Rady10 commited on
Commit
d3d53b0
Β·
verified Β·
1 Parent(s): 0a0b6f8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +112 -40
app.py CHANGED
@@ -1,78 +1,150 @@
 
1
  import base64
2
  import torch
 
 
 
3
 
4
  from fastapi import FastAPI
5
  from pydantic import BaseModel
 
 
 
6
  from PIL import Image
7
  from io import BytesIO
8
 
9
- from transformers import AutoImageProcessor, AutoModelForImageClassification
10
 
11
  # ─────────────────────────────
12
- # MODEL (REPLACE WITH YOUR CLASSIFIER)
13
  # ─────────────────────────────
14
- MODEL_REPO = "Rady10/Plant-Disease-Qwen3VL-2B" # or your plant model
 
15
 
16
- app = FastAPI(title="🌿 Plant Disease Classifier")
 
 
17
 
18
  # ─────────────────────────────
19
- # LOAD MODEL
20
  # ─────────────────────────────
21
- print("Loading classifier...")
22
-
23
- processor = AutoImageProcessor.from_pretrained(MODEL_REPO)
24
- model = AutoModelForImageClassification.from_pretrained(MODEL_REPO)
25
 
26
- model.eval()
27
-
28
- print("Model loaded")
29
 
30
  # ─────────────────────────────
31
- # REQUEST FORMAT
32
  # ─────────────────────────────
33
- class Request(BaseModel):
34
- image: str # base64 string
35
 
36
  # ─────────────────────────────
37
- # DECODE IMAGE
38
  # ─────────────────────────────
39
- def decode_image(b64):
40
- return Image.open(
41
- BytesIO(base64.b64decode(b64))
42
- ).convert("RGB")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
  # ─────────────────────────────
45
- # PREDICTION FUNCTION
46
  # ─────────────────────────────
47
- def predict(image):
 
 
48
 
49
- inputs = processor(images=image, return_tensors="pt")
 
 
 
 
 
50
 
51
- with torch.no_grad():
52
- outputs = model(**inputs)
 
 
53
 
54
- logits = outputs.logits
 
55
 
56
- pred_id = logits.argmax(-1).item()
 
 
 
 
57
 
58
- label = model.config.id2label[pred_id]
59
 
60
- return label
 
 
 
 
 
 
 
 
 
 
 
61
 
62
  # ─────────────────────────────
63
- # API ENDPOINT
64
  # ──────��──────────────────────
65
- @app.post("/predict")
66
- def predict_api(req: Request):
 
67
 
68
- image = decode_image(req.image)
 
69
 
70
- label = predict(image)
71
 
72
- return {
73
- "prediction": label
74
- }
75
 
76
- @app.get("/")
77
- def home():
78
- return {"status": "classifier running"}
 
1
+ import os
2
  import base64
3
  import torch
4
+ import numpy as np
5
+ import faiss
6
+ import json
7
 
8
  from fastapi import FastAPI
9
  from pydantic import BaseModel
10
+ from contextlib import asynccontextmanager
11
+ from huggingface_hub import snapshot_download
12
+ from sentence_transformers import SentenceTransformer
13
  from PIL import Image
14
  from io import BytesIO
15
 
16
+ from transformers import AutoProcessor, AutoModelForVision2Seq
17
 
18
  # ─────────────────────────────
19
+ # CONFIG
20
  # ─────────────────────────────
21
+ MODEL_REPO = "Rady10/Plant-Disease-Qwen3VL-2B"
22
+ RAG_REPO = "Rady10/Plant-Disease-AgroRAG-Index"
23
 
24
+ DEVICE = "cpu"
25
+
26
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
27
 
28
  # ─────────────────────────────
29
+ # GLOBALS
30
  # ─────────────────────────────
31
+ model = None
32
+ processor = None
 
 
33
 
34
+ faiss_index = None
35
+ rag_chunks = None
36
+ embedder = None
37
 
38
  # ─────────────────────────────
39
+ # FASTAPI APP
40
  # ─────────────────────────────
41
+ app = FastAPI(title="🌿 Plant Disease Vision API")
 
42
 
43
  # ─────────────────────────────
44
+ # LOAD MODELS ONCE
45
  # ─────────────────────────────
46
+ @asynccontextmanager
47
+ async def lifespan(app: FastAPI):
48
+
49
+ global model, processor, faiss_index, rag_chunks, embedder
50
+
51
+ print("Loading vision model...")
52
+
53
+ processor = AutoProcessor.from_pretrained(
54
+ MODEL_REPO,
55
+ trust_remote_code=True
56
+ )
57
+
58
+ model = AutoModelForVision2Seq.from_pretrained(
59
+ MODEL_REPO,
60
+ torch_dtype=torch.float32,
61
+ device_map="cpu",
62
+ trust_remote_code=True
63
+ )
64
+
65
+ model.eval()
66
+
67
+ # ───── RAG (optional but included) ─────
68
+ print("Loading RAG...")
69
+
70
+ rag_dir = snapshot_download(
71
+ repo_id=RAG_REPO,
72
+ repo_type="dataset",
73
+ local_dir="./rag"
74
+ )
75
+
76
+ faiss_index = faiss.read_index(
77
+ os.path.join(rag_dir, "agro.index")
78
+ )
79
+
80
+ with open(os.path.join(rag_dir, "chunks.json"), "r", encoding="utf-8") as f:
81
+ rag_chunks = json.load(f)
82
+
83
+ embedder = SentenceTransformer(
84
+ "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
85
+ )
86
+
87
+ print("ALL LOADED")
88
+
89
+ yield
90
+
91
+ app = FastAPI(lifespan=lifespan)
92
 
93
  # ─────────────────────────────
94
+ # REQUEST MODEL
95
  # ─────────────────────────────
96
+ class VisionRequest(BaseModel):
97
+ image: str # base64
98
+ text: str = ""
99
 
100
+ # ─────────────────────────────
101
+ # IMAGE DECODER
102
+ # ─────────────────────────────
103
+ def decode_image(base64_str):
104
+ img_data = base64.b64decode(base64_str)
105
+ return Image.open(BytesIO(img_data)).convert("RGB")
106
 
107
+ # ─────────────────────────────
108
+ # GENERATION
109
+ # ─────────────────────────────
110
+ def generate(image, text):
111
 
112
+ if text.strip() == "":
113
+ text = "What disease is shown in this plant image?"
114
 
115
+ inputs = processor(
116
+ text=text,
117
+ images=image,
118
+ return_tensors="pt"
119
+ )
120
 
121
+ inputs = {k: v.to(model.device) for k, v in inputs.items()}
122
 
123
+ with torch.no_grad():
124
+ output = model.generate(
125
+ **inputs,
126
+ max_new_tokens=256,
127
+ temperature=0.7,
128
+ top_p=0.9
129
+ )
130
+
131
+ return processor.batch_decode(
132
+ output,
133
+ skip_special_tokens=True
134
+ )[0]
135
 
136
  # ─────────────────────────────
137
+ # API ROUTES
138
  # ──────��──────────────────────
139
+ @app.get("/")
140
+ def root():
141
+ return {"status": "vision api running"}
142
 
143
+ @app.post("/analyze")
144
+ def analyze(req: VisionRequest):
145
 
146
+ image = decode_image(req.image)
147
 
148
+ result = generate(image, req.text)
 
 
149
 
150
+ return {"response": result}