ikram02ii commited on
Commit
30ac73a
·
verified ·
1 Parent(s): b6aa1c5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +79 -35
app.py CHANGED
@@ -1,61 +1,105 @@
1
- from fastapi import FastAPI, UploadFile, File
 
2
  from fastapi.middleware.cors import CORSMiddleware
3
- from transformers import AutoImageProcessor, SiglipForImageClassification
4
  from PIL import Image
 
5
  import torch
 
6
 
7
- MODEL_ID = "prithivMLmods/Trash-Net"
8
-
9
- PROCESSOR = AutoImageProcessor.from_pretrained(MODEL_ID)
10
- MODEL = SiglipForImageClassification.from_pretrained(MODEL_ID)
11
- MODEL.eval()
12
-
13
- id2label = MODEL.config.id2label
14
- CLASS_NAMES = [id2label[i] for i in range(len(id2label))]
15
- RECYCLABLE = {"cardboard", "glass", "metal", "paper", "plastic"}
16
-
17
- CONFIDENCE_THRESHOLD = 0.70
18
- MARGIN_THRESHOLD = 0.15
19
 
20
  app = FastAPI()
 
 
21
  app.add_middleware(
22
  CORSMiddleware,
23
  allow_origins=["*"],
 
24
  allow_methods=["*"],
25
  allow_headers=["*"],
26
  )
27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  @app.get("/health")
29
  def health():
30
- return {"ok": True, "model": MODEL_ID}
 
31
 
 
 
 
32
  @app.post("/predict")
33
- async def predict(file: UploadFile = File(...)):
34
- img = Image.open(file.file).convert("RGB")
35
- inputs = PROCESSOR(images=img, return_tensors="pt")
 
 
 
36
 
37
- with torch.no_grad():
38
- probs = torch.softmax(MODEL(**inputs).logits, dim=1)[0]
39
- top2 = torch.topk(probs, k=2)
 
 
 
 
 
40
 
41
- top1_idx = int(top2.indices[0].item())
42
- top2_idx = int(top2.indices[1].item())
43
- top1_prob = float(top2.values[0].item())
44
- top2_prob = float(top2.values[1].item())
45
- margin = top1_prob - top2_prob
46
 
47
- pred_class = CLASS_NAMES[top1_idx]
 
48
 
49
- if top1_prob < CONFIDENCE_THRESHOLD or margin < MARGIN_THRESHOLD:
50
  return {
51
- "class": "unknown",
52
- "recyclable": False,
53
- "confidence": round(top1_prob, 4),
54
- "message": "Not confident enough. Retake photo.",
55
  }
56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  return {
58
- "class": pred_class,
59
- "recyclable": pred_class in RECYCLABLE,
60
- "confidence": round(top1_prob, 4),
61
  }
 
1
+ import os
2
+ from fastapi import FastAPI, UploadFile, File, Form
3
  from fastapi.middleware.cors import CORSMiddleware
 
4
  from PIL import Image
5
+ import io
6
  import torch
7
+ from transformers import AutoImageProcessor, SiglipForImageClassification
8
 
9
+ # Roboflow client
10
+ from inference_sdk import InferenceHTTPClient
 
 
 
 
 
 
 
 
 
 
11
 
12
  app = FastAPI()
13
+
14
+ # Allow frontend (Firebase) to call this API
15
  app.add_middleware(
16
  CORSMiddleware,
17
  allow_origins=["*"],
18
+ allow_credentials=True,
19
  allow_methods=["*"],
20
  allow_headers=["*"],
21
  )
22
 
23
+ # ===============================
24
+ # 🔹 Everyday recyclables model (HF)
25
+ # ===============================
26
+ MODEL_NAME = "prithivMLmods/Trash-Net"
27
+
28
+ processor = AutoImageProcessor.from_pretrained(MODEL_NAME)
29
+ model = SiglipForImageClassification.from_pretrained(MODEL_NAME)
30
+ model.eval()
31
+
32
+ LABELS = model.config.id2label
33
+
34
+
35
+ # ===============================
36
+ # 🔹 Roboflow E-waste client
37
+ # ===============================
38
+ RF_CLIENT = InferenceHTTPClient(
39
+ api_url="https://serverless.roboflow.com",
40
+ api_key=os.getenv("ROBOFLOW_API_KEY") # set in Space secrets
41
+ )
42
+
43
+ RF_MODEL_ID = os.getenv("ROBOFLOW_MODEL_ID", "e-waste-2ecoq/2")
44
+
45
+
46
+ # ===============================
47
+ # 🔹 Health check
48
+ # ===============================
49
  @app.get("/health")
50
  def health():
51
+ return {"status": "ok"}
52
+
53
 
54
+ # ===============================
55
+ # 🔹 Unified predict endpoint
56
+ # ===============================
57
  @app.post("/predict")
58
+ async def predict(
59
+ file: UploadFile = File(...),
60
+ category: str = Form("regular") # "regular" or "ewaste"
61
+ ):
62
+ image_bytes = await file.read()
63
+ image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
64
 
65
+ # ---------------------------
66
+ # 🔸 E-WASTE → Roboflow API
67
+ # ---------------------------
68
+ if category == "ewaste":
69
+ result = RF_CLIENT.infer(image, model_id=RF_MODEL_ID)
70
+
71
+ if not result["predictions"]:
72
+ return {"error": "No object detected"}
73
 
74
+ pred = result["predictions"][0]
 
 
 
 
75
 
76
+ label = pred["class"]
77
+ confidence = pred["confidence"]
78
 
 
79
  return {
80
+ "class": label.lower(),
81
+ "recyclable": True,
82
+ "confidence": confidence
 
83
  }
84
 
85
+ # ---------------------------
86
+ # 🔸 REGULAR → HF model
87
+ # ---------------------------
88
+ inputs = processor(images=image, return_tensors="pt")
89
+
90
+ with torch.no_grad():
91
+ outputs = model(**inputs)
92
+ probs = torch.softmax(outputs.logits, dim=1)
93
+
94
+ score, idx = torch.max(probs, dim=1)
95
+
96
+ label = LABELS[idx.item()].lower()
97
+ confidence = float(score.item())
98
+
99
+ recyclable_classes = ["plastic", "paper", "metal", "glass", "cardboard"]
100
+
101
  return {
102
+ "class": label,
103
+ "recyclable": label in recyclable_classes,
104
+ "confidence": confidence
105
  }