Rady10 commited on
Commit
eab4ea1
Β·
verified Β·
1 Parent(s): 6d700fa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -59
app.py CHANGED
@@ -36,7 +36,7 @@ rag_chunks = None
36
  embedder = None
37
 
38
  # ─────────────────────────────
39
- # LIFESPAN β€” load everything once
40
  # ─────────────────────────────
41
  @asynccontextmanager
42
  async def lifespan(app: FastAPI):
@@ -83,29 +83,41 @@ app = FastAPI(title="🌿 Plant Disease Chat API", lifespan=lifespan)
83
  # REQUEST MODEL
84
  # ─────────────────────────────
85
  class ChatRequest(BaseModel):
86
- messages: list # full conversation history in OpenAI-style format
87
- image: str = None # base64-encoded image (optional)
88
- use_rag: bool = True # set False to skip RAG retrieval
89
 
90
 
91
  # ─────────────────────────────
92
  # HELPERS
93
  # ─────────────────────────────
94
  def decode_image(base64_str: str) -> Image.Image:
95
- """Decode a base64 string into a PIL RGB image."""
96
  img_bytes = base64.b64decode(base64_str)
97
  return Image.open(BytesIO(img_bytes)).convert("RGB")
98
 
99
 
100
- def retrieve_rag_context(messages: list, k: int = 3) -> str:
101
  """
102
- Extract the last user text, embed it, and return the top-k
103
- RAG chunks joined as a single string. Returns "" if nothing found.
104
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  if not rag_chunks or faiss_index is None:
106
  return ""
107
 
108
- # walk backwards to find the latest user text
109
  last_user_text = ""
110
  for m in reversed(messages):
111
  if m.get("role") != "user":
@@ -126,18 +138,16 @@ def retrieve_rag_context(messages: list, k: int = 3) -> str:
126
 
127
  query_vec = embedder.encode([last_user_text])
128
  _, indices = faiss_index.search(query_vec, k=k)
129
- chunks = [rag_chunks[i] for i in indices[0] if i < len(rag_chunks)]
 
 
 
 
 
130
  return "\n\n".join(chunks)
131
 
132
 
133
  def build_full_messages(messages: list, image: Image.Image, rag_context: str) -> list:
134
- """
135
- Combine system prompt (RAG context), conversation history, and optional
136
- image into a single message list ready for apply_chat_template.
137
- """
138
- # ── system as a fake user/assistant pair ──────────────────
139
- # Qwen3VL's apply_chat_template does not support a 'system' role,
140
- # so we simulate it with a leading exchange.
141
  system_parts = ["You are a plant disease expert assistant."]
142
  if rag_context:
143
  system_parts.append(
@@ -151,65 +161,34 @@ def build_full_messages(messages: list, image: Image.Image, rag_context: str) ->
151
  {"role": "assistant", "content": "Understood. I will use this knowledge to help you."},
152
  ]
153
 
154
- # ── copy conversation; inject image into last user turn ───
155
- messages = [dict(m) for m in messages] # shallow copy so we don't mutate input
156
 
157
  if image is not None:
158
- last_user_idx = None
159
  for i in range(len(messages) - 1, -1, -1):
160
  if messages[i].get("role") == "user":
161
- last_user_idx = i
 
 
 
 
162
  break
163
 
164
- if last_user_idx is not None:
165
- content = messages[last_user_idx].get("content", "")
166
- if isinstance(content, str):
167
- content = [{"type": "text", "text": content}]
168
- # prepend image block
169
- content = [{"type": "image", "image": image}] + content
170
- messages[last_user_idx]["content"] = content
171
-
172
  full_messages.extend(messages)
173
  return full_messages
174
 
175
 
176
  # ─────────────────────────────
177
- # SINGLE UNIFIED ENDPOINT
178
  # ─────────────────────────────
179
  @app.post("/chat")
180
  def chat(req: ChatRequest):
181
- """
182
- Unified chat endpoint. Handles three modes transparently:
183
-
184
- 1. RAG only β€” pass messages, use_rag=true, no image
185
- 2. Image only β€” pass messages + image, use_rag=false
186
- 3. Image + RAG β€” pass messages + image, use_rag=true (default)
187
-
188
- Request body
189
- ────────────
190
- messages : list of {"role": "user"|"assistant", "content": str | list}
191
- image : base64-encoded image string (optional)
192
- use_rag : bool, default true
193
-
194
- Response
195
- ────────
196
- {
197
- "response" : str,
198
- "rag_used" : bool,
199
- "image_used": bool
200
- }
201
- """
202
-
203
- # ── decode image ──────────────────────────────────────────
204
  image = decode_image(req.image) if req.image else None
205
 
206
- # ── RAG retrieval ─────────────────────────────────────────
207
- rag_context = retrieve_rag_context(req.messages) if req.use_rag else ""
208
-
209
- # ── assemble messages ─────────────────────────────────────
210
  full_messages = build_full_messages(req.messages, image, rag_context)
211
 
212
- # ── tokenise ──────────────────────────────────────────────
213
  inputs = processor.apply_chat_template(
214
  full_messages,
215
  add_generation_prompt=True,
@@ -217,7 +196,6 @@ def chat(req: ChatRequest):
217
  return_tensors="pt",
218
  ).to(model.device)
219
 
220
- # ── generate ──────────────────────────────────────────────
221
  with torch.no_grad():
222
  output_ids = model.generate(
223
  **inputs,
 
36
  embedder = None
37
 
38
  # ─────────────────────────────
39
+ # LIFESPAN
40
  # ─────────────────────────────
41
  @asynccontextmanager
42
  async def lifespan(app: FastAPI):
 
83
  # REQUEST MODEL
84
  # ─────────────────────────────
85
  class ChatRequest(BaseModel):
86
+ messages: list
87
+ image: str = None
88
+ # image present β†’ RAG skipped automatically
89
 
90
 
91
  # ─────────────────────────────
92
  # HELPERS
93
  # ─────────────────────────────
94
  def decode_image(base64_str: str) -> Image.Image:
 
95
  img_bytes = base64.b64decode(base64_str)
96
  return Image.open(BytesIO(img_bytes)).convert("RGB")
97
 
98
 
99
+ def chunk_to_text(chunk) -> str:
100
  """
101
+ Safely convert a chunk to plain string regardless of its type.
102
+ chunks.json may contain strings, dicts, or other structures.
103
  """
104
+ if isinstance(chunk, str):
105
+ return chunk
106
+ if isinstance(chunk, dict):
107
+ # common keys used in RAG datasets β€” try in order
108
+ for key in ("text", "content", "passage", "chunk", "body"):
109
+ if key in chunk and isinstance(chunk[key], str):
110
+ return chunk[key]
111
+ # fallback: join all string values
112
+ return " ".join(str(v) for v in chunk.values())
113
+ return str(chunk)
114
+
115
+
116
+ def retrieve_rag_context(messages: list, k: int = 3) -> str:
117
  if not rag_chunks or faiss_index is None:
118
  return ""
119
 
120
+ # find last user text
121
  last_user_text = ""
122
  for m in reversed(messages):
123
  if m.get("role") != "user":
 
138
 
139
  query_vec = embedder.encode([last_user_text])
140
  _, indices = faiss_index.search(query_vec, k=k)
141
+
142
+ chunks = [
143
+ chunk_to_text(rag_chunks[i])
144
+ for i in indices[0]
145
+ if i < len(rag_chunks)
146
+ ]
147
  return "\n\n".join(chunks)
148
 
149
 
150
  def build_full_messages(messages: list, image: Image.Image, rag_context: str) -> list:
 
 
 
 
 
 
 
151
  system_parts = ["You are a plant disease expert assistant."]
152
  if rag_context:
153
  system_parts.append(
 
161
  {"role": "assistant", "content": "Understood. I will use this knowledge to help you."},
162
  ]
163
 
164
+ messages = [dict(m) for m in messages]
 
165
 
166
  if image is not None:
 
167
  for i in range(len(messages) - 1, -1, -1):
168
  if messages[i].get("role") == "user":
169
+ content = messages[i].get("content", "")
170
+ if isinstance(content, str):
171
+ content = [{"type": "text", "text": content}]
172
+ content = [{"type": "image", "image": image}] + content
173
+ messages[i]["content"] = content
174
  break
175
 
 
 
 
 
 
 
 
 
176
  full_messages.extend(messages)
177
  return full_messages
178
 
179
 
180
  # ─────────────────────────────
181
+ # UNIFIED ENDPOINT
182
  # ─────────────────────────────
183
  @app.post("/chat")
184
  def chat(req: ChatRequest):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
  image = decode_image(req.image) if req.image else None
186
 
187
+ # image present β†’ use model's own vision training only (no RAG)
188
+ # no image β†’ use RAG to ground the text answer
189
+ rag_context = "" if image else retrieve_rag_context(req.messages)
 
190
  full_messages = build_full_messages(req.messages, image, rag_context)
191
 
 
192
  inputs = processor.apply_chat_template(
193
  full_messages,
194
  add_generation_prompt=True,
 
196
  return_tensors="pt",
197
  ).to(model.device)
198
 
 
199
  with torch.no_grad():
200
  output_ids = model.generate(
201
  **inputs,