Spaces:
Sleeping
Sleeping
File size: 7,589 Bytes
d3d53b0 57bbdbe d3d53b0 57bbdbe d3d53b0 57bbdbe 879c56d 6d700fa 879c56d 57bbdbe d3d53b0 57bbdbe d3d53b0 6d700fa 57bbdbe d3d53b0 57bbdbe d3d53b0 57bbdbe 6d700fa d3d53b0 6d700fa 57bbdbe eab4ea1 57bbdbe d3d53b0 56d265c 69eb2dc d3d53b0 6d700fa d3d53b0 6d700fa 56d265c 6d700fa d3d53b0 879c56d d3d53b0 879c56d 57bbdbe 6d700fa 57bbdbe 6d700fa 879c56d 57bbdbe d3d53b0 6d700fa d3d53b0 6d700fa eab4ea1 56d265c 6d700fa 56d265c 6d700fa eab4ea1 56d265c 5b9d376 56d265c eab4ea1 6d700fa 56d265c 6d700fa 56d265c 6d700fa 57bbdbe 5b9d376 6d700fa 56d265c 879c56d 56d265c 7c14e50 6d700fa 56d265c 6d700fa 7c14e50 56d265c 6d700fa 57bbdbe 879c56d eab4ea1 879c56d 6d700fa eab4ea1 6d700fa 5b9d376 879c56d 6d700fa 879c56d 6d700fa 5b9d376 879c56d 6d700fa 879c56d 6d700fa 879c56d 5b9d376 6d700fa 879c56d 6d700fa | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 | import os
import base64
import torch
import faiss
import json
from fastapi import FastAPI
from pydantic import BaseModel
from contextlib import asynccontextmanager
from huggingface_hub import snapshot_download
from sentence_transformers import SentenceTransformer
from PIL import Image
from io import BytesIO
from transformers import (
AutoProcessor,
Qwen3VLForConditionalGeneration,
)
# βββββββββββββββββββββββββββββ
# CONFIG
# βββββββββββββββββββββββββββββ
MODEL_REPO = "Rady10/Plant-Disease-Qwen3VL-2B"
RAG_REPO = "Rady10/Agriculture-Rag-Data-Index"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# βββββββββββββββββββββββββββββ
# GLOBALS
# βββββββββββββββββββββββββββββ
model = None
processor = None
faiss_index = None
rag_chunks = None
embedder = None
# βββββββββββββββββββββββββββββ
# LIFESPAN
# βββββββββββββββββββββββββββββ
@asynccontextmanager
async def lifespan(app: FastAPI):
global model, processor, faiss_index, rag_chunks, embedder
print("Loading vision model...")
processor = AutoProcessor.from_pretrained(MODEL_REPO, trust_remote_code=True)
model = Qwen3VLForConditionalGeneration.from_pretrained(
MODEL_REPO,
torch_dtype=torch.float32,
device_map="cpu",
trust_remote_code=True,
)
model.eval()
print("Loading RAG index...")
rag_dir = snapshot_download(repo_id=RAG_REPO, repo_type="dataset", local_dir="./rag")
faiss_index = faiss.read_index(os.path.join(rag_dir, "agro.index"))
with open(os.path.join(rag_dir, "chunks.json"), "r", encoding="utf-8") as f:
rag_chunks = json.load(f)
embedder = SentenceTransformer(
"sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
)
print("ALL LOADED β")
yield
# βββββββββββββββββββββββββββββ
# APP
# βββββββββββββββββββββββββββββ
app = FastAPI(title="πΏ Plant Disease Chat API", lifespan=lifespan)
# βββββββββββββββββββββββββββββ
# REQUEST MODEL
# βββββββββββββββββββββββββββββ
class ChatRequest(BaseModel):
messages: list
image: str = None # base64 β if given, RAG is skipped automatically
# βββββββββββββββββββββββββββββ
# HELPERS
# βββββββββββββββββββββββββββββ
def decode_image(b64: str) -> Image.Image:
return Image.open(BytesIO(base64.b64decode(b64))).convert("RGB")
def chunk_to_text(chunk) -> str:
if isinstance(chunk, str):
return chunk
if isinstance(chunk, dict):
for key in ("text", "content", "passage", "chunk", "body"):
if key in chunk and isinstance(chunk[key], str):
return chunk[key]
return " ".join(str(v) for v in chunk.values())
return str(chunk)
def to_content_list(content) -> list:
"""content must always be a list of dicts for apply_chat_template"""
if isinstance(content, str):
return [{"type": "text", "text": content}]
if isinstance(content, list):
result = []
for block in content:
if isinstance(block, str):
result.append({"type": "text", "text": block})
else:
result.append(block)
return result
return [{"type": "text", "text": str(content)}]
def retrieve_rag_context(messages: list, k: int = 3) -> str:
if not rag_chunks or faiss_index is None:
return ""
last_user_text = ""
for m in reversed(messages):
if m.get("role") != "user":
continue
content = m.get("content", "")
if isinstance(content, list):
for block in content:
if isinstance(block, dict) and block.get("type") == "text":
last_user_text = block["text"]
break
elif isinstance(content, str):
last_user_text = content
if last_user_text:
break
if not last_user_text.strip():
return ""
query_vec = embedder.encode([last_user_text])
_, indices = faiss_index.search(query_vec, k=k)
chunks = [chunk_to_text(rag_chunks[i]) for i in indices[0] if i < len(rag_chunks)]
return "\n\n".join(chunks)
def build_full_messages(messages: list, image: Image.Image, rag_context: str) -> list:
system_parts = ["You are a plant disease expert assistant."]
if rag_context:
system_parts.append(
"Use the following retrieved knowledge to inform your answer:\n\n" + rag_context
)
system_prompt = "\n\n".join(system_parts)
# content MUST be list of dicts β never plain string
full_messages = [
{"role": "user", "content": [{"type": "text", "text": system_prompt}]},
{"role": "assistant", "content": [{"type": "text", "text": "Understood. I will use this knowledge to help you."}]},
]
norm = [
{"role": m["role"], "content": to_content_list(m.get("content", ""))}
for m in messages
]
if image is not None:
for i in range(len(norm) - 1, -1, -1):
if norm[i]["role"] == "user":
norm[i]["content"] = [{"type": "image", "image": image}] + norm[i]["content"]
break
full_messages.extend(norm)
return full_messages
# βββββββββββββββββββββββββββββ
# UNIFIED ENDPOINT
# βββββββββββββββββββββββββββββ
@app.post("/chat")
def chat(req: ChatRequest):
image = decode_image(req.image) if req.image else None
rag_context = "" if image else retrieve_rag_context(req.messages)
full_messages = build_full_messages(req.messages, image, rag_context)
# apply_chat_template with tokenize=True returns a plain Tensor, not a dict
# use return_dict=True to get {"input_ids": ..., "attention_mask": ...}
inputs = processor.apply_chat_template(
full_messages,
add_generation_prompt=True,
tokenize=True,
return_tensors="pt",
return_dict=True, # β fixes: argument after ** must be a mapping, not Tensor
)
inputs = {k: v.to(model.device) for k, v in inputs.items()}
with torch.no_grad():
output_ids = model.generate(
**inputs,
max_new_tokens=512,
temperature=0.7,
top_p=0.9,
)
# decode only the newly generated tokens (skip the input prompt)
input_len = inputs["input_ids"].shape[1]
new_tokens = output_ids[0][input_len:]
response_text = processor.decode(new_tokens, skip_special_tokens=True)
return {
"response": response_text,
"rag_used": bool(rag_context),
"image_used": image is not None,
}
# βββββββββββββββββββββββββββββ
# HEALTH CHECK
# βββββββββββββββββββββββββββββ
@app.get("/")
def root():
return {"status": "plant disease chat api running"} |