froidhj commited on
Commit
929f4df
·
verified ·
1 Parent(s): 1b4701d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -31
app.py CHANGED
@@ -2,20 +2,20 @@
2
  from fastapi import FastAPI, Request, Response
3
  from fastapi.middleware.cors import CORSMiddleware
4
  from PIL import Image, UnidentifiedImageError
5
- import io, torch, base64, traceback
6
  from transformers import AutoImageProcessor, AutoModelForImageClassification
7
 
8
  # ========= CONFIG =========
9
  MODEL_ID = "prithivMLmods/Trash-Net"
10
 
11
- # Mantemos estas 4 classes em PT-BR; o resto vira "nao_identificado"
12
  MAP_PT = {
13
  "glass": "vidro",
14
  "metal": "metal",
15
  "paper": "papel",
16
  "plastic": "plastico",
17
  }
18
- ALLOWED = set(MAP_PT.values())
19
 
20
  # ========= OTIMIZAÇÕES (CPU do Space) =========
21
  torch.set_grad_enabled(False)
@@ -26,8 +26,6 @@ torch.set_num_interop_threads(1)
26
  Image.MAX_IMAGE_PIXELS = 25_000_000
27
 
28
  # ========= CARREGAMENTO =========
29
- # Se aparecer aviso "use_fast=True mas torchvision não disponível",
30
- # é só um warning; pode trocar para use_fast=False se quiser ocultar.
31
  processor = AutoImageProcessor.from_pretrained(MODEL_ID, use_fast=True)
32
  model = AutoModelForImageClassification.from_pretrained(MODEL_ID)
33
  model.eval()
@@ -41,8 +39,14 @@ app.add_middleware(
41
  allow_methods=["*"], allow_headers=["*"],
42
  )
43
 
44
- def _to_label_pt(label_en: str) -> str:
45
- return MAP_PT.get((label_en or "").strip().lower(), "nao_identificado")
 
 
 
 
 
 
46
 
47
  def _predict_image_bytes(img_bytes: bytes) -> str:
48
  with Image.open(io.BytesIO(img_bytes)) as img:
@@ -52,9 +56,9 @@ def _predict_image_bytes(img_bytes: bytes) -> str:
52
  with torch.inference_mode():
53
  inputs = processor(images=img, return_tensors="pt")
54
  logits = model(**inputs).logits
55
- idx = int(logits.argmax(-1)) # softmax não precisa para argmax
56
  label_en = model.config.id2label[idx]
57
- return _to_label_pt(label_en)
58
 
59
  # ========= ROTAS =========
60
  @app.get("/")
@@ -73,36 +77,40 @@ async def predict(request: Request):
73
  - image/jpeg (raw JPEG no corpo)
74
  - application/json {"image_b64": "..."} (dataURL ou base64 puro)
75
 
76
- Retorna: texto puro — 'vidro' | 'papel' | 'plastico' | 'metal' | 'nao_identificado'
 
77
  """
78
  try:
79
- ctype = (request.headers.get("content-type") or "").lower()
80
- img_bytes: bytes = b""
81
 
82
- if "application/octet-stream" in ctype or "image/jpeg" in ctype:
83
- img_bytes = await request.body()
84
- else:
85
- # fallback: JSON base64
86
- data = await request.json()
87
- b64 = (data.get("image_b64") or "")
88
- if "," in b64: # dataURL
89
- b64 = b64.split(",", 1)[1]
90
- img_bytes = base64.b64decode(b64) if b64 else b""
91
 
92
- if not img_bytes:
93
- return Response("nao_identificado", media_type="text/plain")
 
94
 
95
- label = _predict_image_bytes(img_bytes)
96
- if label not in ALLOWED:
97
- label = "nao_identificado"
98
 
99
- return Response(label, media_type="text/plain")
 
 
 
 
100
 
101
  except UnidentifiedImageError:
102
- return Response("nao_identificado", media_type="text/plain")
103
  except Exception:
104
- traceback.print_exc()
105
- return Response("nao_identificado", media_type="text/plain")
106
 
107
  # ========= WARM-UP =========
108
  @app.on_event("startup")
@@ -115,4 +123,4 @@ def _warmup():
115
  print("[startup] warm-up ok")
116
  except Exception:
117
  traceback.print_exc()
118
- print("[startup] warm-up falhou (seguindo sem)")
 
2
  from fastapi import FastAPI, Request, Response
3
  from fastapi.middleware.cors import CORSMiddleware
4
  from PIL import Image, UnidentifiedImageError
5
+ import io, torch, base64, traceback, random
6
  from transformers import AutoImageProcessor, AutoModelForImageClassification
7
 
8
  # ========= CONFIG =========
9
  MODEL_ID = "prithivMLmods/Trash-Net"
10
 
11
+ # Mapa EN -> PT (apenas 4 classes desejadas)
12
  MAP_PT = {
13
  "glass": "vidro",
14
  "metal": "metal",
15
  "paper": "papel",
16
  "plastic": "plastico",
17
  }
18
+ ALLOWED = ["plastico", "papel", "vidro", "metal"] # ordem fixa p/ random
19
 
20
  # ========= OTIMIZAÇÕES (CPU do Space) =========
21
  torch.set_grad_enabled(False)
 
26
  Image.MAX_IMAGE_PIXELS = 25_000_000
27
 
28
  # ========= CARREGAMENTO =========
 
 
29
  processor = AutoImageProcessor.from_pretrained(MODEL_ID, use_fast=True)
30
  model = AutoModelForImageClassification.from_pretrained(MODEL_ID)
31
  model.eval()
 
39
  allow_methods=["*"], allow_headers=["*"],
40
  )
41
 
42
+ def _force_allowed(label_en: str | None) -> str:
43
+ """Converte label EN para PT se mapeado; caso contrário, escolhe aleatoriamente uma das 4."""
44
+ if label_en:
45
+ pt = MAP_PT.get(label_en.strip().lower())
46
+ if pt in ALLOWED:
47
+ return pt
48
+ # fallback forçado
49
+ return random.choice(ALLOWED)
50
 
51
  def _predict_image_bytes(img_bytes: bytes) -> str:
52
  with Image.open(io.BytesIO(img_bytes)) as img:
 
56
  with torch.inference_mode():
57
  inputs = processor(images=img, return_tensors="pt")
58
  logits = model(**inputs).logits
59
+ idx = int(logits.argmax(-1))
60
  label_en = model.config.id2label[idx]
61
+ return _force_allowed(label_en)
62
 
63
  # ========= ROTAS =========
64
  @app.get("/")
 
77
  - image/jpeg (raw JPEG no corpo)
78
  - application/json {"image_b64": "..."} (dataURL ou base64 puro)
79
 
80
+ Retorna SEMPRE: texto puro — 'vidro' | 'papel' | 'plastico' | 'metal'
81
+ (nunca 'nao_identificado')
82
  """
83
  try:
84
+ ctype = (request.headers.get("content-type") or "").lower()
85
+ img_bytes: bytes = b""
86
 
87
+ if "application/octet-stream" in ctype or "image/jpeg" in ctype:
88
+ img_bytes = await request.body()
89
+ else:
90
+ # fallback: JSON base64
91
+ data = await request.json()
92
+ b64 = (data.get("image_b64") or "")
93
+ if "," in b64: # dataURL
94
+ b64 = b64.split(",", 1)[1]
95
+ img_bytes = base64.b64decode(b64) if b64 else b""
96
 
97
+ # Se veio vazio, ainda assim devolve um dos 4
98
+ if not img_bytes:
99
+ return Response(random.choice(ALLOWED), media_type="text/plain")
100
 
101
+ label = _predict_image_bytes(img_bytes)
 
 
102
 
103
+ # Por garantia, força para uma das 4
104
+ if label not in ALLOWED:
105
+ label = random.choice(ALLOWED)
106
+
107
+ return Response(label, media_type="text/plain")
108
 
109
  except UnidentifiedImageError:
110
+ return Response(random.choice(ALLOWED), media_type="text/plain")
111
  except Exception:
112
+ traceback.print_exc()
113
+ return Response(random.choice(ALLOWED), media_type="text/plain")
114
 
115
  # ========= WARM-UP =========
116
  @app.on_event("startup")
 
123
  print("[startup] warm-up ok")
124
  except Exception:
125
  traceback.print_exc()
126
+ print("[startup] warm-up falhou (seguindo sem)")