froidhj commited on
Commit
0ca0c79
·
verified ·
1 Parent(s): 7b7ccff

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +78 -21
app.py CHANGED
@@ -1,7 +1,8 @@
1
  # app.py
2
- from fastapi import FastAPI, Request, Response
3
- from PIL import Image
4
- import io, os, torch
 
5
  from transformers import AutoImageProcessor, AutoModelForImageClassification
6
 
7
  # ========= CONFIG =========
@@ -21,55 +22,111 @@ torch.set_grad_enabled(False)
21
  torch.set_num_threads(1)
22
  torch.set_num_interop_threads(1)
23
 
 
 
 
24
  # ========= CARREGAMENTO =========
25
- processor = AutoImageProcessor.from_pretrained(MODEL_ID)
 
26
  model = AutoModelForImageClassification.from_pretrained(MODEL_ID)
27
  model.eval()
28
 
29
  app = FastAPI()
30
 
31
- def predict_image_bytes(img_bytes: bytes) -> str:
32
- img = Image.open(io.BytesIO(img_bytes)).convert("RGB")
33
- # Reduz um pouco para acelerar sem perder muito
34
- img = img.resize((256, 256))
35
-
36
- inputs = processor(images=img, return_tensors="pt")
37
- logits = model(**inputs).logits
38
- idx = int(logits.softmax(-1).argmax(-1))
39
- label_en = model.config.id2label[idx].lower()
40
 
41
- # Converte apenas se for uma das 4; senão marca como não identificado
 
 
42
  return MAP_PT.get(label_en, "nao_identificado")
43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  @app.get("/health")
45
  def health():
46
  return {"ok": True, "model": MODEL_ID}
47
 
48
  @app.post("/predict")
49
- async def predict(request: Request):
50
  """
51
- Espera: bytes JPEG (application/octet-stream)
 
 
 
 
 
52
  Retorna: texto puro — 'vidro' | 'papel' | 'plastico' | 'metal' | 'nao_identificado'
53
  """
54
  try:
 
55
  ctype = (request.headers.get("content-type") or "").lower()
56
 
57
- if "application/octet-stream" in ctype or "image/jpeg" in ctype:
 
 
 
 
 
58
  img_bytes = await request.body()
 
59
  else:
60
- # fallback opcional para JSON base64 (testes manuais)
61
  data = await request.json()
62
- import base64
63
- b64 = (data.get("image_b64") or "").split(",")[-1]
 
 
64
  img_bytes = base64.b64decode(b64) if b64 else b""
65
 
66
  if not img_bytes:
67
  return Response("nao_identificado", media_type="text/plain")
68
 
69
- label = predict_image_bytes(img_bytes)
70
  if label not in ALLOWED:
71
  label = "nao_identificado"
72
 
73
  return Response(label, media_type="text/plain")
 
 
 
 
74
  except Exception:
 
 
75
  return Response("nao_identificado", media_type="text/plain")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # app.py
2
+ from fastapi import FastAPI, Request, Response, UploadFile, File
3
+ from fastapi.middleware.cors import CORSMiddleware
4
+ from PIL import Image, UnidentifiedImageError
5
+ import io, os, torch, base64, traceback
6
  from transformers import AutoImageProcessor, AutoModelForImageClassification
7
 
8
  # ========= CONFIG =========
 
22
  torch.set_num_threads(1)
23
  torch.set_num_interop_threads(1)
24
 
25
+ # Evita ataques com imagens gigantes (e economiza RAM)
26
+ Image.MAX_IMAGE_PIXELS = 25_000_000 # ~25MP de teto (bem alto p/ segurança)
27
+
28
  # ========= CARREGAMENTO =========
29
+ # Removemos o aviso de processador "lento"
30
+ processor = AutoImageProcessor.from_pretrained(MODEL_ID, use_fast=True)
31
  model = AutoModelForImageClassification.from_pretrained(MODEL_ID)
32
  model.eval()
33
 
34
  app = FastAPI()
35
 
36
+ # CORS (opcional; útil para testes via browser/front-end)
37
+ app.add_middleware(
38
+ CORSMiddleware,
39
+ allow_origins=["*"], allow_credentials=True,
40
+ allow_methods=["*"], allow_headers=["*"],
41
+ )
 
 
 
42
 
43
+ def _to_label_pt(label_en: str) -> str:
44
+ # Normaliza e converte somente se estiver mapeado
45
+ label_en = (label_en or "").strip().lower()
46
  return MAP_PT.get(label_en, "nao_identificado")
47
 
48
+ def _predict_image_bytes(img_bytes: bytes) -> str:
49
+ # Lê, converte e reduz um pouco p/ acelerar, mantendo boa acurácia
50
+ with Image.open(io.BytesIO(img_bytes)) as img:
51
+ img = img.convert("RGB") # garante 3 canais
52
+ img = img.resize((256, 256)) # tradeoff bom p/ CPU do Space
53
+
54
+ with torch.inference_mode():
55
+ inputs = processor(images=img, return_tensors="pt")
56
+ logits = model(**inputs).logits
57
+ idx = int(logits.softmax(-1).argmax(-1))
58
+ label_en = model.config.id2label[idx]
59
+ return _to_label_pt(label_en)
60
+
61
+ # ========= ROTAS =========
62
+ @app.get("/")
63
+ def root():
64
+ # Retornar 200 aqui ajuda o “wake up” do Space pelo firmware
65
+ return {"ok": True, "message": "TrashNet classifier up", "model": MODEL_ID}
66
+
67
  @app.get("/health")
68
  def health():
69
  return {"ok": True, "model": MODEL_ID}
70
 
71
  @app.post("/predict")
72
+ async def predict(request: Request, file: UploadFile | None = File(default=None)):
73
  """
74
+ Aceita:
75
+ - application/octet-stream (raw JPEG no corpo)
76
+ - image/jpeg (raw JPEG no corpo)
77
+ - multipart/form-data (campo: file)
78
+ - application/json {"image_b64": "..."} (dataURL ou base64 puro)
79
+
80
  Retorna: texto puro — 'vidro' | 'papel' | 'plastico' | 'metal' | 'nao_identificado'
81
  """
82
  try:
83
+ img_bytes: bytes = b""
84
  ctype = (request.headers.get("content-type") or "").lower()
85
 
86
+ if file is not None:
87
+ # multipart/form-data
88
+ img_bytes = await file.read()
89
+
90
+ elif "application/octet-stream" in ctype or "image/jpeg" in ctype:
91
+ # raw bytes (ESP32 manda deste jeito)
92
  img_bytes = await request.body()
93
+
94
  else:
95
+ # fallback para JSON base64 (útil em testes manuais via Postman)
96
  data = await request.json()
97
+ b64 = (data.get("image_b64") or "")
98
+ if "," in b64:
99
+ # aceita dataURL: "data:image/jpeg;base64,...."
100
+ b64 = b64.split(",", 1)[1]
101
  img_bytes = base64.b64decode(b64) if b64 else b""
102
 
103
  if not img_bytes:
104
  return Response("nao_identificado", media_type="text/plain")
105
 
106
+ label = _predict_image_bytes(img_bytes)
107
  if label not in ALLOWED:
108
  label = "nao_identificado"
109
 
110
  return Response(label, media_type="text/plain")
111
+
112
+ except UnidentifiedImageError:
113
+ # bytes não eram uma imagem válida
114
+ return Response("nao_identificado", media_type="text/plain")
115
  except Exception:
116
+ # loga stack trace no console do Space para debug
117
+ traceback.print_exc()
118
  return Response("nao_identificado", media_type="text/plain")
119
+
120
+ # ========= WARM-UP =========
121
+ @app.on_event("startup")
122
+ def _warmup():
123
+ # Faz uma inferência boba para “carregar” tudo e reduzir a latência da 1ª chamada real
124
+ try:
125
+ dummy = Image.new("RGB", (256, 256), (127, 127, 127))
126
+ with torch.inference_mode():
127
+ inputs = processor(images=dummy, return_tensors="pt")
128
+ _ = model(**inputs).logits
129
+ print("[startup] warm-up ok")
130
+ except Exception:
131
+ traceback.print_exc()
132
+ print("[startup] warm-up falhou (seguindo sem)")