import io import os from typing import Dict, Any import requests import torch from PIL import Image from fastapi import FastAPI, File, UploadFile, Form from fastapi.middleware.cors import CORSMiddleware from transformers import AutoImageProcessor, SiglipForImageClassification # ---------------------------- # Config # ---------------------------- HF_MODEL_ID = "prithivMLmods/Trash-Net" ROBOFLOW_API_KEY = os.getenv("ROBOFLOW_API_KEY", "") ROBOFLOW_MODEL_ID = os.getenv("ROBOFLOW_MODEL_ID", "e-waste-2ecoq/2") CONFIDENCE_THRESHOLD = 0.70 MARGIN_THRESHOLD = 0.15 RECYCLABLE = {"cardboard", "glass", "metal", "paper", "plastic"} # ---------------------------- # ✅ E-waste label control (ADD THIS) # ---------------------------- EWASTE_ALLOWED = { "keyboards", "mobile", "mouses", "tv", "camera", "laptop", "microwave", } # Map common label variations -> your allowed labels EWASTE_ALIAS = { "keyboard": "keyboards", "keyboards": "keyboards", "mouse": "mouses", "mice": "mouses", "mouses": "mouses", "phone": "mobile", "phones": "mobile", "cellphone": "mobile", "cell phone": "mobile", "smartphone": "mobile", "mobile": "mobile", "television": "tv", "tvs": "tv", "tv": "tv", "camera": "camera", "cameras": "camera", "laptop": "laptop", "laptops": "laptop", "microwave": "microwave", "microwaves": "microwave", } # ---------------------------- # Load HF model once # ---------------------------- PROCESSOR = AutoImageProcessor.from_pretrained(HF_MODEL_ID) MODEL = SiglipForImageClassification.from_pretrained(HF_MODEL_ID) MODEL.eval() id2label = MODEL.config.id2label CLASS_NAMES = [id2label[i] for i in range(len(id2label))] # ---------------------------- # App # ---------------------------- app = FastAPI() app.add_middleware( CORSMiddleware, allow_origins=["*"], # later you can put your firebase domain only allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) @app.get("/health") def health(): return {"ok": True} def hf_predict(image: Image.Image) -> Dict[str, Any]: inputs = PROCESSOR(images=image, return_tensors="pt") with torch.no_grad(): outputs = MODEL(**inputs) logits = outputs.logits probs_tensor = torch.softmax(logits, dim=1)[0] top2 = torch.topk(probs_tensor, k=2) top1_idx = int(top2.indices[0].item()) top2_idx = int(top2.indices[1].item()) top1_prob = float(top2.values[0].item()) top2_prob = float(top2.values[1].item()) margin = top1_prob - top2_prob pred_class = CLASS_NAMES[top1_idx] if top1_prob < CONFIDENCE_THRESHOLD or margin < MARGIN_THRESHOLD: return { "class": "unknown", "recyclable": False, "confidence": round(top1_prob, 4), "message": "Not confident enough. Retake photo (good lighting, centered, plain background).", "debug": { "top1": {"class": CLASS_NAMES[top1_idx], "prob": round(top1_prob, 4)}, "top2": {"class": CLASS_NAMES[top2_idx], "prob": round(top2_prob, 4)}, "margin": round(margin, 4), }, } return { "class": pred_class, "recyclable": pred_class in RECYCLABLE, "confidence": round(top1_prob, 4), } def roboflow_predict(image_bytes: bytes) -> Dict[str, Any]: if not ROBOFLOW_API_KEY: return {"error": "ROBOFLOW_API_KEY is missing in Space secrets/env."} url = f"https://detect.roboflow.com/{ROBOFLOW_MODEL_ID}?api_key={ROBOFLOW_API_KEY}" resp = requests.post( url, files={"file": ("image.jpg", image_bytes, "image/jpeg")}, timeout=60, ) if resp.status_code != 200: return {"error": f"Roboflow error {resp.status_code}", "detail": resp.text} data = resp.json() preds = data.get("predictions", []) if not preds: return {"class": "unknown", "recyclable": False, "confidence": 0.0} best = max(preds, key=lambda p: p.get("confidence", 0) or 0) raw_label = str(best.get("class", "unknown")).strip().lower() conf = float(best.get("confidence", 0.0) or 0.0) # normalize label (aliases) label = EWASTE_ALIAS.get(raw_label, raw_label) # ✅ FILTER: only allow your e-waste items if label not in EWASTE_ALLOWED: return { "class": "unknown", "recyclable": False, "confidence": round(conf, 4), "message": "Unrecognized e-waste item. Please select Everyday Recyclables for plastic/metal/paper." } return { "class": label, # return the normalized label "recyclable": False, "confidence": round(conf, 4), } @app.post("/predict") async def predict( file: UploadFile = File(...), category: str = Form("regular"), # "regular" or "ewaste" ): try: image_bytes = await file.read() image = Image.open(io.BytesIO(image_bytes)).convert("RGB") if category == "ewaste": return roboflow_predict(image_bytes) return hf_predict(image) except Exception as e: return {"error": "Prediction failed", "detail": str(e)}