trashnet-server / app.py
froidhj's picture
Update app.py
1b4701d verified
raw
history blame
3.86 kB
# app.py
from fastapi import FastAPI, Request, Response
from fastapi.middleware.cors import CORSMiddleware
from PIL import Image, UnidentifiedImageError
import io, torch, base64, traceback
from transformers import AutoImageProcessor, AutoModelForImageClassification
# ========= CONFIG =========
MODEL_ID = "prithivMLmods/Trash-Net"
# Mantemos só estas 4 classes em PT-BR; o resto vira "nao_identificado"
MAP_PT = {
"glass": "vidro",
"metal": "metal",
"paper": "papel",
"plastic": "plastico",
}
ALLOWED = set(MAP_PT.values())
# ========= OTIMIZAÇÕES (CPU do Space) =========
torch.set_grad_enabled(False)
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
# Evita imagens gigantes
Image.MAX_IMAGE_PIXELS = 25_000_000
# ========= CARREGAMENTO =========
# Se aparecer aviso "use_fast=True mas torchvision não disponível",
# é só um warning; pode trocar para use_fast=False se quiser ocultar.
processor = AutoImageProcessor.from_pretrained(MODEL_ID, use_fast=True)
model = AutoModelForImageClassification.from_pretrained(MODEL_ID)
model.eval()
app = FastAPI()
# CORS (opcional)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], allow_credentials=True,
allow_methods=["*"], allow_headers=["*"],
)
def _to_label_pt(label_en: str) -> str:
return MAP_PT.get((label_en or "").strip().lower(), "nao_identificado")
def _predict_image_bytes(img_bytes: bytes) -> str:
with Image.open(io.BytesIO(img_bytes)) as img:
img = img.convert("RGB")
img = img.resize((256, 256)) # tradeoff bom para CPU do Space
with torch.inference_mode():
inputs = processor(images=img, return_tensors="pt")
logits = model(**inputs).logits
idx = int(logits.argmax(-1)) # softmax não precisa para argmax
label_en = model.config.id2label[idx]
return _to_label_pt(label_en)
# ========= ROTAS =========
@app.get("/")
def root():
return {"ok": True, "message": "TrashNet classifier up", "model": MODEL_ID}
@app.get("/health")
def health():
return {"ok": True, "model": MODEL_ID}
@app.post("/predict")
async def predict(request: Request):
"""
Aceita:
- application/octet-stream (raw JPEG no corpo)
- image/jpeg (raw JPEG no corpo)
- application/json {"image_b64": "..."} (dataURL ou base64 puro)
Retorna: texto puro — 'vidro' | 'papel' | 'plastico' | 'metal' | 'nao_identificado'
"""
try:
ctype = (request.headers.get("content-type") or "").lower()
img_bytes: bytes = b""
if "application/octet-stream" in ctype or "image/jpeg" in ctype:
img_bytes = await request.body()
else:
# fallback: JSON base64
data = await request.json()
b64 = (data.get("image_b64") or "")
if "," in b64: # dataURL
b64 = b64.split(",", 1)[1]
img_bytes = base64.b64decode(b64) if b64 else b""
if not img_bytes:
return Response("nao_identificado", media_type="text/plain")
label = _predict_image_bytes(img_bytes)
if label not in ALLOWED:
label = "nao_identificado"
return Response(label, media_type="text/plain")
except UnidentifiedImageError:
return Response("nao_identificado", media_type="text/plain")
except Exception:
traceback.print_exc()
return Response("nao_identificado", media_type="text/plain")
# ========= WARM-UP =========
@app.on_event("startup")
def _warmup():
try:
dummy = Image.new("RGB", (256, 256), (127, 127, 127))
with torch.inference_mode():
inputs = processor(images=dummy, return_tensors="pt")
_ = model(**inputs).logits
print("[startup] warm-up ok")
except Exception:
traceback.print_exc()
print("[startup] warm-up falhou (seguindo sem)")