| | import io |
| | import os |
| | from typing import Dict, Any |
| |
|
| | import requests |
| | import torch |
| | from PIL import Image |
| | from fastapi import FastAPI, File, UploadFile, Form |
| | from fastapi.middleware.cors import CORSMiddleware |
| | from transformers import AutoImageProcessor, SiglipForImageClassification |
| |
|
| | |
| | |
| | |
| | HF_MODEL_ID = "prithivMLmods/Trash-Net" |
| |
|
| | ROBOFLOW_API_KEY = os.getenv("ROBOFLOW_API_KEY", "") |
| | ROBOFLOW_MODEL_ID = os.getenv("ROBOFLOW_MODEL_ID", "e-waste-2ecoq/2") |
| |
|
| | CONFIDENCE_THRESHOLD = 0.70 |
| | MARGIN_THRESHOLD = 0.15 |
| |
|
| | RECYCLABLE = {"cardboard", "glass", "metal", "paper", "plastic"} |
| |
|
| | |
| | |
| | |
| | EWASTE_ALLOWED = { |
| | "keyboards", |
| | "mobile", |
| | "mouses", |
| | "tv", |
| | "camera", |
| | "laptop", |
| | "microwave", |
| | } |
| |
|
| | |
| | EWASTE_ALIAS = { |
| | "keyboard": "keyboards", |
| | "keyboards": "keyboards", |
| | "mouse": "mouses", |
| | "mice": "mouses", |
| | "mouses": "mouses", |
| | "phone": "mobile", |
| | "phones": "mobile", |
| | "cellphone": "mobile", |
| | "cell phone": "mobile", |
| | "smartphone": "mobile", |
| | "mobile": "mobile", |
| | "television": "tv", |
| | "tvs": "tv", |
| | "tv": "tv", |
| | "camera": "camera", |
| | "cameras": "camera", |
| | "laptop": "laptop", |
| | "laptops": "laptop", |
| | "microwave": "microwave", |
| | "microwaves": "microwave", |
| | } |
| |
|
| | |
| | |
| | |
| | PROCESSOR = AutoImageProcessor.from_pretrained(HF_MODEL_ID) |
| | MODEL = SiglipForImageClassification.from_pretrained(HF_MODEL_ID) |
| | MODEL.eval() |
| |
|
| | id2label = MODEL.config.id2label |
| | CLASS_NAMES = [id2label[i] for i in range(len(id2label))] |
| |
|
| | |
| | |
| | |
| | app = FastAPI() |
| |
|
| | app.add_middleware( |
| | CORSMiddleware, |
| | allow_origins=["*"], |
| | allow_credentials=True, |
| | allow_methods=["*"], |
| | allow_headers=["*"], |
| | ) |
| |
|
| | @app.get("/health") |
| | def health(): |
| | return {"ok": True} |
| |
|
| | def hf_predict(image: Image.Image) -> Dict[str, Any]: |
| | inputs = PROCESSOR(images=image, return_tensors="pt") |
| |
|
| | with torch.no_grad(): |
| | outputs = MODEL(**inputs) |
| | logits = outputs.logits |
| | probs_tensor = torch.softmax(logits, dim=1)[0] |
| |
|
| | top2 = torch.topk(probs_tensor, k=2) |
| | top1_idx = int(top2.indices[0].item()) |
| | top2_idx = int(top2.indices[1].item()) |
| |
|
| | top1_prob = float(top2.values[0].item()) |
| | top2_prob = float(top2.values[1].item()) |
| | margin = top1_prob - top2_prob |
| |
|
| | pred_class = CLASS_NAMES[top1_idx] |
| |
|
| | if top1_prob < CONFIDENCE_THRESHOLD or margin < MARGIN_THRESHOLD: |
| | return { |
| | "class": "unknown", |
| | "recyclable": False, |
| | "confidence": round(top1_prob, 4), |
| | "message": "Not confident enough. Retake photo (good lighting, centered, plain background).", |
| | "debug": { |
| | "top1": {"class": CLASS_NAMES[top1_idx], "prob": round(top1_prob, 4)}, |
| | "top2": {"class": CLASS_NAMES[top2_idx], "prob": round(top2_prob, 4)}, |
| | "margin": round(margin, 4), |
| | }, |
| | } |
| |
|
| | return { |
| | "class": pred_class, |
| | "recyclable": pred_class in RECYCLABLE, |
| | "confidence": round(top1_prob, 4), |
| | } |
| |
|
| | def roboflow_predict(image_bytes: bytes) -> Dict[str, Any]: |
| | if not ROBOFLOW_API_KEY: |
| | return {"error": "ROBOFLOW_API_KEY is missing in Space secrets/env."} |
| |
|
| | url = f"https://detect.roboflow.com/{ROBOFLOW_MODEL_ID}?api_key={ROBOFLOW_API_KEY}" |
| |
|
| | resp = requests.post( |
| | url, |
| | files={"file": ("image.jpg", image_bytes, "image/jpeg")}, |
| | timeout=60, |
| | ) |
| |
|
| | if resp.status_code != 200: |
| | return {"error": f"Roboflow error {resp.status_code}", "detail": resp.text} |
| |
|
| | data = resp.json() |
| | preds = data.get("predictions", []) |
| | if not preds: |
| | return {"class": "unknown", "recyclable": False, "confidence": 0.0} |
| |
|
| | best = max(preds, key=lambda p: p.get("confidence", 0) or 0) |
| | raw_label = str(best.get("class", "unknown")).strip().lower() |
| | conf = float(best.get("confidence", 0.0) or 0.0) |
| |
|
| | |
| | label = EWASTE_ALIAS.get(raw_label, raw_label) |
| |
|
| | |
| | if label not in EWASTE_ALLOWED: |
| | return { |
| | "class": "unknown", |
| | "recyclable": False, |
| | "confidence": round(conf, 4), |
| | "message": "Unrecognized e-waste item. Please select Everyday Recyclables for plastic/metal/paper." |
| | } |
| |
|
| | return { |
| | "class": label, |
| | "recyclable": False, |
| | "confidence": round(conf, 4), |
| | } |
| |
|
| | @app.post("/predict") |
| | async def predict( |
| | file: UploadFile = File(...), |
| | category: str = Form("regular"), |
| | ): |
| | try: |
| | image_bytes = await file.read() |
| | image = Image.open(io.BytesIO(image_bytes)).convert("RGB") |
| |
|
| | if category == "ewaste": |
| | return roboflow_predict(image_bytes) |
| |
|
| | return hf_predict(image) |
| |
|
| | except Exception as e: |
| | return {"error": "Prediction failed", "detail": str(e)} |
| |
|