File size: 7,589 Bytes
d3d53b0
57bbdbe
 
d3d53b0
 
57bbdbe
 
 
d3d53b0
 
 
57bbdbe
 
 
879c56d
 
6d700fa
879c56d
57bbdbe
 
d3d53b0
57bbdbe
d3d53b0
6d700fa
57bbdbe
d3d53b0
57bbdbe
 
d3d53b0
57bbdbe
6d700fa
 
d3d53b0
6d700fa
 
57bbdbe
 
eab4ea1
57bbdbe
d3d53b0
 
 
 
 
56d265c
69eb2dc
d3d53b0
 
 
6d700fa
d3d53b0
 
 
6d700fa
56d265c
6d700fa
d3d53b0
 
 
 
 
 
 
879c56d
d3d53b0
 
879c56d
57bbdbe
6d700fa
57bbdbe
6d700fa
879c56d
57bbdbe
d3d53b0
6d700fa
d3d53b0
6d700fa
eab4ea1
56d265c
6d700fa
 
 
 
 
56d265c
 
6d700fa
 
eab4ea1
 
 
 
 
 
 
 
 
 
 
56d265c
5b9d376
56d265c
 
 
 
 
 
 
 
 
 
 
 
 
eab4ea1
6d700fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56d265c
6d700fa
 
 
 
 
 
 
56d265c
6d700fa
 
57bbdbe
5b9d376
6d700fa
56d265c
 
879c56d
 
56d265c
 
 
 
7c14e50
6d700fa
56d265c
 
 
6d700fa
7c14e50
56d265c
6d700fa
57bbdbe
879c56d
 
eab4ea1
879c56d
 
 
6d700fa
eab4ea1
6d700fa
 
5b9d376
 
879c56d
6d700fa
879c56d
6d700fa
 
5b9d376
 
 
879c56d
 
6d700fa
879c56d
6d700fa
 
 
879c56d
 
5b9d376
 
 
 
6d700fa
879c56d
6d700fa
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
import os
import base64
import torch
import faiss
import json

from fastapi import FastAPI
from pydantic import BaseModel
from contextlib import asynccontextmanager
from huggingface_hub import snapshot_download
from sentence_transformers import SentenceTransformer
from PIL import Image
from io import BytesIO

from transformers import (
    AutoProcessor,
    Qwen3VLForConditionalGeneration,
)

# ─────────────────────────────
# CONFIG
# ─────────────────────────────
MODEL_REPO = "Rady10/Plant-Disease-Qwen3VL-2B"
RAG_REPO   = "Rady10/Agriculture-Rag-Data-Index"

os.environ["TOKENIZERS_PARALLELISM"] = "false"

# ─────────────────────────────
# GLOBALS
# ─────────────────────────────
model       = None
processor   = None
faiss_index = None
rag_chunks  = None
embedder    = None

# ─────────────────────────────
# LIFESPAN
# ─────────────────────────────
@asynccontextmanager
async def lifespan(app: FastAPI):
    global model, processor, faiss_index, rag_chunks, embedder

    print("Loading vision model...")
    processor = AutoProcessor.from_pretrained(MODEL_REPO, trust_remote_code=True)
    model = Qwen3VLForConditionalGeneration.from_pretrained(
        MODEL_REPO,
        torch_dtype=torch.float32,
        device_map="cpu",
        trust_remote_code=True,
    )
    model.eval()

    print("Loading RAG index...")
    rag_dir = snapshot_download(repo_id=RAG_REPO, repo_type="dataset", local_dir="./rag")
    faiss_index = faiss.read_index(os.path.join(rag_dir, "agro.index"))
    with open(os.path.join(rag_dir, "chunks.json"), "r", encoding="utf-8") as f:
        rag_chunks = json.load(f)

    embedder = SentenceTransformer(
        "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
    )

    print("ALL LOADED βœ”")
    yield


# ─────────────────────────────
# APP
# ─────────────────────────────
app = FastAPI(title="🌿 Plant Disease Chat API", lifespan=lifespan)


# ─────────────────────────────
# REQUEST MODEL
# ─────────────────────────────
class ChatRequest(BaseModel):
    messages: list
    image: str = None     # base64 β€” if given, RAG is skipped automatically


# ─────────────────────────────
# HELPERS
# ─────────────────────────────
def decode_image(b64: str) -> Image.Image:
    return Image.open(BytesIO(base64.b64decode(b64))).convert("RGB")


def chunk_to_text(chunk) -> str:
    if isinstance(chunk, str):
        return chunk
    if isinstance(chunk, dict):
        for key in ("text", "content", "passage", "chunk", "body"):
            if key in chunk and isinstance(chunk[key], str):
                return chunk[key]
        return " ".join(str(v) for v in chunk.values())
    return str(chunk)


def to_content_list(content) -> list:
    """content must always be a list of dicts for apply_chat_template"""
    if isinstance(content, str):
        return [{"type": "text", "text": content}]
    if isinstance(content, list):
        result = []
        for block in content:
            if isinstance(block, str):
                result.append({"type": "text", "text": block})
            else:
                result.append(block)
        return result
    return [{"type": "text", "text": str(content)}]


def retrieve_rag_context(messages: list, k: int = 3) -> str:
    if not rag_chunks or faiss_index is None:
        return ""

    last_user_text = ""
    for m in reversed(messages):
        if m.get("role") != "user":
            continue
        content = m.get("content", "")
        if isinstance(content, list):
            for block in content:
                if isinstance(block, dict) and block.get("type") == "text":
                    last_user_text = block["text"]
                    break
        elif isinstance(content, str):
            last_user_text = content
        if last_user_text:
            break

    if not last_user_text.strip():
        return ""

    query_vec = embedder.encode([last_user_text])
    _, indices = faiss_index.search(query_vec, k=k)
    chunks = [chunk_to_text(rag_chunks[i]) for i in indices[0] if i < len(rag_chunks)]
    return "\n\n".join(chunks)


def build_full_messages(messages: list, image: Image.Image, rag_context: str) -> list:
    system_parts = ["You are a plant disease expert assistant."]
    if rag_context:
        system_parts.append(
            "Use the following retrieved knowledge to inform your answer:\n\n" + rag_context
        )
    system_prompt = "\n\n".join(system_parts)

    # content MUST be list of dicts β€” never plain string
    full_messages = [
        {"role": "user",      "content": [{"type": "text", "text": system_prompt}]},
        {"role": "assistant", "content": [{"type": "text", "text": "Understood. I will use this knowledge to help you."}]},
    ]

    norm = [
        {"role": m["role"], "content": to_content_list(m.get("content", ""))}
        for m in messages
    ]

    if image is not None:
        for i in range(len(norm) - 1, -1, -1):
            if norm[i]["role"] == "user":
                norm[i]["content"] = [{"type": "image", "image": image}] + norm[i]["content"]
                break

    full_messages.extend(norm)
    return full_messages


# ─────────────────────────────
# UNIFIED ENDPOINT
# ─────────────────────────────
@app.post("/chat")
def chat(req: ChatRequest):
    image = decode_image(req.image) if req.image else None
    rag_context = "" if image else retrieve_rag_context(req.messages)
    full_messages = build_full_messages(req.messages, image, rag_context)

    # apply_chat_template with tokenize=True returns a plain Tensor, not a dict
    # use return_dict=True to get {"input_ids": ..., "attention_mask": ...}
    inputs = processor.apply_chat_template(
        full_messages,
        add_generation_prompt=True,
        tokenize=True,
        return_tensors="pt",
        return_dict=True,          # ← fixes: argument after ** must be a mapping, not Tensor
    )
    inputs = {k: v.to(model.device) for k, v in inputs.items()}

    with torch.no_grad():
        output_ids = model.generate(
            **inputs,
            max_new_tokens=512,
            temperature=0.7,
            top_p=0.9,
        )

    # decode only the newly generated tokens (skip the input prompt)
    input_len = inputs["input_ids"].shape[1]
    new_tokens = output_ids[0][input_len:]
    response_text = processor.decode(new_tokens, skip_special_tokens=True)

    return {
        "response":   response_text,
        "rag_used":   bool(rag_context),
        "image_used": image is not None,
    }


# ─────────────────────────────
# HEALTH CHECK
# ─────────────────────────────
@app.get("/")
def root():
    return {"status": "plant disease chat api running"}