# app.py — TrashTrack Turbo (compatível com ESP32 + multipart/form-data) import os # Evita erro do libgomp e excesso de threads no CPU os.environ["OMP_NUM_THREADS"] = "1" os.environ["TOKENIZERS_PARALLELISM"] = "false" from fastapi import FastAPI, File, UploadFile from fastapi.responses import JSONResponse from PIL import Image, ImageOps import io, time, torch import numpy as np from transformers import AutoProcessor, AutoModel # ============================================== # ⚙️ Configurações # ============================================== MODEL_ID = "google/siglip-so400m-patch14-384" # modelo de maior assertividade device = "cuda" if torch.cuda.is_available() else "cpu" # Detecta automaticamente se há torchvision try: import torchvision # noqa: F401 USE_FAST = True except Exception: USE_FAST = False print(f"🚀 Carregando modelo {MODEL_ID} (use_fast={USE_FAST})...") # Usa caminho "slow" se não tiver torchvision processor = AutoProcessor.from_pretrained(MODEL_ID, use_fast=USE_FAST) model = AutoModel.from_pretrained( MODEL_ID, dtype=torch.float16 if device == "cuda" else None ).to(device).eval() print("✅ Modelo carregado com sucesso.") # ============================================== # 📋 Classes (PT + EN) — **SEM VIDRO** # ============================================== labels = { "plastico": [ "plástico", "garrafa PET", "tampinha plástica", "sacola plástica", "plastic bottle" ], "papel": [ "papel", "folha", "envelope de papel", "paper sheet", "paper wrapper" ], "metal": [ "lata", "alumínio", "tampinha metálica", "metal cap", "can" ], } def _promptize(term: str) -> str: return f"centered {term} on a white background; ignore the background; classify only the object" texts = [_promptize(t) for group in labels.values() for t in group] # ============================================== # 🔧 Util — recorte do foreground ignorando fundo branco # ============================================== def crop_foreground_ignore_white(pil: Image.Image) -> Image.Image: img = pil.convert("RGB") arr = np.array(img) r, g, b = arr[..., 0], arr[..., 1], arr[..., 2] whiteish = (r > 230) & (g > 230) & (b > 230) fg = ~whiteish if fg.sum() < 500: w, h = img.size cw, ch = int(w * 0.8), int(h * 0.8) left, top = (w - cw) // 2, (h - ch) // 2 return img.crop((left, top, left + cw, top + ch)) ys, xs = np.where(fg) y0, y1 = ys.min(), ys.max() x0, x1 = xs.min(), xs.max() py, px = int(0.03 * img.height), int(0.03 * img.width) y0 = max(0, y0 - py); y1 = min(img.height - 1, y1 + py) x0 = max(0, x0 - px); x1 = min(img.width - 1, x1 + px) return img.crop((x0, y0, x1 + 1, y1 + 1)) # ============================================== # 🌐 App FastAPI # ============================================== app = FastAPI(title="TrashTrack Turbo — ESP32 Compatible") @app.get("/") def root(): return {"ok": True, "model": MODEL_ID, "mode": "multipart/files[]", "classes": list(labels.keys())} @app.post("/predict") async def predict(files_: list[UploadFile] = File(..., alias="files[]")): try: t0 = time.time() results = [] for f in files_: data = await f.read() image = Image.open(io.BytesIO(data)) image = ImageOps.exif_transpose(image).convert("RGB") image = crop_foreground_ignore_white(image) text_inputs = processor(text=texts, return_tensors="pt", padding=True).to(device) image_inputs = processor(images=image, return_tensors="pt").to(device) with torch.inference_mode(): txt_emb, img_emb = None, None if hasattr(model, "get_text_features"): txt_emb = model.get_text_features(**text_inputs) if hasattr(model, "get_image_features"): img_emb = model.get_image_features(**image_inputs) if txt_emb is None or img_emb is None: joint = processor(text=texts, images=image, return_tensors="pt", padding=True).to(device) out = model(**joint) if txt_emb is None: txt_emb = getattr(out, "text_embeds", getattr(out, "text_embeds_projected", None)) if img_emb is None: img_emb = getattr(out, "image_embeds", getattr(out, "image_embeds_projected", None)) img_emb = torch.nn.functional.normalize(img_emb, dim=-1) txt_emb = torch.nn.functional.normalize(txt_emb, dim=-1) logits = (img_emb @ txt_emb.t()).squeeze(0) probs = torch.softmax(logits.float().cpu(), dim=-1).tolist() idx, scores = 0, {} for key, group in labels.items(): g = probs[idx: idx + len(group)] s = sum(g) / len(g) scores[key] = s idx += len(group) best = max(scores, key=scores.get) conf = round(float(scores[best]), 3) results.append((best, conf)) votos = {} for r, c in results: votos.setdefault(r, []).append(c) final = max(votos, key=lambda k: sum(votos[k]) / len(votos[k])) conf = round(sum(votos[final]) / len(votos[final]), 3) latency = round(time.time() - t0, 2) print(f"[OK] {final} ({conf}) em {latency}s") return JSONResponse({"label": final, "conf": conf, "latency_s": latency}) except Exception as e: print(f"[ERRO] {e}") return JSONResponse({"error": str(e)}, status_code=500)