Spaces:
Sleeping
Sleeping
| import os | |
| import base64 | |
| import torch | |
| import faiss | |
| import json | |
| from fastapi import FastAPI | |
| from pydantic import BaseModel | |
| from contextlib import asynccontextmanager | |
| from huggingface_hub import snapshot_download | |
| from sentence_transformers import SentenceTransformer | |
| from PIL import Image | |
| from io import BytesIO | |
| from transformers import ( | |
| AutoProcessor, | |
| Qwen3VLForConditionalGeneration, | |
| ) | |
| # βββββββββββββββββββββββββββββ | |
| # CONFIG | |
| # βββββββββββββββββββββββββββββ | |
| MODEL_REPO = "Rady10/Plant-Disease-Qwen3VL-2B" | |
| RAG_REPO = "Rady10/Agriculture-Rag-Data-Index" | |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
| # βββββββββββββββββββββββββββββ | |
| # GLOBALS | |
| # βββββββββββββββββββββββββββββ | |
| model = None | |
| processor = None | |
| faiss_index = None | |
| rag_chunks = None | |
| embedder = None | |
| # βββββββββββββββββββββββββββββ | |
| # LIFESPAN | |
| # βββββββββββββββββββββββββββββ | |
| async def lifespan(app: FastAPI): | |
| global model, processor, faiss_index, rag_chunks, embedder | |
| print("Loading vision model...") | |
| processor = AutoProcessor.from_pretrained(MODEL_REPO, trust_remote_code=True) | |
| model = Qwen3VLForConditionalGeneration.from_pretrained( | |
| MODEL_REPO, | |
| torch_dtype=torch.float32, | |
| device_map="cpu", | |
| trust_remote_code=True, | |
| ) | |
| model.eval() | |
| print("Loading RAG index...") | |
| rag_dir = snapshot_download(repo_id=RAG_REPO, repo_type="dataset", local_dir="./rag") | |
| faiss_index = faiss.read_index(os.path.join(rag_dir, "agro.index")) | |
| with open(os.path.join(rag_dir, "chunks.json"), "r", encoding="utf-8") as f: | |
| rag_chunks = json.load(f) | |
| embedder = SentenceTransformer( | |
| "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2" | |
| ) | |
| print("ALL LOADED β") | |
| yield | |
| # βββββββββββββββββββββββββββββ | |
| # APP | |
| # βββββββββββββββββββββββββββββ | |
| app = FastAPI(title="πΏ Plant Disease Chat API", lifespan=lifespan) | |
| # βββββββββββββββββββββββββββββ | |
| # REQUEST MODEL | |
| # βββββββββββββββββββββββββββββ | |
| class ChatRequest(BaseModel): | |
| messages: list | |
| image: str = None # base64 β if given, RAG is skipped automatically | |
| # βββββββββββββββββββββββββββββ | |
| # HELPERS | |
| # βββββββββββββββββββββββββββββ | |
| def decode_image(b64: str) -> Image.Image: | |
| return Image.open(BytesIO(base64.b64decode(b64))).convert("RGB") | |
| def chunk_to_text(chunk) -> str: | |
| if isinstance(chunk, str): | |
| return chunk | |
| if isinstance(chunk, dict): | |
| for key in ("text", "content", "passage", "chunk", "body"): | |
| if key in chunk and isinstance(chunk[key], str): | |
| return chunk[key] | |
| return " ".join(str(v) for v in chunk.values()) | |
| return str(chunk) | |
| def to_content_list(content) -> list: | |
| """content must always be a list of dicts for apply_chat_template""" | |
| if isinstance(content, str): | |
| return [{"type": "text", "text": content}] | |
| if isinstance(content, list): | |
| result = [] | |
| for block in content: | |
| if isinstance(block, str): | |
| result.append({"type": "text", "text": block}) | |
| else: | |
| result.append(block) | |
| return result | |
| return [{"type": "text", "text": str(content)}] | |
| def retrieve_rag_context(messages: list, k: int = 3) -> str: | |
| if not rag_chunks or faiss_index is None: | |
| return "" | |
| last_user_text = "" | |
| for m in reversed(messages): | |
| if m.get("role") != "user": | |
| continue | |
| content = m.get("content", "") | |
| if isinstance(content, list): | |
| for block in content: | |
| if isinstance(block, dict) and block.get("type") == "text": | |
| last_user_text = block["text"] | |
| break | |
| elif isinstance(content, str): | |
| last_user_text = content | |
| if last_user_text: | |
| break | |
| if not last_user_text.strip(): | |
| return "" | |
| query_vec = embedder.encode([last_user_text]) | |
| _, indices = faiss_index.search(query_vec, k=k) | |
| chunks = [chunk_to_text(rag_chunks[i]) for i in indices[0] if i < len(rag_chunks)] | |
| return "\n\n".join(chunks) | |
| def build_full_messages(messages: list, image: Image.Image, rag_context: str) -> list: | |
| system_parts = ["You are a plant disease expert assistant."] | |
| if rag_context: | |
| system_parts.append( | |
| "Use the following retrieved knowledge to inform your answer:\n\n" + rag_context | |
| ) | |
| system_prompt = "\n\n".join(system_parts) | |
| # content MUST be list of dicts β never plain string | |
| full_messages = [ | |
| {"role": "user", "content": [{"type": "text", "text": system_prompt}]}, | |
| {"role": "assistant", "content": [{"type": "text", "text": "Understood. I will use this knowledge to help you."}]}, | |
| ] | |
| norm = [ | |
| {"role": m["role"], "content": to_content_list(m.get("content", ""))} | |
| for m in messages | |
| ] | |
| if image is not None: | |
| for i in range(len(norm) - 1, -1, -1): | |
| if norm[i]["role"] == "user": | |
| norm[i]["content"] = [{"type": "image", "image": image}] + norm[i]["content"] | |
| break | |
| full_messages.extend(norm) | |
| return full_messages | |
| # βββββββββββββββββββββββββββββ | |
| # UNIFIED ENDPOINT | |
| # βββββββββββββββββββββββββββββ | |
| def chat(req: ChatRequest): | |
| image = decode_image(req.image) if req.image else None | |
| rag_context = "" if image else retrieve_rag_context(req.messages) | |
| full_messages = build_full_messages(req.messages, image, rag_context) | |
| # apply_chat_template with tokenize=True returns a plain Tensor, not a dict | |
| # use return_dict=True to get {"input_ids": ..., "attention_mask": ...} | |
| inputs = processor.apply_chat_template( | |
| full_messages, | |
| add_generation_prompt=True, | |
| tokenize=True, | |
| return_tensors="pt", | |
| return_dict=True, # β fixes: argument after ** must be a mapping, not Tensor | |
| ) | |
| inputs = {k: v.to(model.device) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| output_ids = model.generate( | |
| **inputs, | |
| max_new_tokens=512, | |
| temperature=0.7, | |
| top_p=0.9, | |
| ) | |
| # decode only the newly generated tokens (skip the input prompt) | |
| input_len = inputs["input_ids"].shape[1] | |
| new_tokens = output_ids[0][input_len:] | |
| response_text = processor.decode(new_tokens, skip_special_tokens=True) | |
| return { | |
| "response": response_text, | |
| "rag_used": bool(rag_context), | |
| "image_used": image is not None, | |
| } | |
| # βββββββββββββββββββββββββββββ | |
| # HEALTH CHECK | |
| # βββββββββββββββββββββββββββββ | |
| def root(): | |
| return {"status": "plant disease chat api running"} |