import os import base64 import torch import faiss import json from fastapi import FastAPI from pydantic import BaseModel from contextlib import asynccontextmanager from huggingface_hub import snapshot_download from sentence_transformers import SentenceTransformer from PIL import Image from io import BytesIO from transformers import ( AutoProcessor, Qwen3VLForConditionalGeneration, ) # ───────────────────────────── # CONFIG # ───────────────────────────── MODEL_REPO = "Rady10/Plant-Disease-Qwen3VL-2B" RAG_REPO = "Rady10/Agriculture-Rag-Data-Index" os.environ["TOKENIZERS_PARALLELISM"] = "false" # ───────────────────────────── # GLOBALS # ───────────────────────────── model = None processor = None faiss_index = None rag_chunks = None embedder = None # ───────────────────────────── # LIFESPAN # ───────────────────────────── @asynccontextmanager async def lifespan(app: FastAPI): global model, processor, faiss_index, rag_chunks, embedder print("Loading vision model...") processor = AutoProcessor.from_pretrained(MODEL_REPO, trust_remote_code=True) model = Qwen3VLForConditionalGeneration.from_pretrained( MODEL_REPO, torch_dtype=torch.float32, device_map="cpu", trust_remote_code=True, ) model.eval() print("Loading RAG index...") rag_dir = snapshot_download(repo_id=RAG_REPO, repo_type="dataset", local_dir="./rag") faiss_index = faiss.read_index(os.path.join(rag_dir, "agro.index")) with open(os.path.join(rag_dir, "chunks.json"), "r", encoding="utf-8") as f: rag_chunks = json.load(f) embedder = SentenceTransformer( "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2" ) print("ALL LOADED ✔") yield # ───────────────────────────── # APP # ───────────────────────────── app = FastAPI(title="🌿 Plant Disease Chat API", lifespan=lifespan) # ───────────────────────────── # REQUEST MODEL # ───────────────────────────── class ChatRequest(BaseModel): messages: list image: str = None # base64 — if given, RAG is skipped automatically # ───────────────────────────── # HELPERS # ───────────────────────────── def decode_image(b64: str) -> Image.Image: return Image.open(BytesIO(base64.b64decode(b64))).convert("RGB") def chunk_to_text(chunk) -> str: if isinstance(chunk, str): return chunk if isinstance(chunk, dict): for key in ("text", "content", "passage", "chunk", "body"): if key in chunk and isinstance(chunk[key], str): return chunk[key] return " ".join(str(v) for v in chunk.values()) return str(chunk) def to_content_list(content) -> list: """content must always be a list of dicts for apply_chat_template""" if isinstance(content, str): return [{"type": "text", "text": content}] if isinstance(content, list): result = [] for block in content: if isinstance(block, str): result.append({"type": "text", "text": block}) else: result.append(block) return result return [{"type": "text", "text": str(content)}] def retrieve_rag_context(messages: list, k: int = 3) -> str: if not rag_chunks or faiss_index is None: return "" last_user_text = "" for m in reversed(messages): if m.get("role") != "user": continue content = m.get("content", "") if isinstance(content, list): for block in content: if isinstance(block, dict) and block.get("type") == "text": last_user_text = block["text"] break elif isinstance(content, str): last_user_text = content if last_user_text: break if not last_user_text.strip(): return "" query_vec = embedder.encode([last_user_text]) _, indices = faiss_index.search(query_vec, k=k) chunks = [chunk_to_text(rag_chunks[i]) for i in indices[0] if i < len(rag_chunks)] return "\n\n".join(chunks) def build_full_messages(messages: list, image: Image.Image, rag_context: str) -> list: system_parts = ["You are a plant disease expert assistant."] if rag_context: system_parts.append( "Use the following retrieved knowledge to inform your answer:\n\n" + rag_context ) system_prompt = "\n\n".join(system_parts) # content MUST be list of dicts — never plain string full_messages = [ {"role": "user", "content": [{"type": "text", "text": system_prompt}]}, {"role": "assistant", "content": [{"type": "text", "text": "Understood. I will use this knowledge to help you."}]}, ] norm = [ {"role": m["role"], "content": to_content_list(m.get("content", ""))} for m in messages ] if image is not None: for i in range(len(norm) - 1, -1, -1): if norm[i]["role"] == "user": norm[i]["content"] = [{"type": "image", "image": image}] + norm[i]["content"] break full_messages.extend(norm) return full_messages # ───────────────────────────── # UNIFIED ENDPOINT # ───────────────────────────── @app.post("/chat") def chat(req: ChatRequest): image = decode_image(req.image) if req.image else None rag_context = "" if image else retrieve_rag_context(req.messages) full_messages = build_full_messages(req.messages, image, rag_context) # apply_chat_template with tokenize=True returns a plain Tensor, not a dict # use return_dict=True to get {"input_ids": ..., "attention_mask": ...} inputs = processor.apply_chat_template( full_messages, add_generation_prompt=True, tokenize=True, return_tensors="pt", return_dict=True, # ← fixes: argument after ** must be a mapping, not Tensor ) inputs = {k: v.to(model.device) for k, v in inputs.items()} with torch.no_grad(): output_ids = model.generate( **inputs, max_new_tokens=512, temperature=0.7, top_p=0.9, ) # decode only the newly generated tokens (skip the input prompt) input_len = inputs["input_ids"].shape[1] new_tokens = output_ids[0][input_len:] response_text = processor.decode(new_tokens, skip_special_tokens=True) return { "response": response_text, "rag_used": bool(rag_context), "image_used": image is not None, } # ───────────────────────────── # HEALTH CHECK # ───────────────────────────── @app.get("/") def root(): return {"status": "plant disease chat api running"}