Rady10 commited on
Commit
57bbdbe
Β·
verified Β·
1 Parent(s): 3e226a3

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +150 -0
app.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import base64
3
+ import torch
4
+ import numpy as np
5
+ import faiss
6
+ import json
7
+
8
+ from fastapi import FastAPI
9
+ from pydantic import BaseModel
10
+ from contextlib import asynccontextmanager
11
+ from huggingface_hub import snapshot_download
12
+ from sentence_transformers import SentenceTransformer
13
+ from PIL import Image
14
+ from io import BytesIO
15
+
16
+ from transformers import AutoProcessor, AutoModelForVision2Seq
17
+
18
+ # ─────────────────────────────
19
+ # CONFIG
20
+ # ─────────────────────────────
21
+ MODEL_REPO = "Rady10/Plant-Disease-Qwen3VL-2B"
22
+ RAG_REPO = "Rady10/Agriculture-Rag-Data-Index"
23
+
24
+ DEVICE = "cpu"
25
+
26
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
27
+
28
+ # ─────────────────────────────
29
+ # GLOBALS
30
+ # ─────────────────────────────
31
+ model = None
32
+ processor = None
33
+
34
+ faiss_index = None
35
+ rag_chunks = None
36
+ embedder = None
37
+
38
+ # ─────────────────────────────
39
+ # FASTAPI APP
40
+ # ─────────────────────────────
41
+ app = FastAPI(title="🌿 Plant Disease Vision API")
42
+
43
+ # ─────────────────────────────
44
+ # LOAD MODELS ONCE
45
+ # ─────────────────────────────
46
+ @asynccontextmanager
47
+ async def lifespan(app: FastAPI):
48
+
49
+ global model, processor, faiss_index, rag_chunks, embedder
50
+
51
+ print("Loading vision model...")
52
+
53
+ processor = AutoProcessor.from_pretrained(
54
+ MODEL_REPO,
55
+ trust_remote_code=True
56
+ )
57
+
58
+ model = AutoModelForVision2Seq.from_pretrained(
59
+ MODEL_REPO,
60
+ torch_dtype=torch.float32,
61
+ device_map="cpu",
62
+ trust_remote_code=True
63
+ )
64
+
65
+ model.eval()
66
+
67
+ # ───── RAG (optional but included) ─────
68
+ print("Loading RAG...")
69
+
70
+ rag_dir = snapshot_download(
71
+ repo_id=RAG_REPO,
72
+ repo_type="dataset",
73
+ local_dir="./rag"
74
+ )
75
+
76
+ faiss_index = faiss.read_index(
77
+ os.path.join(rag_dir, "agro.index")
78
+ )
79
+
80
+ with open(os.path.join(rag_dir, "chunks.json"), "r", encoding="utf-8") as f:
81
+ rag_chunks = json.load(f)
82
+
83
+ embedder = SentenceTransformer(
84
+ "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
85
+ )
86
+
87
+ print("ALL LOADED")
88
+
89
+ yield
90
+
91
+ app = FastAPI(lifespan=lifespan)
92
+
93
+ # ─────────────────────────────
94
+ # REQUEST MODEL
95
+ # ─────────────────────────────
96
+ class VisionRequest(BaseModel):
97
+ image: str # base64
98
+ text: str = ""
99
+
100
+ # ─────────────────────────────
101
+ # IMAGE DECODER
102
+ # ─────────────────────────────
103
+ def decode_image(base64_str):
104
+ img_data = base64.b64decode(base64_str)
105
+ return Image.open(BytesIO(img_data)).convert("RGB")
106
+
107
+ # ─────────────────────────────
108
+ # GENERATION
109
+ # ─────────────────────────────
110
+ def generate(image, text):
111
+
112
+ if text.strip() == "":
113
+ text = "What disease is shown in this plant image?"
114
+
115
+ inputs = processor(
116
+ text=text,
117
+ images=image,
118
+ return_tensors="pt"
119
+ )
120
+
121
+ inputs = {k: v.to(model.device) for k, v in inputs.items()}
122
+
123
+ with torch.no_grad():
124
+ output = model.generate(
125
+ **inputs,
126
+ max_new_tokens=256,
127
+ temperature=0.7,
128
+ top_p=0.9
129
+ )
130
+
131
+ return processor.batch_decode(
132
+ output,
133
+ skip_special_tokens=True
134
+ )[0]
135
+
136
+ # ─────────────────────────────
137
+ # API ROUTES
138
+ # ─────────────────────────────
139
+ @app.get("/")
140
+ def root():
141
+ return {"status": "vision api running"}
142
+
143
+ @app.post("/analyze")
144
+ def analyze(req: VisionRequest):
145
+
146
+ image = decode_image(req.image)
147
+
148
+ result = generate(image, req.text)
149
+
150
+ return {"response": result}