Rady10 commited on
Commit
6d700fa
Β·
verified Β·
1 Parent(s): 058a9d2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +154 -112
app.py CHANGED
@@ -1,4 +1,3 @@
1
-
2
  import os
3
  import base64
4
  import torch
@@ -15,15 +14,14 @@ from io import BytesIO
15
 
16
  from transformers import (
17
  AutoProcessor,
18
- AutoConfig,
19
- Qwen3VLForConditionalGeneration
20
  )
21
 
22
  # ─────────────────────────────
23
  # CONFIG
24
  # ─────────────────────────────
25
  MODEL_REPO = "Rady10/Plant-Disease-Qwen3VL-2B"
26
- RAG_REPO = "Rady10/Agriculture-Rag-Data-Index"
27
 
28
  DEVICE = "cpu"
29
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
@@ -31,48 +29,39 @@ os.environ["TOKENIZERS_PARALLELISM"] = "false"
31
  # ─────────────────────────────
32
  # GLOBALS
33
  # ─────────────────────────────
34
- model = None
35
- processor = None
36
  faiss_index = None
37
- rag_chunks = None
38
- embedder = None
39
 
40
  # ─────────────────────────────
41
- # FASTAPI APP
42
  # ─────────────────────────────
43
  @asynccontextmanager
44
  async def lifespan(app: FastAPI):
45
  global model, processor, faiss_index, rag_chunks, embedder
46
 
47
  print("Loading vision model...")
48
-
49
  processor = AutoProcessor.from_pretrained(
50
  MODEL_REPO,
51
- trust_remote_code=True
52
  )
53
-
54
  model = Qwen3VLForConditionalGeneration.from_pretrained(
55
  MODEL_REPO,
56
  torch_dtype=torch.float32,
57
  device_map="cpu",
58
- trust_remote_code=True
59
  )
60
-
61
  model.eval()
62
 
63
- # ───── LOAD RAG ─────
64
- print("Loading RAG...")
65
-
66
  rag_dir = snapshot_download(
67
  repo_id=RAG_REPO,
68
  repo_type="dataset",
69
- local_dir="./rag"
70
- )
71
-
72
- faiss_index = faiss.read_index(
73
- os.path.join(rag_dir, "agro.index")
74
  )
75
-
76
  with open(os.path.join(rag_dir, "chunks.json"), "r", encoding="utf-8") as f:
77
  rag_chunks = json.load(f)
78
 
@@ -81,121 +70,174 @@ async def lifespan(app: FastAPI):
81
  )
82
 
83
  print("ALL LOADED βœ”")
84
-
85
  yield
86
 
87
 
88
- app = FastAPI(
89
- title="🌿 Plant Disease Vision API",
90
- lifespan=lifespan
91
- )
92
-
93
  # ─────────────────────────────
94
- # REQUEST MODELS
95
  # ─────────────────────────────
96
- class VisionRequest(BaseModel):
97
- image: str
98
- text: str = ""
99
-
100
- class ChatRequest(BaseModel):
101
- messages: list
102
- image: str = None
103
 
104
- # ─────────────────────────────
105
- # IMAGE DECODER
106
- # ─────────────────────────────
107
- def decode_image(base64_str):
108
- img_data = base64.b64decode(base64_str)
109
- return Image.open(BytesIO(img_data)).convert("RGB")
110
 
111
  # ─────────────────────────────
112
- # VISION GENERATION (CHAT FORMAT)
113
  # ─────────────────────────────
114
- def generate(image, text):
115
-
116
- if not text.strip():
117
- text = "What disease is shown in this plant image?"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
 
119
- messages = [
120
- {
121
- "role": "user",
122
- "content": [
123
- {"type": "image", "image": image},
124
- {"type": "text", "text": text}
125
- ]
126
- }
127
  ]
128
 
129
- inputs = processor.apply_chat_template(
130
- messages,
131
- add_generation_prompt=True,
132
- return_tensors="pt"
133
- )
134
 
135
- inputs = inputs.to(model.device)
 
 
 
 
 
136
 
137
- with torch.no_grad():
138
- output = model.generate(
139
- **inputs,
140
- max_new_tokens=256,
141
- temperature=0.7,
142
- top_p=0.9
143
- )
144
-
145
- return processor.decode(output[0], skip_special_tokens=True)
146
-
147
- # ─────────────────────────────
148
- # ROUTES
149
- # ─────────────────────────────
150
- @app.get("/")
151
- def root():
152
- return {"status": "vision api running"}
153
-
154
- @app.post("/analyze")
155
- def analyze(req: VisionRequest):
156
 
157
- image = decode_image(req.image)
158
- result = generate(image, req.text)
159
 
160
- return {"response": result}
161
 
162
  # ─────────────────────────────
163
- # CHAT ENDPOINT (IMAGE + TEXT)
164
  # ─────────────────────────────
165
  @app.post("/chat")
166
  def chat(req: ChatRequest):
167
-
168
- messages = req.messages
169
-
170
- image = None
171
- if req.image:
172
- image = decode_image(req.image)
173
-
174
- # ───── inject image safely ─────
175
- if image:
176
- messages[-1]["content"].insert(0, {
177
- "type": "image",
178
- "image": image
179
- })
180
-
181
- # ───── IMPORTANT FIX HERE ─────
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
182
  inputs = processor.apply_chat_template(
183
- messages,
184
  add_generation_prompt=True,
185
- tokenize=True, # πŸ”΄ THIS FIXES IT
186
- return_tensors="pt"
187
- )
188
-
189
- # now inputs is a tensor dict (NOT string anymore)
190
-
191
- inputs = {k: v.to(model.device) for k, v in inputs.items()}
192
 
 
193
  with torch.no_grad():
194
- output = model.generate(
195
  **inputs,
196
- max_new_tokens=256
 
 
197
  )
198
 
 
 
199
  return {
200
- "response": processor.decode(output[0], skip_special_tokens=True)
201
- }
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  import base64
3
  import torch
 
14
 
15
  from transformers import (
16
  AutoProcessor,
17
+ Qwen3VLForConditionalGeneration,
 
18
  )
19
 
20
  # ─────────────────────────────
21
  # CONFIG
22
  # ─────────────────────────────
23
  MODEL_REPO = "Rady10/Plant-Disease-Qwen3VL-2B"
24
+ RAG_REPO = "Rady10/Agriculture-Rag-Data-Index"
25
 
26
  DEVICE = "cpu"
27
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
 
29
  # ─────────────────────────────
30
  # GLOBALS
31
  # ─────────────────────────────
32
+ model = None
33
+ processor = None
34
  faiss_index = None
35
+ rag_chunks = None
36
+ embedder = None
37
 
38
  # ─────────────────────────────
39
+ # LIFESPAN β€” load everything once
40
  # ─────────────────────────────
41
  @asynccontextmanager
42
  async def lifespan(app: FastAPI):
43
  global model, processor, faiss_index, rag_chunks, embedder
44
 
45
  print("Loading vision model...")
 
46
  processor = AutoProcessor.from_pretrained(
47
  MODEL_REPO,
48
+ trust_remote_code=True,
49
  )
 
50
  model = Qwen3VLForConditionalGeneration.from_pretrained(
51
  MODEL_REPO,
52
  torch_dtype=torch.float32,
53
  device_map="cpu",
54
+ trust_remote_code=True,
55
  )
 
56
  model.eval()
57
 
58
+ print("Loading RAG index...")
 
 
59
  rag_dir = snapshot_download(
60
  repo_id=RAG_REPO,
61
  repo_type="dataset",
62
+ local_dir="./rag",
 
 
 
 
63
  )
64
+ faiss_index = faiss.read_index(os.path.join(rag_dir, "agro.index"))
65
  with open(os.path.join(rag_dir, "chunks.json"), "r", encoding="utf-8") as f:
66
  rag_chunks = json.load(f)
67
 
 
70
  )
71
 
72
  print("ALL LOADED βœ”")
 
73
  yield
74
 
75
 
 
 
 
 
 
76
  # ─────────────────────────────
77
+ # APP
78
  # ─────────────────────────────
79
+ app = FastAPI(title="🌿 Plant Disease Chat API", lifespan=lifespan)
 
 
 
 
 
 
80
 
 
 
 
 
 
 
81
 
82
  # ─────────────────────────────
83
+ # REQUEST MODEL
84
  # ─────────────────────────────
85
+ class ChatRequest(BaseModel):
86
+ messages: list # full conversation history in OpenAI-style format
87
+ image: str = None # base64-encoded image (optional)
88
+ use_rag: bool = True # set False to skip RAG retrieval
89
+
90
+
91
+ # ─────────────────────────────
92
+ # HELPERS
93
+ # ─────────────────────────────
94
+ def decode_image(base64_str: str) -> Image.Image:
95
+ """Decode a base64 string into a PIL RGB image."""
96
+ img_bytes = base64.b64decode(base64_str)
97
+ return Image.open(BytesIO(img_bytes)).convert("RGB")
98
+
99
+
100
+ def retrieve_rag_context(messages: list, k: int = 3) -> str:
101
+ """
102
+ Extract the last user text, embed it, and return the top-k
103
+ RAG chunks joined as a single string. Returns "" if nothing found.
104
+ """
105
+ if not rag_chunks or faiss_index is None:
106
+ return ""
107
+
108
+ # walk backwards to find the latest user text
109
+ last_user_text = ""
110
+ for m in reversed(messages):
111
+ if m.get("role") != "user":
112
+ continue
113
+ content = m.get("content", "")
114
+ if isinstance(content, list):
115
+ for block in content:
116
+ if isinstance(block, dict) and block.get("type") == "text":
117
+ last_user_text = block["text"]
118
+ break
119
+ elif isinstance(content, str):
120
+ last_user_text = content
121
+ if last_user_text:
122
+ break
123
+
124
+ if not last_user_text.strip():
125
+ return ""
126
+
127
+ query_vec = embedder.encode([last_user_text])
128
+ _, indices = faiss_index.search(query_vec, k=k)
129
+ chunks = [rag_chunks[i] for i in indices[0] if i < len(rag_chunks)]
130
+ return "\n\n".join(chunks)
131
+
132
+
133
+ def build_full_messages(messages: list, image: Image.Image, rag_context: str) -> list:
134
+ """
135
+ Combine system prompt (RAG context), conversation history, and optional
136
+ image into a single message list ready for apply_chat_template.
137
+ """
138
+ # ── system as a fake user/assistant pair ──────────────────
139
+ # Qwen3VL's apply_chat_template does not support a 'system' role,
140
+ # so we simulate it with a leading exchange.
141
+ system_parts = ["You are a plant disease expert assistant."]
142
+ if rag_context:
143
+ system_parts.append(
144
+ "Use the following retrieved knowledge to inform your answer:\n\n"
145
+ + rag_context
146
+ )
147
+ system_prompt = "\n\n".join(system_parts)
148
 
149
+ full_messages = [
150
+ {"role": "user", "content": system_prompt},
151
+ {"role": "assistant", "content": "Understood. I will use this knowledge to help you."},
 
 
 
 
 
152
  ]
153
 
154
+ # ── copy conversation; inject image into last user turn ───
155
+ messages = [dict(m) for m in messages] # shallow copy so we don't mutate input
 
 
 
156
 
157
+ if image is not None:
158
+ last_user_idx = None
159
+ for i in range(len(messages) - 1, -1, -1):
160
+ if messages[i].get("role") == "user":
161
+ last_user_idx = i
162
+ break
163
 
164
+ if last_user_idx is not None:
165
+ content = messages[last_user_idx].get("content", "")
166
+ if isinstance(content, str):
167
+ content = [{"type": "text", "text": content}]
168
+ # prepend image block
169
+ content = [{"type": "image", "image": image}] + content
170
+ messages[last_user_idx]["content"] = content
 
 
 
 
 
 
 
 
 
 
 
 
171
 
172
+ full_messages.extend(messages)
173
+ return full_messages
174
 
 
175
 
176
  # ─────────────────────────────
177
+ # SINGLE UNIFIED ENDPOINT
178
  # ─────────────────────────────
179
  @app.post("/chat")
180
  def chat(req: ChatRequest):
181
+ """
182
+ Unified chat endpoint. Handles three modes transparently:
183
+
184
+ 1. RAG only β€” pass messages, use_rag=true, no image
185
+ 2. Image only β€” pass messages + image, use_rag=false
186
+ 3. Image + RAG β€” pass messages + image, use_rag=true (default)
187
+
188
+ Request body
189
+ ────────────
190
+ messages : list of {"role": "user"|"assistant", "content": str | list}
191
+ image : base64-encoded image string (optional)
192
+ use_rag : bool, default true
193
+
194
+ Response
195
+ ────────
196
+ {
197
+ "response" : str,
198
+ "rag_used" : bool,
199
+ "image_used": bool
200
+ }
201
+ """
202
+
203
+ # ── decode image ──────────────────────────────────────────
204
+ image = decode_image(req.image) if req.image else None
205
+
206
+ # ── RAG retrieval ─────────────────────────────────────────
207
+ rag_context = retrieve_rag_context(req.messages) if req.use_rag else ""
208
+
209
+ # ── assemble messages ─────────────────────────────────────
210
+ full_messages = build_full_messages(req.messages, image, rag_context)
211
+
212
+ # ── tokenise ──────────────────────────────────────────────
213
  inputs = processor.apply_chat_template(
214
+ full_messages,
215
  add_generation_prompt=True,
216
+ tokenize=True,
217
+ return_tensors="pt",
218
+ ).to(model.device)
 
 
 
 
219
 
220
+ # ── generate ──────────────────────────────────────────────
221
  with torch.no_grad():
222
+ output_ids = model.generate(
223
  **inputs,
224
+ max_new_tokens=512,
225
+ temperature=0.7,
226
+ top_p=0.9,
227
  )
228
 
229
+ response_text = processor.decode(output_ids[0], skip_special_tokens=True)
230
+
231
  return {
232
+ "response": response_text,
233
+ "rag_used": bool(rag_context),
234
+ "image_used": image is not None,
235
+ }
236
+
237
+
238
+ # ─────────────────────────────
239
+ # HEALTH CHECK
240
+ # ─────────────────────────────
241
+ @app.get("/")
242
+ def root():
243
+ return {"status": "plant disease chat api running"}