Update app.py
Browse files
app.py
CHANGED
|
@@ -1,9 +1,3 @@
|
|
| 1 |
-
# app.py
|
| 2 |
-
# Purpose: FastAPI API for RecycloMate
|
| 3 |
-
# - POST /predict with (file, category=regular|ewaste)
|
| 4 |
-
# - regular -> HuggingFace Transformers model (Trash-Net)
|
| 5 |
-
# - ewaste -> Roboflow serverless API (no inference-sdk, no cv2)
|
| 6 |
-
|
| 7 |
import io
|
| 8 |
import os
|
| 9 |
from typing import Dict, Any
|
|
@@ -22,13 +16,49 @@ HF_MODEL_ID = "prithivMLmods/Trash-Net"
|
|
| 22 |
|
| 23 |
ROBOFLOW_API_KEY = os.getenv("ROBOFLOW_API_KEY", "")
|
| 24 |
ROBOFLOW_MODEL_ID = os.getenv("ROBOFLOW_MODEL_ID", "e-waste-2ecoq/2")
|
| 25 |
-
ROBOFLOW_URL = f"https://detect.roboflow.com/{ROBOFLOW_MODEL_ID}?api_key={ROBOFLOW_API_KEY}"
|
| 26 |
|
| 27 |
CONFIDENCE_THRESHOLD = 0.70
|
| 28 |
MARGIN_THRESHOLD = 0.15
|
| 29 |
|
| 30 |
RECYCLABLE = {"cardboard", "glass", "metal", "paper", "plastic"}
|
| 31 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
# ----------------------------
|
| 33 |
# Load HF model once
|
| 34 |
# ----------------------------
|
|
@@ -44,7 +74,6 @@ CLASS_NAMES = [id2label[i] for i in range(len(id2label))]
|
|
| 44 |
# ----------------------------
|
| 45 |
app = FastAPI()
|
| 46 |
|
| 47 |
-
# If you call this Space from Firebase Hosting, keep CORS open or restrict later
|
| 48 |
app.add_middleware(
|
| 49 |
CORSMiddleware,
|
| 50 |
allow_origins=["*"], # later you can put your firebase domain only
|
|
@@ -114,7 +143,6 @@ def roboflow_predict(image_bytes: bytes) -> Dict[str, Any]:
|
|
| 114 |
if not preds:
|
| 115 |
return {"class": "unknown", "recyclable": False, "confidence": 0.0}
|
| 116 |
|
| 117 |
-
# pick best prediction
|
| 118 |
best = max(preds, key=lambda p: p.get("confidence", 0) or 0)
|
| 119 |
raw_label = str(best.get("class", "unknown")).strip().lower()
|
| 120 |
conf = float(best.get("confidence", 0.0) or 0.0)
|
|
@@ -132,13 +160,11 @@ def roboflow_predict(image_bytes: bytes) -> Dict[str, Any]:
|
|
| 132 |
}
|
| 133 |
|
| 134 |
return {
|
| 135 |
-
"class": label,
|
| 136 |
"recyclable": False,
|
| 137 |
"confidence": round(conf, 4),
|
| 138 |
}
|
| 139 |
|
| 140 |
-
|
| 141 |
-
|
| 142 |
@app.post("/predict")
|
| 143 |
async def predict(
|
| 144 |
file: UploadFile = File(...),
|
|
@@ -151,7 +177,6 @@ async def predict(
|
|
| 151 |
if category == "ewaste":
|
| 152 |
return roboflow_predict(image_bytes)
|
| 153 |
|
| 154 |
-
# default: regular
|
| 155 |
return hf_predict(image)
|
| 156 |
|
| 157 |
except Exception as e:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import io
|
| 2 |
import os
|
| 3 |
from typing import Dict, Any
|
|
|
|
| 16 |
|
| 17 |
ROBOFLOW_API_KEY = os.getenv("ROBOFLOW_API_KEY", "")
|
| 18 |
ROBOFLOW_MODEL_ID = os.getenv("ROBOFLOW_MODEL_ID", "e-waste-2ecoq/2")
|
|
|
|
| 19 |
|
| 20 |
CONFIDENCE_THRESHOLD = 0.70
|
| 21 |
MARGIN_THRESHOLD = 0.15
|
| 22 |
|
| 23 |
RECYCLABLE = {"cardboard", "glass", "metal", "paper", "plastic"}
|
| 24 |
|
| 25 |
+
# ----------------------------
|
| 26 |
+
# ✅ E-waste label control (ADD THIS)
|
| 27 |
+
# ----------------------------
|
| 28 |
+
EWASTE_ALLOWED = {
|
| 29 |
+
"keyboards",
|
| 30 |
+
"mobile",
|
| 31 |
+
"mouses",
|
| 32 |
+
"tv",
|
| 33 |
+
"camera",
|
| 34 |
+
"laptop",
|
| 35 |
+
"microwave",
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
# Map common label variations -> your allowed labels
|
| 39 |
+
EWASTE_ALIAS = {
|
| 40 |
+
"keyboard": "keyboards",
|
| 41 |
+
"keyboards": "keyboards",
|
| 42 |
+
"mouse": "mouses",
|
| 43 |
+
"mice": "mouses",
|
| 44 |
+
"mouses": "mouses",
|
| 45 |
+
"phone": "mobile",
|
| 46 |
+
"phones": "mobile",
|
| 47 |
+
"cellphone": "mobile",
|
| 48 |
+
"cell phone": "mobile",
|
| 49 |
+
"smartphone": "mobile",
|
| 50 |
+
"mobile": "mobile",
|
| 51 |
+
"television": "tv",
|
| 52 |
+
"tvs": "tv",
|
| 53 |
+
"tv": "tv",
|
| 54 |
+
"camera": "camera",
|
| 55 |
+
"cameras": "camera",
|
| 56 |
+
"laptop": "laptop",
|
| 57 |
+
"laptops": "laptop",
|
| 58 |
+
"microwave": "microwave",
|
| 59 |
+
"microwaves": "microwave",
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
# ----------------------------
|
| 63 |
# Load HF model once
|
| 64 |
# ----------------------------
|
|
|
|
| 74 |
# ----------------------------
|
| 75 |
app = FastAPI()
|
| 76 |
|
|
|
|
| 77 |
app.add_middleware(
|
| 78 |
CORSMiddleware,
|
| 79 |
allow_origins=["*"], # later you can put your firebase domain only
|
|
|
|
| 143 |
if not preds:
|
| 144 |
return {"class": "unknown", "recyclable": False, "confidence": 0.0}
|
| 145 |
|
|
|
|
| 146 |
best = max(preds, key=lambda p: p.get("confidence", 0) or 0)
|
| 147 |
raw_label = str(best.get("class", "unknown")).strip().lower()
|
| 148 |
conf = float(best.get("confidence", 0.0) or 0.0)
|
|
|
|
| 160 |
}
|
| 161 |
|
| 162 |
return {
|
| 163 |
+
"class": label, # return the normalized label
|
| 164 |
"recyclable": False,
|
| 165 |
"confidence": round(conf, 4),
|
| 166 |
}
|
| 167 |
|
|
|
|
|
|
|
| 168 |
@app.post("/predict")
|
| 169 |
async def predict(
|
| 170 |
file: UploadFile = File(...),
|
|
|
|
| 177 |
if category == "ewaste":
|
| 178 |
return roboflow_predict(image_bytes)
|
| 179 |
|
|
|
|
| 180 |
return hf_predict(image)
|
| 181 |
|
| 182 |
except Exception as e:
|