froidhj commited on
Commit
1b4701d
·
verified ·
1 Parent(s): 0ca0c79

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -31
app.py CHANGED
@@ -1,8 +1,8 @@
1
  # app.py
2
- from fastapi import FastAPI, Request, Response, UploadFile, File
3
  from fastapi.middleware.cors import CORSMiddleware
4
  from PIL import Image, UnidentifiedImageError
5
- import io, os, torch, base64, traceback
6
  from transformers import AutoImageProcessor, AutoModelForImageClassification
7
 
8
  # ========= CONFIG =========
@@ -22,18 +22,19 @@ torch.set_grad_enabled(False)
22
  torch.set_num_threads(1)
23
  torch.set_num_interop_threads(1)
24
 
25
- # Evita ataques com imagens gigantes (e economiza RAM)
26
- Image.MAX_IMAGE_PIXELS = 25_000_000 # ~25MP de teto (bem alto p/ segurança)
27
 
28
  # ========= CARREGAMENTO =========
29
- # Removemos o aviso de processador "lento"
 
30
  processor = AutoImageProcessor.from_pretrained(MODEL_ID, use_fast=True)
31
  model = AutoModelForImageClassification.from_pretrained(MODEL_ID)
32
  model.eval()
33
 
34
  app = FastAPI()
35
 
36
- # CORS (opcional; útil para testes via browser/front-end)
37
  app.add_middleware(
38
  CORSMiddleware,
39
  allow_origins=["*"], allow_credentials=True,
@@ -41,27 +42,23 @@ app.add_middleware(
41
  )
42
 
43
  def _to_label_pt(label_en: str) -> str:
44
- # Normaliza e converte somente se estiver mapeado
45
- label_en = (label_en or "").strip().lower()
46
- return MAP_PT.get(label_en, "nao_identificado")
47
 
48
  def _predict_image_bytes(img_bytes: bytes) -> str:
49
- # Lê, converte e reduz um pouco p/ acelerar, mantendo boa acurácia
50
  with Image.open(io.BytesIO(img_bytes)) as img:
51
- img = img.convert("RGB") # garante 3 canais
52
- img = img.resize((256, 256)) # tradeoff bom p/ CPU do Space
53
 
54
  with torch.inference_mode():
55
  inputs = processor(images=img, return_tensors="pt")
56
  logits = model(**inputs).logits
57
- idx = int(logits.softmax(-1).argmax(-1))
58
  label_en = model.config.id2label[idx]
59
  return _to_label_pt(label_en)
60
 
61
  # ========= ROTAS =========
62
  @app.get("/")
63
  def root():
64
- # Retornar 200 aqui ajuda o “wake up” do Space pelo firmware
65
  return {"ok": True, "message": "TrashNet classifier up", "model": MODEL_ID}
66
 
67
  @app.get("/health")
@@ -69,34 +66,26 @@ def health():
69
  return {"ok": True, "model": MODEL_ID}
70
 
71
  @app.post("/predict")
72
- async def predict(request: Request, file: UploadFile | None = File(default=None)):
73
  """
74
  Aceita:
75
  - application/octet-stream (raw JPEG no corpo)
76
  - image/jpeg (raw JPEG no corpo)
77
- - multipart/form-data (campo: file)
78
  - application/json {"image_b64": "..."} (dataURL ou base64 puro)
79
 
80
  Retorna: texto puro — 'vidro' | 'papel' | 'plastico' | 'metal' | 'nao_identificado'
81
  """
82
  try:
83
- img_bytes: bytes = b""
84
  ctype = (request.headers.get("content-type") or "").lower()
 
85
 
86
- if file is not None:
87
- # multipart/form-data
88
- img_bytes = await file.read()
89
-
90
- elif "application/octet-stream" in ctype or "image/jpeg" in ctype:
91
- # raw bytes (ESP32 manda deste jeito)
92
  img_bytes = await request.body()
93
-
94
  else:
95
- # fallback para JSON base64 (útil em testes manuais via Postman)
96
  data = await request.json()
97
  b64 = (data.get("image_b64") or "")
98
- if "," in b64:
99
- # aceita dataURL: "data:image/jpeg;base64,...."
100
  b64 = b64.split(",", 1)[1]
101
  img_bytes = base64.b64decode(b64) if b64 else b""
102
 
@@ -110,17 +99,14 @@ async def predict(request: Request, file: UploadFile | None = File(default=None)
110
  return Response(label, media_type="text/plain")
111
 
112
  except UnidentifiedImageError:
113
- # bytes não eram uma imagem válida
114
  return Response("nao_identificado", media_type="text/plain")
115
  except Exception:
116
- # loga stack trace no console do Space para debug
117
  traceback.print_exc()
118
  return Response("nao_identificado", media_type="text/plain")
119
 
120
  # ========= WARM-UP =========
121
  @app.on_event("startup")
122
  def _warmup():
123
- # Faz uma inferência boba para “carregar” tudo e reduzir a latência da 1ª chamada real
124
  try:
125
  dummy = Image.new("RGB", (256, 256), (127, 127, 127))
126
  with torch.inference_mode():
@@ -129,4 +115,4 @@ def _warmup():
129
  print("[startup] warm-up ok")
130
  except Exception:
131
  traceback.print_exc()
132
- print("[startup] warm-up falhou (seguindo sem)")
 
1
  # app.py
2
+ from fastapi import FastAPI, Request, Response
3
  from fastapi.middleware.cors import CORSMiddleware
4
  from PIL import Image, UnidentifiedImageError
5
+ import io, torch, base64, traceback
6
  from transformers import AutoImageProcessor, AutoModelForImageClassification
7
 
8
  # ========= CONFIG =========
 
22
  torch.set_num_threads(1)
23
  torch.set_num_interop_threads(1)
24
 
25
+ # Evita imagens gigantes
26
+ Image.MAX_IMAGE_PIXELS = 25_000_000
27
 
28
  # ========= CARREGAMENTO =========
29
+ # Se aparecer aviso "use_fast=True mas torchvision não disponível",
30
+ # é só um warning; pode trocar para use_fast=False se quiser ocultar.
31
  processor = AutoImageProcessor.from_pretrained(MODEL_ID, use_fast=True)
32
  model = AutoModelForImageClassification.from_pretrained(MODEL_ID)
33
  model.eval()
34
 
35
  app = FastAPI()
36
 
37
+ # CORS (opcional)
38
  app.add_middleware(
39
  CORSMiddleware,
40
  allow_origins=["*"], allow_credentials=True,
 
42
  )
43
 
44
  def _to_label_pt(label_en: str) -> str:
45
+ return MAP_PT.get((label_en or "").strip().lower(), "nao_identificado")
 
 
46
 
47
  def _predict_image_bytes(img_bytes: bytes) -> str:
 
48
  with Image.open(io.BytesIO(img_bytes)) as img:
49
+ img = img.convert("RGB")
50
+ img = img.resize((256, 256)) # tradeoff bom para CPU do Space
51
 
52
  with torch.inference_mode():
53
  inputs = processor(images=img, return_tensors="pt")
54
  logits = model(**inputs).logits
55
+ idx = int(logits.argmax(-1)) # softmax não precisa para argmax
56
  label_en = model.config.id2label[idx]
57
  return _to_label_pt(label_en)
58
 
59
  # ========= ROTAS =========
60
  @app.get("/")
61
  def root():
 
62
  return {"ok": True, "message": "TrashNet classifier up", "model": MODEL_ID}
63
 
64
  @app.get("/health")
 
66
  return {"ok": True, "model": MODEL_ID}
67
 
68
  @app.post("/predict")
69
+ async def predict(request: Request):
70
  """
71
  Aceita:
72
  - application/octet-stream (raw JPEG no corpo)
73
  - image/jpeg (raw JPEG no corpo)
 
74
  - application/json {"image_b64": "..."} (dataURL ou base64 puro)
75
 
76
  Retorna: texto puro — 'vidro' | 'papel' | 'plastico' | 'metal' | 'nao_identificado'
77
  """
78
  try:
 
79
  ctype = (request.headers.get("content-type") or "").lower()
80
+ img_bytes: bytes = b""
81
 
82
+ if "application/octet-stream" in ctype or "image/jpeg" in ctype:
 
 
 
 
 
83
  img_bytes = await request.body()
 
84
  else:
85
+ # fallback: JSON base64
86
  data = await request.json()
87
  b64 = (data.get("image_b64") or "")
88
+ if "," in b64: # dataURL
 
89
  b64 = b64.split(",", 1)[1]
90
  img_bytes = base64.b64decode(b64) if b64 else b""
91
 
 
99
  return Response(label, media_type="text/plain")
100
 
101
  except UnidentifiedImageError:
 
102
  return Response("nao_identificado", media_type="text/plain")
103
  except Exception:
 
104
  traceback.print_exc()
105
  return Response("nao_identificado", media_type="text/plain")
106
 
107
  # ========= WARM-UP =========
108
  @app.on_event("startup")
109
  def _warmup():
 
110
  try:
111
  dummy = Image.new("RGB", (256, 256), (127, 127, 127))
112
  with torch.inference_mode():
 
115
  print("[startup] warm-up ok")
116
  except Exception:
117
  traceback.print_exc()
118
+ print("[startup] warm-up falhou (seguindo sem)")