Spaces:
Sleeping
Sleeping
| # app.py | |
| from fastapi import FastAPI, Request, Response | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from PIL import Image, UnidentifiedImageError | |
| import io, torch, base64, traceback, random | |
| from transformers import AutoImageProcessor, AutoModelForImageClassification | |
| # ========= CONFIG ========= | |
| MODEL_ID = "prithivMLmods/Trash-Net" | |
| # Mapa EN -> PT (apenas 4 classes desejadas) | |
| MAP_PT = { | |
| "glass": "vidro", | |
| "metal": "metal", | |
| "paper": "papel", | |
| "plastic": "plastico", | |
| } | |
| ALLOWED = ["plastico", "papel", "vidro", "metal"] # ordem fixa p/ random | |
| # ========= OTIMIZAÇÕES (CPU do Space) ========= | |
| torch.set_grad_enabled(False) | |
| torch.set_num_threads(1) | |
| torch.set_num_interop_threads(1) | |
| # Evita imagens gigantes | |
| Image.MAX_IMAGE_PIXELS = 25_000_000 | |
| # ========= CARREGAMENTO ========= | |
| processor = AutoImageProcessor.from_pretrained(MODEL_ID, use_fast=True) | |
| model = AutoModelForImageClassification.from_pretrained(MODEL_ID) | |
| model.eval() | |
| app = FastAPI() | |
| # CORS (opcional) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], allow_credentials=True, | |
| allow_methods=["*"], allow_headers=["*"], | |
| ) | |
| def _force_allowed(label_en: str | None) -> str: | |
| """Converte label EN para PT se mapeado; caso contrário, escolhe aleatoriamente uma das 4.""" | |
| if label_en: | |
| pt = MAP_PT.get(label_en.strip().lower()) | |
| if pt in ALLOWED: | |
| return pt | |
| # fallback forçado | |
| return random.choice(ALLOWED) | |
| def _predict_image_bytes(img_bytes: bytes) -> str: | |
| with Image.open(io.BytesIO(img_bytes)) as img: | |
| img = img.convert("RGB") | |
| img = img.resize((256, 256)) # tradeoff bom para CPU do Space | |
| with torch.inference_mode(): | |
| inputs = processor(images=img, return_tensors="pt") | |
| logits = model(**inputs).logits | |
| idx = int(logits.argmax(-1)) | |
| label_en = model.config.id2label[idx] | |
| return _force_allowed(label_en) | |
| # ========= ROTAS ========= | |
| def root(): | |
| return {"ok": True, "message": "TrashNet classifier up", "model": MODEL_ID} | |
| def health(): | |
| return {"ok": True, "model": MODEL_ID} | |
| async def predict(request: Request): | |
| """ | |
| Aceita: | |
| - application/octet-stream (raw JPEG no corpo) | |
| - image/jpeg (raw JPEG no corpo) | |
| - application/json {"image_b64": "..."} (dataURL ou base64 puro) | |
| Retorna SEMPRE: texto puro — 'vidro' | 'papel' | 'plastico' | 'metal' | |
| (nunca 'nao_identificado') | |
| """ | |
| try: | |
| ctype = (request.headers.get("content-type") or "").lower() | |
| img_bytes: bytes = b"" | |
| if "application/octet-stream" in ctype or "image/jpeg" in ctype: | |
| img_bytes = await request.body() | |
| else: | |
| # fallback: JSON base64 | |
| data = await request.json() | |
| b64 = (data.get("image_b64") or "") | |
| if "," in b64: # dataURL | |
| b64 = b64.split(",", 1)[1] | |
| img_bytes = base64.b64decode(b64) if b64 else b"" | |
| # Se veio vazio, ainda assim devolve um dos 4 | |
| if not img_bytes: | |
| return Response(random.choice(ALLOWED), media_type="text/plain") | |
| label = _predict_image_bytes(img_bytes) | |
| # Por garantia, força para uma das 4 | |
| if label not in ALLOWED: | |
| label = random.choice(ALLOWED) | |
| return Response(label, media_type="text/plain") | |
| except UnidentifiedImageError: | |
| return Response(random.choice(ALLOWED), media_type="text/plain") | |
| except Exception: | |
| traceback.print_exc() | |
| return Response(random.choice(ALLOWED), media_type="text/plain") | |
| # ========= WARM-UP ========= | |
| def _warmup(): | |
| try: | |
| dummy = Image.new("RGB", (256, 256), (127, 127, 127)) | |
| with torch.inference_mode(): | |
| inputs = processor(images=dummy, return_tensors="pt") | |
| _ = model(**inputs).logits | |
| print("[startup] warm-up ok") | |
| except Exception: | |
| traceback.print_exc() | |
| print("[startup] warm-up falhou (seguindo sem)") |