HF / app.py
ikram02ii's picture
Update app.py
cf89185 verified
import io
import os
from typing import Dict, Any
import requests
import torch
from PIL import Image
from fastapi import FastAPI, File, UploadFile, Form
from fastapi.middleware.cors import CORSMiddleware
from transformers import AutoImageProcessor, SiglipForImageClassification
# ----------------------------
# Config
# ----------------------------
HF_MODEL_ID = "prithivMLmods/Trash-Net"
ROBOFLOW_API_KEY = os.getenv("ROBOFLOW_API_KEY", "")
ROBOFLOW_MODEL_ID = os.getenv("ROBOFLOW_MODEL_ID", "e-waste-2ecoq/2")
CONFIDENCE_THRESHOLD = 0.70
MARGIN_THRESHOLD = 0.15
RECYCLABLE = {"cardboard", "glass", "metal", "paper", "plastic"}
# ----------------------------
# ✅ E-waste label control (ADD THIS)
# ----------------------------
EWASTE_ALLOWED = {
"keyboards",
"mobile",
"mouses",
"tv",
"camera",
"laptop",
"microwave",
}
# Map common label variations -> your allowed labels
EWASTE_ALIAS = {
"keyboard": "keyboards",
"keyboards": "keyboards",
"mouse": "mouses",
"mice": "mouses",
"mouses": "mouses",
"phone": "mobile",
"phones": "mobile",
"cellphone": "mobile",
"cell phone": "mobile",
"smartphone": "mobile",
"mobile": "mobile",
"television": "tv",
"tvs": "tv",
"tv": "tv",
"camera": "camera",
"cameras": "camera",
"laptop": "laptop",
"laptops": "laptop",
"microwave": "microwave",
"microwaves": "microwave",
}
# ----------------------------
# Load HF model once
# ----------------------------
PROCESSOR = AutoImageProcessor.from_pretrained(HF_MODEL_ID)
MODEL = SiglipForImageClassification.from_pretrained(HF_MODEL_ID)
MODEL.eval()
id2label = MODEL.config.id2label
CLASS_NAMES = [id2label[i] for i in range(len(id2label))]
# ----------------------------
# App
# ----------------------------
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # later you can put your firebase domain only
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.get("/health")
def health():
return {"ok": True}
def hf_predict(image: Image.Image) -> Dict[str, Any]:
inputs = PROCESSOR(images=image, return_tensors="pt")
with torch.no_grad():
outputs = MODEL(**inputs)
logits = outputs.logits
probs_tensor = torch.softmax(logits, dim=1)[0]
top2 = torch.topk(probs_tensor, k=2)
top1_idx = int(top2.indices[0].item())
top2_idx = int(top2.indices[1].item())
top1_prob = float(top2.values[0].item())
top2_prob = float(top2.values[1].item())
margin = top1_prob - top2_prob
pred_class = CLASS_NAMES[top1_idx]
if top1_prob < CONFIDENCE_THRESHOLD or margin < MARGIN_THRESHOLD:
return {
"class": "unknown",
"recyclable": False,
"confidence": round(top1_prob, 4),
"message": "Not confident enough. Retake photo (good lighting, centered, plain background).",
"debug": {
"top1": {"class": CLASS_NAMES[top1_idx], "prob": round(top1_prob, 4)},
"top2": {"class": CLASS_NAMES[top2_idx], "prob": round(top2_prob, 4)},
"margin": round(margin, 4),
},
}
return {
"class": pred_class,
"recyclable": pred_class in RECYCLABLE,
"confidence": round(top1_prob, 4),
}
def roboflow_predict(image_bytes: bytes) -> Dict[str, Any]:
if not ROBOFLOW_API_KEY:
return {"error": "ROBOFLOW_API_KEY is missing in Space secrets/env."}
url = f"https://detect.roboflow.com/{ROBOFLOW_MODEL_ID}?api_key={ROBOFLOW_API_KEY}"
resp = requests.post(
url,
files={"file": ("image.jpg", image_bytes, "image/jpeg")},
timeout=60,
)
if resp.status_code != 200:
return {"error": f"Roboflow error {resp.status_code}", "detail": resp.text}
data = resp.json()
preds = data.get("predictions", [])
if not preds:
return {"class": "unknown", "recyclable": False, "confidence": 0.0}
best = max(preds, key=lambda p: p.get("confidence", 0) or 0)
raw_label = str(best.get("class", "unknown")).strip().lower()
conf = float(best.get("confidence", 0.0) or 0.0)
# normalize label (aliases)
label = EWASTE_ALIAS.get(raw_label, raw_label)
# ✅ FILTER: only allow your e-waste items
if label not in EWASTE_ALLOWED:
return {
"class": "unknown",
"recyclable": False,
"confidence": round(conf, 4),
"message": "Unrecognized e-waste item. Please select Everyday Recyclables for plastic/metal/paper."
}
return {
"class": label, # return the normalized label
"recyclable": False,
"confidence": round(conf, 4),
}
@app.post("/predict")
async def predict(
file: UploadFile = File(...),
category: str = Form("regular"), # "regular" or "ewaste"
):
try:
image_bytes = await file.read()
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
if category == "ewaste":
return roboflow_predict(image_bytes)
return hf_predict(image)
except Exception as e:
return {"error": "Prediction failed", "detail": str(e)}