# app.py from fastapi import FastAPI, Request, Response from fastapi.middleware.cors import CORSMiddleware from PIL import Image, UnidentifiedImageError import io, torch, base64, traceback from transformers import AutoImageProcessor, AutoModelForImageClassification # ========= CONFIG ========= MODEL_ID = "prithivMLmods/Trash-Net" # Mantemos só estas 4 classes em PT-BR; o resto vira "nao_identificado" MAP_PT = { "glass": "vidro", "metal": "metal", "paper": "papel", "plastic": "plastico", } ALLOWED = set(MAP_PT.values()) # ========= OTIMIZAÇÕES (CPU do Space) ========= torch.set_grad_enabled(False) torch.set_num_threads(1) torch.set_num_interop_threads(1) # Evita imagens gigantes Image.MAX_IMAGE_PIXELS = 25_000_000 # ========= CARREGAMENTO ========= # Se aparecer aviso "use_fast=True mas torchvision não disponível", # é só um warning; pode trocar para use_fast=False se quiser ocultar. processor = AutoImageProcessor.from_pretrained(MODEL_ID, use_fast=True) model = AutoModelForImageClassification.from_pretrained(MODEL_ID) model.eval() app = FastAPI() # CORS (opcional) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) def _to_label_pt(label_en: str) -> str: return MAP_PT.get((label_en or "").strip().lower(), "nao_identificado") def _predict_image_bytes(img_bytes: bytes) -> str: with Image.open(io.BytesIO(img_bytes)) as img: img = img.convert("RGB") img = img.resize((256, 256)) # tradeoff bom para CPU do Space with torch.inference_mode(): inputs = processor(images=img, return_tensors="pt") logits = model(**inputs).logits idx = int(logits.argmax(-1)) # softmax não precisa para argmax label_en = model.config.id2label[idx] return _to_label_pt(label_en) # ========= ROTAS ========= @app.get("/") def root(): return {"ok": True, "message": "TrashNet classifier up", "model": MODEL_ID} @app.get("/health") def health(): return {"ok": True, "model": MODEL_ID} @app.post("/predict") async def predict(request: Request): """ Aceita: - application/octet-stream (raw JPEG no corpo) - image/jpeg (raw JPEG no corpo) - application/json {"image_b64": "..."} (dataURL ou base64 puro) Retorna: texto puro — 'vidro' | 'papel' | 'plastico' | 'metal' | 'nao_identificado' """ try: ctype = (request.headers.get("content-type") or "").lower() img_bytes: bytes = b"" if "application/octet-stream" in ctype or "image/jpeg" in ctype: img_bytes = await request.body() else: # fallback: JSON base64 data = await request.json() b64 = (data.get("image_b64") or "") if "," in b64: # dataURL b64 = b64.split(",", 1)[1] img_bytes = base64.b64decode(b64) if b64 else b"" if not img_bytes: return Response("nao_identificado", media_type="text/plain") label = _predict_image_bytes(img_bytes) if label not in ALLOWED: label = "nao_identificado" return Response(label, media_type="text/plain") except UnidentifiedImageError: return Response("nao_identificado", media_type="text/plain") except Exception: traceback.print_exc() return Response("nao_identificado", media_type="text/plain") # ========= WARM-UP ========= @app.on_event("startup") def _warmup(): try: dummy = Image.new("RGB", (256, 256), (127, 127, 127)) with torch.inference_mode(): inputs = processor(images=dummy, return_tensors="pt") _ = model(**inputs).logits print("[startup] warm-up ok") except Exception: traceback.print_exc() print("[startup] warm-up falhou (seguindo sem)")