froidhj commited on
Commit
65f698e
·
verified ·
1 Parent(s): 9fd9d16

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -106
app.py CHANGED
@@ -1,89 +1,52 @@
1
  # app.py
2
  from fastapi import FastAPI, Request, Response
3
  from PIL import Image, ImageOps
4
- import io, os, torch, json
5
  import torch.nn.functional as F
6
  from transformers import AutoImageProcessor, AutoModelForImageClassification
7
 
8
  # ========= CONFIG =========
9
  MODEL_ID = "prithivMLmods/Trash-Net"
10
 
11
- # PT-BR map (somente 4 classes)
12
  MAP_PT = {
13
  "glass": "vidro",
14
  "metal": "metal",
15
  "paper": "papel",
16
  "plastic": "plastico",
17
  }
 
18
  TARGETS_EN = list(MAP_PT.keys()) # ["glass","metal","paper","plastic"]
19
 
20
- # ========= OTIMIZAÇÕES (CPU do Space) =========
21
  torch.set_grad_enabled(False)
22
  torch.set_num_threads(1)
23
  torch.set_num_interop_threads(1)
24
- DEVICE = "cpu"
25
 
26
- # ========= CARREGAMENTO =========
27
- # use_fast=True evita o aviso e tende a ser mais eficiente
28
- processor = AutoImageProcessor.from_pretrained(MODEL_ID, use_fast=True)
29
  model = AutoModelForImageClassification.from_pretrained(MODEL_ID)
30
  model.eval()
31
 
32
- # ========= MAPAS DE RÓTULO (robustos) =========
33
- cfg = model.config
34
-
35
- # Tenta id2label direto (id->str)
36
- _id2label = {}
37
- if getattr(cfg, "id2label", None):
38
- # pode vir com chaves str ou int; normalizamos:
39
- try:
40
- _id2label = {int(k): str(v) for k, v in cfg.id2label.items()}
41
- except Exception:
42
- # alguns modelos já trazem chaves int
43
- _id2label = {int(i): str(lbl) for i, lbl in cfg.id2label.items()}
44
-
45
- # Tenta label2id direto (str->id)
46
- _label2id = {}
47
- if getattr(cfg, "label2id", None):
48
- try:
49
- _label2id = {str(k).strip().lower(): int(v) for k, v in cfg.label2id.items()}
50
- except Exception:
51
- # fallback: se o modelo tiver salvo ao contrário (id->label), inverta
52
- _label2id = {}
53
-
54
- # Se label2id não veio, derive de id2label
55
- if not _label2id and _id2label:
56
- _label2id = {str(v).strip().lower(): int(k) for k, v in _id2label.items()}
57
 
58
- # Se id2label não veio, derive de label2id
59
- if not _id2label and _label2id:
60
- _id2label = {int(v): str(k) for k, v in _label2id.items()}
61
-
62
- # Logs de diagnóstico (aparecem no console do Space)
63
- print("config.id2label:", _id2label)
64
- print("config.label2id:", _label2id)
65
-
66
- # ========= DESCOBERTA DOS 4 ÍNDICES‐ALVO =========
67
  target_indices = []
68
- target_indices_en = [] # rótulo EN correspondente na mesma ordem
69
-
70
- # 1) tentamos correspondência exata (case-insensitive)
71
  for en in TARGETS_EN:
72
- key = en.lower()
73
- if key in _label2id:
74
- idx = _label2id[key]
75
- if idx not in target_indices:
76
- target_indices.append(idx)
77
- target_indices_en.append(en)
78
 
79
- # 2) se faltar algum, tentamos "contém" no id2label (ex.: "cardboard" ~ paper)
80
  if len(target_indices) < 4:
81
  for en in TARGETS_EN:
82
  if en in target_indices_en:
83
  continue
84
  found = None
85
  en_low = en.lower()
86
- for i, lab in _id2label.items():
87
  if en_low in lab.lower():
88
  found = i
89
  break
@@ -91,107 +54,89 @@ if len(target_indices) < 4:
91
  target_indices.append(found)
92
  target_indices_en.append(en)
93
 
94
- # (continua mesmo que haja <4; sempre escolheremos dentre os disponíveis)
95
-
96
  app = FastAPI()
97
 
98
- # ========= PRÉ‐PROCESS =========
99
  def _prepare_image(img_bytes: bytes) -> Image.Image:
 
100
  img = Image.open(io.BytesIO(img_bytes)).convert("RGB")
101
- # Corrige rotação (EXIF)
102
  img = ImageOps.exif_transpose(img)
103
- # Center-crop quadrado para reduzir distorções periféricas
104
  w, h = img.size
105
  side = min(w, h)
106
  left = (w - side) // 2
107
  top = (h - side) // 2
108
  img = img.crop((left, top, left + side, top + side))
109
- # O processor cuida do resize/padding normalizado do modelo
110
  return img
111
 
112
- # ========= PREDICT =========
113
  def predict_image_bytes(img_bytes: bytes):
114
  """
115
  Retorna (label_pt, confidence_float_0_1)
116
- Sempre uma das 4 classes: vidro/papel/plastico/metal
117
  """
118
  img = _prepare_image(img_bytes)
119
  inputs = processor(images=img, return_tensors="pt")
120
  logits = model(**inputs).logits # [1, num_labels]
121
- probs = F.softmax(logits, dim=-1)[0] # [num_labels]
122
 
123
  if target_indices:
124
- subset = probs[target_indices] # [<=4]
125
- j = int(torch.argmax(subset).item()) # posição dentro do subset
126
- best_idx_global = target_indices[j] # índice real no espaço do modelo
127
- best_model_label = _id2label.get(best_idx_global, "").lower()
 
128
  conf = float(subset[j].item())
129
-
130
- # Mapeia para PT (prioriza rótulo exato; senão usa a intenção TARGETS_EN[j])
131
- if best_model_label in MAP_PT:
132
- label_pt = MAP_PT[best_model_label]
133
  else:
134
  label_pt = MAP_PT[target_indices_en[j]]
135
  return label_pt, conf
136
-
137
- # Fallback global (se não achamos índice nenhum)
138
- i = int(torch.argmax(probs).item())
139
- best_en = _id2label.get(i, "").lower()
140
- conf = float(probs[i].item())
141
- if "glass" in best_en:
142
- label_pt = "vidro"
143
- elif ("metal" in best_en) or ("steel" in best_en) or ("alum" in best_en):
144
- label_pt = "metal"
145
- elif ("paper" in best_en) or ("cardboard" in best_en):
146
- label_pt = "papel"
147
  else:
148
- label_pt = "plastico"
149
- return label_pt, conf
 
 
 
 
 
 
 
 
 
 
 
150
 
151
- # ========= ENDPOINTS =========
152
  @app.get("/health")
153
  def health():
154
- return {
155
- "ok": True,
156
- "model": MODEL_ID,
157
- "targets_en": TARGETS_EN,
158
- "targets_pt": list(MAP_PT.values()),
159
- "mapped_indices": target_indices,
160
- }
161
 
162
  @app.post("/predict")
163
  async def predict(request: Request):
164
  """
165
  Entrada:
166
- - bytes JPEG/PNG (Content-Type: application/octet-stream | image/jpeg | image/png)
167
- - ou JSON {"image_b64": "..."} (útil para teste manual)
 
168
  Saída:
169
  - texto puro: 'vidro' | 'papel' | 'plastico' | 'metal'
170
  - header X-Confidence com a confiança 0..1
171
  """
172
  try:
173
  ctype = (request.headers.get("content-type") or "").lower()
174
- img_bytes = b""
175
-
176
  if "application/octet-stream" in ctype or "image/jpeg" in ctype or "image/png" in ctype:
177
- img_bytes = await request.body()
178
  else:
179
  data = await request.json()
180
  import base64
181
  b64 = (data.get("image_b64") or "").split(",")[-1]
182
- if b64:
183
- img_bytes = base64.b64decode(b64)
184
 
185
  if not img_bytes:
186
- # Sem imagem: retorna uma classe válida com confiança 0
187
- return Response("plastico", media_type="text/plain",
188
- headers={"X-Confidence": "0.0000"})
189
 
190
  label_pt, conf = predict_image_bytes(img_bytes)
191
- return Response(label_pt, media_type="text/plain",
192
- headers={"X-Confidence": f"{conf:.4f}"})
193
  except Exception as e:
194
- # Loga erro e ainda assim devolve uma das 4 (mantém o pipeline vivo)
195
- print("predict error:", repr(e))
196
- return Response("plastico", media_type="text/plain",
197
- headers={"X-Confidence": "0.0000"})
 
1
  # app.py
2
  from fastapi import FastAPI, Request, Response
3
  from PIL import Image, ImageOps
4
+ import io, os, torch
5
  import torch.nn.functional as F
6
  from transformers import AutoImageProcessor, AutoModelForImageClassification
7
 
8
  # ========= CONFIG =========
9
  MODEL_ID = "prithivMLmods/Trash-Net"
10
 
11
+ # PT-BR map (somente 4 classes principais)
12
  MAP_PT = {
13
  "glass": "vidro",
14
  "metal": "metal",
15
  "paper": "papel",
16
  "plastic": "plastico",
17
  }
18
+
19
  TARGETS_EN = list(MAP_PT.keys()) # ["glass","metal","paper","plastic"]
20
 
21
+ # ========= OTIMIZAÇÕES (para CPU do Space) =========
22
  torch.set_grad_enabled(False)
23
  torch.set_num_threads(1)
24
  torch.set_num_interop_threads(1)
 
25
 
26
+ # ========= CARREGAMENTO DO MODELO =========
27
+ processor = AutoImageProcessor.from_pretrained(MODEL_ID)
 
28
  model = AutoModelForImageClassification.from_pretrained(MODEL_ID)
29
  model.eval()
30
 
31
+ # Cria dicionários auxiliares de mapeamento
32
+ id2label = {int(k): v for k, v in model.config.id2label.items()}
33
+ label2id = {v.lower(): int(k) for k, v in model.config.label2id.items()}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
+ # Descobre os índices das classes principais dentro do modelo
 
 
 
 
 
 
 
 
36
  target_indices = []
37
+ target_indices_en = []
 
 
38
  for en in TARGETS_EN:
39
+ if en in label2id:
40
+ target_indices.append(label2id[en])
41
+ target_indices_en.append(en)
 
 
 
42
 
 
43
  if len(target_indices) < 4:
44
  for en in TARGETS_EN:
45
  if en in target_indices_en:
46
  continue
47
  found = None
48
  en_low = en.lower()
49
+ for i, lab in id2label.items():
50
  if en_low in lab.lower():
51
  found = i
52
  break
 
54
  target_indices.append(found)
55
  target_indices_en.append(en)
56
 
57
+ # ========= FASTAPI APP =========
 
58
  app = FastAPI()
59
 
60
+ # ========= FUNÇÕES =========
61
  def _prepare_image(img_bytes: bytes) -> Image.Image:
62
+ """Prepara a imagem (corrige orientação, recorta e converte RGB)."""
63
  img = Image.open(io.BytesIO(img_bytes)).convert("RGB")
 
64
  img = ImageOps.exif_transpose(img)
 
65
  w, h = img.size
66
  side = min(w, h)
67
  left = (w - side) // 2
68
  top = (h - side) // 2
69
  img = img.crop((left, top, left + side, top + side))
 
70
  return img
71
 
 
72
  def predict_image_bytes(img_bytes: bytes):
73
  """
74
  Retorna (label_pt, confidence_float_0_1)
 
75
  """
76
  img = _prepare_image(img_bytes)
77
  inputs = processor(images=img, return_tensors="pt")
78
  logits = model(**inputs).logits # [1, num_labels]
 
79
 
80
  if target_indices:
81
+ probs = F.softmax(logits, dim=-1)[0]
82
+ subset = probs[target_indices]
83
+ j = int(torch.argmax(subset).item())
84
+ best_idx_global = target_indices[j]
85
+ best_en = id2label[best_idx_global].lower()
86
  conf = float(subset[j].item())
87
+ if best_en in MAP_PT:
88
+ label_pt = MAP_PT[best_en]
 
 
89
  else:
90
  label_pt = MAP_PT[target_indices_en[j]]
91
  return label_pt, conf
 
 
 
 
 
 
 
 
 
 
 
92
  else:
93
+ probs = F.softmax(logits, dim=-1)[0]
94
+ i = int(torch.argmax(probs).item())
95
+ best_en = id2label[i].lower()
96
+ conf = float(probs[i].item())
97
+ if "glass" in best_en:
98
+ label_pt = "vidro"
99
+ elif "metal" in best_en or "steel" in best_en or "aluminum" in best_en:
100
+ label_pt = "metal"
101
+ elif "paper" in best_en or "cardboard" in best_en:
102
+ label_pt = "papel"
103
+ else:
104
+ label_pt = "plastico"
105
+ return label_pt, conf
106
 
107
+ # ========= ROTAS =========
108
  @app.get("/health")
109
  def health():
110
+ """Verifica se o servidor está ativo."""
111
+ return {"ok": True, "model": MODEL_ID, "targets": list(MAP_PT.values())}
 
 
 
 
 
112
 
113
  @app.post("/predict")
114
  async def predict(request: Request):
115
  """
116
  Entrada:
117
+ - bytes JPEG (Content-Type: application/octet-stream ou image/jpeg)
118
+ - ou JSON {"image_b64": "..."} (apenas para testes manuais)
119
+
120
  Saída:
121
  - texto puro: 'vidro' | 'papel' | 'plastico' | 'metal'
122
  - header X-Confidence com a confiança 0..1
123
  """
124
  try:
125
  ctype = (request.headers.get("content-type") or "").lower()
 
 
126
  if "application/octet-stream" in ctype or "image/jpeg" in ctype or "image/png" in ctype:
127
+ img_bytes = await request.body() # <-- aqui está o correto
128
  else:
129
  data = await request.json()
130
  import base64
131
  b64 = (data.get("image_b64") or "").split(",")[-1]
132
+ img_bytes = base64.b64decode(b64) if b64 else b""
 
133
 
134
  if not img_bytes:
135
+ return Response("plastico", media_type="text/plain", headers={"X-Confidence": "0.00"})
 
 
136
 
137
  label_pt, conf = predict_image_bytes(img_bytes)
138
+ return Response(label_pt, media_type="text/plain", headers={"X-Confidence": f"{conf:.4f}"})
139
+
140
  except Exception as e:
141
+ print("predict error:", e)
142
+ return Response("plastico", media_type="text/plain", headers={"X-Confidence": "0.00"})