Rady10 commited on
Commit
7c14e50
Β·
verified Β·
1 Parent(s): f5eb5ef

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -112
app.py CHANGED
@@ -1,150 +1,78 @@
1
- import os
2
  import base64
3
  import torch
4
- import numpy as np
5
- import faiss
6
- import json
7
 
8
  from fastapi import FastAPI
9
  from pydantic import BaseModel
10
- from contextlib import asynccontextmanager
11
- from huggingface_hub import snapshot_download
12
- from sentence_transformers import SentenceTransformer
13
  from PIL import Image
14
  from io import BytesIO
15
 
16
- from transformers import AutoProcessor, AutoModelForCausalLM
17
 
18
  # ─────────────────────────────
19
- # CONFIG
20
  # ─────────────────────────────
21
- MODEL_REPO = "Rady10/Plant-Disease-Qwen3VL-2B"
22
- RAG_REPO = "Rady10/Agriculture-Rag-Data-Index"
23
 
24
- DEVICE = "cpu"
25
 
26
- os.environ["TOKENIZERS_PARALLELISM"] = "false"
27
-
28
- # ─────────────────────────────
29
- # GLOBALS
30
  # ─────────────────────────────
31
- model = None
32
- processor = None
33
-
34
- faiss_index = None
35
- rag_chunks = None
36
- embedder = None
37
-
38
- # ─────────────────────────────
39
- # FASTAPI APP
40
- # ─────────────────────────────
41
- app = FastAPI(title="🌿 Plant Disease Vision API")
42
-
43
  # ─────────────────────────────
44
- # LOAD MODELS ONCE
45
- # ─────────────────────────────
46
- @asynccontextmanager
47
- async def lifespan(app: FastAPI):
48
-
49
- global model, processor, faiss_index, rag_chunks, embedder
50
-
51
- print("Loading vision model...")
52
-
53
- processor = AutoProcessor.from_pretrained(
54
- MODEL_REPO,
55
- trust_remote_code=True
56
- )
57
-
58
- model = AutoModelForCausalLM.from_pretrained(
59
- MODEL_REPO,
60
- torch_dtype=torch.float32,
61
- device_map="cpu",
62
- trust_remote_code=True
63
- )
64
-
65
- model.eval()
66
 
67
- # ───── RAG (optional but included) ─────
68
- print("Loading RAG...")
69
 
70
- rag_dir = snapshot_download(
71
- repo_id=RAG_REPO,
72
- repo_type="dataset",
73
- local_dir="./rag"
74
- )
75
 
76
- faiss_index = faiss.read_index(
77
- os.path.join(rag_dir, "agro.index")
78
- )
79
-
80
- with open(os.path.join(rag_dir, "chunks.json"), "r", encoding="utf-8") as f:
81
- rag_chunks = json.load(f)
82
-
83
- embedder = SentenceTransformer(
84
- "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
85
- )
86
-
87
- print("ALL LOADED")
88
-
89
- yield
90
-
91
- app = FastAPI(lifespan=lifespan)
92
 
93
  # ─────────────────────────────
94
- # REQUEST MODEL
95
  # ─────────────────────────────
96
- class VisionRequest(BaseModel):
97
- image: str # base64
98
- text: str = ""
99
 
100
  # ─────────────────────────────
101
- # IMAGE DECODER
102
  # ─────────────────────────────
103
- def decode_image(base64_str):
104
- img_data = base64.b64decode(base64_str)
105
- return Image.open(BytesIO(img_data)).convert("RGB")
 
106
 
107
  # ─────────────────────────────
108
- # GENERATION
109
  # ─────────────────────────────
110
- def generate(image, text):
111
 
112
- if text.strip() == "":
113
- text = "What disease is shown in this plant image?"
114
 
115
- inputs = processor(
116
- text=text,
117
- images=image,
118
- return_tensors="pt"
119
- )
120
 
121
- inputs = {k: v.to(model.device) for k, v in inputs.items()}
122
 
123
- with torch.no_grad():
124
- output = model.generate(
125
- **inputs,
126
- max_new_tokens=256,
127
- temperature=0.7,
128
- top_p=0.9
129
- )
130
-
131
- return processor.batch_decode(
132
- output,
133
- skip_special_tokens=True
134
- )[0]
135
 
136
  # ─────────────────────────────
137
- # API ROUTES
138
  # ────────���────────────────────
139
- @app.get("/")
140
- def root():
141
- return {"status": "vision api running"}
142
-
143
- @app.post("/analyze")
144
- def analyze(req: VisionRequest):
145
 
146
  image = decode_image(req.image)
147
 
148
- result = generate(image, req.text)
 
 
 
 
149
 
150
- return {"response": result}
 
 
 
 
1
  import base64
2
  import torch
 
 
 
3
 
4
  from fastapi import FastAPI
5
  from pydantic import BaseModel
 
 
 
6
  from PIL import Image
7
  from io import BytesIO
8
 
9
+ from transformers import AutoImageProcessor, AutoModelForImageClassification
10
 
11
  # ─────────────────────────────
12
+ # MODEL (REPLACE WITH YOUR CLASSIFIER)
13
  # ─────────────────────────────
14
+ MODEL_REPO = "Rady10/Plant-Disease-Qwen3VL-2B" # or your plant model
 
15
 
16
+ app = FastAPI(title="🌿 Plant Disease Classifier")
17
 
 
 
 
 
18
  # ─────────────────────────────
19
+ # LOAD MODEL
 
 
 
 
 
 
 
 
 
 
 
20
  # ─────────────────────────────
21
+ print("Loading classifier...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
+ processor = AutoImageProcessor.from_pretrained(MODEL_REPO)
24
+ model = AutoModelForImageClassification.from_pretrained(MODEL_REPO)
25
 
26
+ model.eval()
 
 
 
 
27
 
28
+ print("Model loaded")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
  # ─────────────────────────────
31
+ # REQUEST FORMAT
32
  # ─────────────────────────────
33
+ class Request(BaseModel):
34
+ image: str # base64 string
 
35
 
36
  # ─────────────────────────────
37
+ # DECODE IMAGE
38
  # ─────────────────────────────
39
+ def decode_image(b64):
40
+ return Image.open(
41
+ BytesIO(base64.b64decode(b64))
42
+ ).convert("RGB")
43
 
44
  # ─────────────────────────────
45
+ # PREDICTION FUNCTION
46
  # ─────────────────────────────
47
+ def predict(image):
48
 
49
+ inputs = processor(images=image, return_tensors="pt")
 
50
 
51
+ with torch.no_grad():
52
+ outputs = model(**inputs)
 
 
 
53
 
54
+ logits = outputs.logits
55
 
56
+ pred_id = logits.argmax(-1).item()
57
+
58
+ label = model.config.id2label[pred_id]
59
+
60
+ return label
 
 
 
 
 
 
 
61
 
62
  # ─────────────────────────────
63
+ # API ENDPOINT
64
  # ────────���────────────────────
65
+ @app.post("/predict")
66
+ def predict_api(req: Request):
 
 
 
 
67
 
68
  image = decode_image(req.image)
69
 
70
+ label = predict(image)
71
+
72
+ return {
73
+ "prediction": label
74
+ }
75
 
76
+ @app.get("/")
77
+ def home():
78
+ return {"status": "classifier running"}