trashnet-server / app.py
froidhj's picture
Update app.py
c3a6888 verified
raw
history blame
2.6 kB
# app.py
from fastapi import FastAPI, Request, Response
from PIL import Image
import io, os, torch
from transformers import AutoImageProcessor, AutoModelForImageClassification
# ========= CONFIG =========
MODEL_ID = "AmadFR/ecovision_mobilenetv3"
# Mapeamento para português (apenas as 4 classes desejadas)
MAP_PT = {
"glass": "vidro",
"metal": "metal",
"paper": "papel",
"plastic": "plastico"
}
ALLOWED = set(MAP_PT.values())
# ========= OTIMIZAÇÕES =========
torch.set_grad_enabled(False)
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
# ========= CARREGA MODELO =========
processor = AutoImageProcessor.from_pretrained(MODEL_ID)
model = AutoModelForImageClassification.from_pretrained(MODEL_ID)
model.eval()
app = FastAPI()
def predict_image_bytes(img_bytes: bytes) -> str:
"""Recebe bytes JPEG e devolve um dos rótulos simplificados."""
img = Image.open(io.BytesIO(img_bytes)).convert("RGB")
img = img.resize((224, 224)) # acelera a inferência
inputs = processor(images=img, return_tensors="pt")
outputs = model(**inputs)
logits = outputs.logits
predicted_class_idx = int(logits.argmax(-1))
label_en = model.config.id2label[predicted_class_idx].lower()
# Converte para português
label_pt = MAP_PT.get(label_en, "nao_identificado")
return label_pt
@app.get("/health")
def health():
return {"ok": True, "model": MODEL_ID}
@app.post("/predict")
async def predict(request: Request):
"""
Espera: imagem JPEG (application/octet-stream)
Retorna: texto puro - 'vidro', 'papel', 'plastico', 'metal' ou 'nao_identificado'
"""
try:
content_type = (request.headers.get("content-type") or "").lower()
# ESP32 envia como application/octet-stream
if "application/octet-stream" in content_type or "image/jpeg" in content_type:
img_bytes = await request.body()
else:
# fallback para JSON base64 (para testes)
data = await request.json()
import base64
b64 = data.get("image_b64", "").split(",")[-1]
img_bytes = base64.b64decode(b64) if b64 else b""
if not img_bytes:
return Response("nao_identificado", media_type="text/plain")
label = predict_image_bytes(img_bytes)
# filtro final — só 4 materiais
if label not in ALLOWED:
label = "nao_identificado"
return Response(label, media_type="text/plain")
except Exception as e:
print("Erro:", e)
return Response("nao_identificado", media_type="text/plain")