trashnet-server / app.py
froidhj's picture
Update app.py
929f4df verified
raw
history blame
4 kB
# app.py
from fastapi import FastAPI, Request, Response
from fastapi.middleware.cors import CORSMiddleware
from PIL import Image, UnidentifiedImageError
import io, torch, base64, traceback, random
from transformers import AutoImageProcessor, AutoModelForImageClassification
# ========= CONFIG =========
MODEL_ID = "prithivMLmods/Trash-Net"
# Mapa EN -> PT (apenas 4 classes desejadas)
MAP_PT = {
"glass": "vidro",
"metal": "metal",
"paper": "papel",
"plastic": "plastico",
}
ALLOWED = ["plastico", "papel", "vidro", "metal"] # ordem fixa p/ random
# ========= OTIMIZAÇÕES (CPU do Space) =========
torch.set_grad_enabled(False)
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
# Evita imagens gigantes
Image.MAX_IMAGE_PIXELS = 25_000_000
# ========= CARREGAMENTO =========
processor = AutoImageProcessor.from_pretrained(MODEL_ID, use_fast=True)
model = AutoModelForImageClassification.from_pretrained(MODEL_ID)
model.eval()
app = FastAPI()
# CORS (opcional)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], allow_credentials=True,
allow_methods=["*"], allow_headers=["*"],
)
def _force_allowed(label_en: str | None) -> str:
"""Converte label EN para PT se mapeado; caso contrário, escolhe aleatoriamente uma das 4."""
if label_en:
pt = MAP_PT.get(label_en.strip().lower())
if pt in ALLOWED:
return pt
# fallback forçado
return random.choice(ALLOWED)
def _predict_image_bytes(img_bytes: bytes) -> str:
with Image.open(io.BytesIO(img_bytes)) as img:
img = img.convert("RGB")
img = img.resize((256, 256)) # tradeoff bom para CPU do Space
with torch.inference_mode():
inputs = processor(images=img, return_tensors="pt")
logits = model(**inputs).logits
idx = int(logits.argmax(-1))
label_en = model.config.id2label[idx]
return _force_allowed(label_en)
# ========= ROTAS =========
@app.get("/")
def root():
return {"ok": True, "message": "TrashNet classifier up", "model": MODEL_ID}
@app.get("/health")
def health():
return {"ok": True, "model": MODEL_ID}
@app.post("/predict")
async def predict(request: Request):
"""
Aceita:
- application/octet-stream (raw JPEG no corpo)
- image/jpeg (raw JPEG no corpo)
- application/json {"image_b64": "..."} (dataURL ou base64 puro)
Retorna SEMPRE: texto puro — 'vidro' | 'papel' | 'plastico' | 'metal'
(nunca 'nao_identificado')
"""
try:
ctype = (request.headers.get("content-type") or "").lower()
img_bytes: bytes = b""
if "application/octet-stream" in ctype or "image/jpeg" in ctype:
img_bytes = await request.body()
else:
# fallback: JSON base64
data = await request.json()
b64 = (data.get("image_b64") or "")
if "," in b64: # dataURL
b64 = b64.split(",", 1)[1]
img_bytes = base64.b64decode(b64) if b64 else b""
# Se veio vazio, ainda assim devolve um dos 4
if not img_bytes:
return Response(random.choice(ALLOWED), media_type="text/plain")
label = _predict_image_bytes(img_bytes)
# Por garantia, força para uma das 4
if label not in ALLOWED:
label = random.choice(ALLOWED)
return Response(label, media_type="text/plain")
except UnidentifiedImageError:
return Response(random.choice(ALLOWED), media_type="text/plain")
except Exception:
traceback.print_exc()
return Response(random.choice(ALLOWED), media_type="text/plain")
# ========= WARM-UP =========
@app.on_event("startup")
def _warmup():
try:
dummy = Image.new("RGB", (256, 256), (127, 127, 127))
with torch.inference_mode():
inputs = processor(images=dummy, return_tensors="pt")
_ = model(**inputs).logits
print("[startup] warm-up ok")
except Exception:
traceback.print_exc()
print("[startup] warm-up falhou (seguindo sem)")