Rady10 commited on
Commit
56d265c
Β·
verified Β·
1 Parent(s): eab4ea1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -44
app.py CHANGED
@@ -23,7 +23,6 @@ from transformers import (
23
  MODEL_REPO = "Rady10/Plant-Disease-Qwen3VL-2B"
24
  RAG_REPO = "Rady10/Agriculture-Rag-Data-Index"
25
 
26
- DEVICE = "cpu"
27
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
28
 
29
  # ─────────────────────────────
@@ -43,10 +42,7 @@ async def lifespan(app: FastAPI):
43
  global model, processor, faiss_index, rag_chunks, embedder
44
 
45
  print("Loading vision model...")
46
- processor = AutoProcessor.from_pretrained(
47
- MODEL_REPO,
48
- trust_remote_code=True,
49
- )
50
  model = Qwen3VLForConditionalGeneration.from_pretrained(
51
  MODEL_REPO,
52
  torch_dtype=torch.float32,
@@ -56,11 +52,7 @@ async def lifespan(app: FastAPI):
56
  model.eval()
57
 
58
  print("Loading RAG index...")
59
- rag_dir = snapshot_download(
60
- repo_id=RAG_REPO,
61
- repo_type="dataset",
62
- local_dir="./rag",
63
- )
64
  faiss_index = faiss.read_index(os.path.join(rag_dir, "agro.index"))
65
  with open(os.path.join(rag_dir, "chunks.json"), "r", encoding="utf-8") as f:
66
  rag_chunks = json.load(f)
@@ -84,40 +76,49 @@ app = FastAPI(title="🌿 Plant Disease Chat API", lifespan=lifespan)
84
  # ─────────────────────────────
85
  class ChatRequest(BaseModel):
86
  messages: list
87
- image: str = None
88
- # image present β†’ RAG skipped automatically
89
 
90
 
91
  # ─────────────────────────────
92
  # HELPERS
93
  # ─────────────────────────────
94
- def decode_image(base64_str: str) -> Image.Image:
95
- img_bytes = base64.b64decode(base64_str)
96
- return Image.open(BytesIO(img_bytes)).convert("RGB")
97
 
98
 
99
  def chunk_to_text(chunk) -> str:
100
- """
101
- Safely convert a chunk to plain string regardless of its type.
102
- chunks.json may contain strings, dicts, or other structures.
103
- """
104
  if isinstance(chunk, str):
105
  return chunk
106
  if isinstance(chunk, dict):
107
- # common keys used in RAG datasets β€” try in order
108
  for key in ("text", "content", "passage", "chunk", "body"):
109
  if key in chunk and isinstance(chunk[key], str):
110
  return chunk[key]
111
- # fallback: join all string values
112
  return " ".join(str(v) for v in chunk.values())
113
  return str(chunk)
114
 
115
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
  def retrieve_rag_context(messages: list, k: int = 3) -> str:
117
  if not rag_chunks or faiss_index is None:
118
  return ""
119
 
120
- # find last user text
121
  last_user_text = ""
122
  for m in reversed(messages):
123
  if m.get("role") != "user":
@@ -138,12 +139,7 @@ def retrieve_rag_context(messages: list, k: int = 3) -> str:
138
 
139
  query_vec = embedder.encode([last_user_text])
140
  _, indices = faiss_index.search(query_vec, k=k)
141
-
142
- chunks = [
143
- chunk_to_text(rag_chunks[i])
144
- for i in indices[0]
145
- if i < len(rag_chunks)
146
- ]
147
  return "\n\n".join(chunks)
148
 
149
 
@@ -151,29 +147,30 @@ def build_full_messages(messages: list, image: Image.Image, rag_context: str) ->
151
  system_parts = ["You are a plant disease expert assistant."]
152
  if rag_context:
153
  system_parts.append(
154
- "Use the following retrieved knowledge to inform your answer:\n\n"
155
- + rag_context
156
  )
157
  system_prompt = "\n\n".join(system_parts)
158
 
 
159
  full_messages = [
160
- {"role": "user", "content": system_prompt},
161
- {"role": "assistant", "content": "Understood. I will use this knowledge to help you."},
162
  ]
163
 
164
- messages = [dict(m) for m in messages]
 
 
 
 
165
 
 
166
  if image is not None:
167
- for i in range(len(messages) - 1, -1, -1):
168
- if messages[i].get("role") == "user":
169
- content = messages[i].get("content", "")
170
- if isinstance(content, str):
171
- content = [{"type": "text", "text": content}]
172
- content = [{"type": "image", "image": image}] + content
173
- messages[i]["content"] = content
174
  break
175
 
176
- full_messages.extend(messages)
177
  return full_messages
178
 
179
 
@@ -183,9 +180,6 @@ def build_full_messages(messages: list, image: Image.Image, rag_context: str) ->
183
  @app.post("/chat")
184
  def chat(req: ChatRequest):
185
  image = decode_image(req.image) if req.image else None
186
-
187
- # image present β†’ use model's own vision training only (no RAG)
188
- # no image β†’ use RAG to ground the text answer
189
  rag_context = "" if image else retrieve_rag_context(req.messages)
190
  full_messages = build_full_messages(req.messages, image, rag_context)
191
 
 
23
  MODEL_REPO = "Rady10/Plant-Disease-Qwen3VL-2B"
24
  RAG_REPO = "Rady10/Agriculture-Rag-Data-Index"
25
 
 
26
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
27
 
28
  # ─────────────────────────────
 
42
  global model, processor, faiss_index, rag_chunks, embedder
43
 
44
  print("Loading vision model...")
45
+ processor = AutoProcessor.from_pretrained(MODEL_REPO, trust_remote_code=True)
 
 
 
46
  model = Qwen3VLForConditionalGeneration.from_pretrained(
47
  MODEL_REPO,
48
  torch_dtype=torch.float32,
 
52
  model.eval()
53
 
54
  print("Loading RAG index...")
55
+ rag_dir = snapshot_download(repo_id=RAG_REPO, repo_type="dataset", local_dir="./rag")
 
 
 
 
56
  faiss_index = faiss.read_index(os.path.join(rag_dir, "agro.index"))
57
  with open(os.path.join(rag_dir, "chunks.json"), "r", encoding="utf-8") as f:
58
  rag_chunks = json.load(f)
 
76
  # ─────────────────────────────
77
  class ChatRequest(BaseModel):
78
  messages: list
79
+ image: str = None # base64 β€” if given, RAG is skipped automatically
 
80
 
81
 
82
  # ─────────────────────────────
83
  # HELPERS
84
  # ─────────────────────────────
85
+ def decode_image(b64: str) -> Image.Image:
86
+ return Image.open(BytesIO(base64.b64decode(b64))).convert("RGB")
 
87
 
88
 
89
  def chunk_to_text(chunk) -> str:
 
 
 
 
90
  if isinstance(chunk, str):
91
  return chunk
92
  if isinstance(chunk, dict):
 
93
  for key in ("text", "content", "passage", "chunk", "body"):
94
  if key in chunk and isinstance(chunk[key], str):
95
  return chunk[key]
 
96
  return " ".join(str(v) for v in chunk.values())
97
  return str(chunk)
98
 
99
 
100
+ def to_content_list(content) -> list:
101
+ """
102
+ apply_chat_template requires content to ALWAYS be a list of dicts.
103
+ Never a plain string β€” that causes: TypeError: string indices must be integers
104
+ """
105
+ if isinstance(content, str):
106
+ return [{"type": "text", "text": content}]
107
+ if isinstance(content, list):
108
+ result = []
109
+ for block in content:
110
+ if isinstance(block, str):
111
+ result.append({"type": "text", "text": block})
112
+ else:
113
+ result.append(block)
114
+ return result
115
+ return [{"type": "text", "text": str(content)}]
116
+
117
+
118
  def retrieve_rag_context(messages: list, k: int = 3) -> str:
119
  if not rag_chunks or faiss_index is None:
120
  return ""
121
 
 
122
  last_user_text = ""
123
  for m in reversed(messages):
124
  if m.get("role") != "user":
 
139
 
140
  query_vec = embedder.encode([last_user_text])
141
  _, indices = faiss_index.search(query_vec, k=k)
142
+ chunks = [chunk_to_text(rag_chunks[i]) for i in indices[0] if i < len(rag_chunks)]
 
 
 
 
 
143
  return "\n\n".join(chunks)
144
 
145
 
 
147
  system_parts = ["You are a plant disease expert assistant."]
148
  if rag_context:
149
  system_parts.append(
150
+ "Use the following retrieved knowledge to inform your answer:\n\n" + rag_context
 
151
  )
152
  system_prompt = "\n\n".join(system_parts)
153
 
154
+ # ⚠️ content MUST be list of dicts β€” never a plain string
155
  full_messages = [
156
+ {"role": "user", "content": [{"type": "text", "text": system_prompt}]},
157
+ {"role": "assistant", "content": [{"type": "text", "text": "Understood. I will use this knowledge to help you."}]},
158
  ]
159
 
160
+ # normalize every incoming message too
161
+ norm = [
162
+ {"role": m["role"], "content": to_content_list(m.get("content", ""))}
163
+ for m in messages
164
+ ]
165
 
166
+ # inject image into last user turn
167
  if image is not None:
168
+ for i in range(len(norm) - 1, -1, -1):
169
+ if norm[i]["role"] == "user":
170
+ norm[i]["content"] = [{"type": "image", "image": image}] + norm[i]["content"]
 
 
 
 
171
  break
172
 
173
+ full_messages.extend(norm)
174
  return full_messages
175
 
176
 
 
180
  @app.post("/chat")
181
  def chat(req: ChatRequest):
182
  image = decode_image(req.image) if req.image else None
 
 
 
183
  rag_context = "" if image else retrieve_rag_context(req.messages)
184
  full_messages = build_full_messages(req.messages, image, rag_context)
185