Rady10's picture
Update app.py
5b9d376 verified
import os
import base64
import torch
import faiss
import json
from fastapi import FastAPI
from pydantic import BaseModel
from contextlib import asynccontextmanager
from huggingface_hub import snapshot_download
from sentence_transformers import SentenceTransformer
from PIL import Image
from io import BytesIO
from transformers import (
AutoProcessor,
Qwen3VLForConditionalGeneration,
)
# ─────────────────────────────
# CONFIG
# ─────────────────────────────
MODEL_REPO = "Rady10/Plant-Disease-Qwen3VL-2B"
RAG_REPO = "Rady10/Agriculture-Rag-Data-Index"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# ─────────────────────────────
# GLOBALS
# ─────────────────────────────
model = None
processor = None
faiss_index = None
rag_chunks = None
embedder = None
# ─────────────────────────────
# LIFESPAN
# ─────────────────────────────
@asynccontextmanager
async def lifespan(app: FastAPI):
global model, processor, faiss_index, rag_chunks, embedder
print("Loading vision model...")
processor = AutoProcessor.from_pretrained(MODEL_REPO, trust_remote_code=True)
model = Qwen3VLForConditionalGeneration.from_pretrained(
MODEL_REPO,
torch_dtype=torch.float32,
device_map="cpu",
trust_remote_code=True,
)
model.eval()
print("Loading RAG index...")
rag_dir = snapshot_download(repo_id=RAG_REPO, repo_type="dataset", local_dir="./rag")
faiss_index = faiss.read_index(os.path.join(rag_dir, "agro.index"))
with open(os.path.join(rag_dir, "chunks.json"), "r", encoding="utf-8") as f:
rag_chunks = json.load(f)
embedder = SentenceTransformer(
"sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
)
print("ALL LOADED βœ”")
yield
# ─────────────────────────────
# APP
# ─────────────────────────────
app = FastAPI(title="🌿 Plant Disease Chat API", lifespan=lifespan)
# ─────────────────────────────
# REQUEST MODEL
# ─────────────────────────────
class ChatRequest(BaseModel):
messages: list
image: str = None # base64 β€” if given, RAG is skipped automatically
# ─────────────────────────────
# HELPERS
# ─────────────────────────────
def decode_image(b64: str) -> Image.Image:
return Image.open(BytesIO(base64.b64decode(b64))).convert("RGB")
def chunk_to_text(chunk) -> str:
if isinstance(chunk, str):
return chunk
if isinstance(chunk, dict):
for key in ("text", "content", "passage", "chunk", "body"):
if key in chunk and isinstance(chunk[key], str):
return chunk[key]
return " ".join(str(v) for v in chunk.values())
return str(chunk)
def to_content_list(content) -> list:
"""content must always be a list of dicts for apply_chat_template"""
if isinstance(content, str):
return [{"type": "text", "text": content}]
if isinstance(content, list):
result = []
for block in content:
if isinstance(block, str):
result.append({"type": "text", "text": block})
else:
result.append(block)
return result
return [{"type": "text", "text": str(content)}]
def retrieve_rag_context(messages: list, k: int = 3) -> str:
if not rag_chunks or faiss_index is None:
return ""
last_user_text = ""
for m in reversed(messages):
if m.get("role") != "user":
continue
content = m.get("content", "")
if isinstance(content, list):
for block in content:
if isinstance(block, dict) and block.get("type") == "text":
last_user_text = block["text"]
break
elif isinstance(content, str):
last_user_text = content
if last_user_text:
break
if not last_user_text.strip():
return ""
query_vec = embedder.encode([last_user_text])
_, indices = faiss_index.search(query_vec, k=k)
chunks = [chunk_to_text(rag_chunks[i]) for i in indices[0] if i < len(rag_chunks)]
return "\n\n".join(chunks)
def build_full_messages(messages: list, image: Image.Image, rag_context: str) -> list:
system_parts = ["You are a plant disease expert assistant."]
if rag_context:
system_parts.append(
"Use the following retrieved knowledge to inform your answer:\n\n" + rag_context
)
system_prompt = "\n\n".join(system_parts)
# content MUST be list of dicts β€” never plain string
full_messages = [
{"role": "user", "content": [{"type": "text", "text": system_prompt}]},
{"role": "assistant", "content": [{"type": "text", "text": "Understood. I will use this knowledge to help you."}]},
]
norm = [
{"role": m["role"], "content": to_content_list(m.get("content", ""))}
for m in messages
]
if image is not None:
for i in range(len(norm) - 1, -1, -1):
if norm[i]["role"] == "user":
norm[i]["content"] = [{"type": "image", "image": image}] + norm[i]["content"]
break
full_messages.extend(norm)
return full_messages
# ─────────────────────────────
# UNIFIED ENDPOINT
# ─────────────────────────────
@app.post("/chat")
def chat(req: ChatRequest):
image = decode_image(req.image) if req.image else None
rag_context = "" if image else retrieve_rag_context(req.messages)
full_messages = build_full_messages(req.messages, image, rag_context)
# apply_chat_template with tokenize=True returns a plain Tensor, not a dict
# use return_dict=True to get {"input_ids": ..., "attention_mask": ...}
inputs = processor.apply_chat_template(
full_messages,
add_generation_prompt=True,
tokenize=True,
return_tensors="pt",
return_dict=True, # ← fixes: argument after ** must be a mapping, not Tensor
)
inputs = {k: v.to(model.device) for k, v in inputs.items()}
with torch.no_grad():
output_ids = model.generate(
**inputs,
max_new_tokens=512,
temperature=0.7,
top_p=0.9,
)
# decode only the newly generated tokens (skip the input prompt)
input_len = inputs["input_ids"].shape[1]
new_tokens = output_ids[0][input_len:]
response_text = processor.decode(new_tokens, skip_special_tokens=True)
return {
"response": response_text,
"rag_used": bool(rag_context),
"image_used": image is not None,
}
# ─────────────────────────────
# HEALTH CHECK
# ─────────────────────────────
@app.get("/")
def root():
return {"status": "plant disease chat api running"}