froidhj commited on
Commit
7b7ccff
·
verified ·
1 Parent(s): 3d60c5e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -91
app.py CHANGED
@@ -1,23 +1,22 @@
1
  # app.py
2
  from fastapi import FastAPI, Request, Response
3
- from PIL import Image, ImageOps
4
  import io, os, torch
5
- import torch.nn.functional as F
6
  from transformers import AutoImageProcessor, AutoModelForImageClassification
7
 
8
  # ========= CONFIG =========
9
  MODEL_ID = "prithivMLmods/Trash-Net"
10
 
11
- # PT-BR map (somente 4 classes principais)
12
  MAP_PT = {
13
  "glass": "vidro",
14
  "metal": "metal",
15
  "paper": "papel",
16
  "plastic": "plastico",
17
  }
18
- TARGETS_EN = list(MAP_PT.keys())
19
 
20
- # ========= OTIMIZAÇÕES =========
21
  torch.set_grad_enabled(False)
22
  torch.set_num_threads(1)
23
  torch.set_num_interop_threads(1)
@@ -27,113 +26,50 @@ processor = AutoImageProcessor.from_pretrained(MODEL_ID)
27
  model = AutoModelForImageClassification.from_pretrained(MODEL_ID)
28
  model.eval()
29
 
30
- # --- Trata o caso label2id invertido ---
31
- id2label_raw = model.config.id2label
32
- label2id_raw = model.config.label2id
33
-
34
- id2label = {}
35
- label2id = {}
36
-
37
- for k, v in id2label_raw.items():
38
- # Normaliza chaves e valores para int→str
39
- try:
40
- id2label[int(k)] = str(v)
41
- except Exception:
42
- id2label[int(v)] = str(k)
43
-
44
- for k, v in label2id_raw.items():
45
- # Normaliza para str→int
46
- try:
47
- label2id[str(k).lower()] = int(v)
48
- except Exception:
49
- label2id[str(v).lower()] = int(k)
50
-
51
- # Descobre índices das 4 classes-alvo
52
- target_indices = []
53
- target_indices_en = []
54
- for en in TARGETS_EN:
55
- if en in label2id:
56
- target_indices.append(label2id[en])
57
- target_indices_en.append(en)
58
-
59
- if len(target_indices) < 4:
60
- for en in TARGETS_EN:
61
- if en in target_indices_en:
62
- continue
63
- found = None
64
- en_low = en.lower()
65
- for i, lab in id2label.items():
66
- if en_low in lab.lower():
67
- found = i
68
- break
69
- if found is not None and found not in target_indices:
70
- target_indices.append(found)
71
- target_indices_en.append(en)
72
-
73
- # ========= APP =========
74
  app = FastAPI()
75
 
76
- # ========= FUNÇÕES =========
77
- def _prepare_image(img_bytes: bytes) -> Image.Image:
78
  img = Image.open(io.BytesIO(img_bytes)).convert("RGB")
79
- img = ImageOps.exif_transpose(img)
80
- w, h = img.size
81
- side = min(w, h)
82
- left = (w - side) // 2
83
- top = (h - side) // 2
84
- img = img.crop((left, top, left + side, top + side))
85
- return img
86
 
87
- def predict_image_bytes(img_bytes: bytes):
88
- img = _prepare_image(img_bytes)
89
  inputs = processor(images=img, return_tensors="pt")
90
  logits = model(**inputs).logits
91
- probs = F.softmax(logits, dim=-1)[0]
 
92
 
93
- if target_indices:
94
- subset = probs[target_indices]
95
- j = int(torch.argmax(subset).item())
96
- best_idx_global = target_indices[j]
97
- best_en = id2label[best_idx_global].lower()
98
- conf = float(subset[j].item())
99
- label_pt = MAP_PT.get(best_en, MAP_PT[target_indices_en[j]])
100
- return label_pt, conf
101
- else:
102
- i = int(torch.argmax(probs).item())
103
- best_en = id2label[i].lower()
104
- conf = float(probs[i].item())
105
- if "glass" in best_en:
106
- label_pt = "vidro"
107
- elif "metal" in best_en or "steel" in best_en or "aluminum" in best_en:
108
- label_pt = "metal"
109
- elif "paper" in best_en or "cardboard" in best_en:
110
- label_pt = "papel"
111
- else:
112
- label_pt = "plastico"
113
- return label_pt, conf
114
 
115
- # ========= ROTAS =========
116
  @app.get("/health")
117
  def health():
118
- return {"ok": True, "model": MODEL_ID, "targets": list(MAP_PT.values())}
119
 
120
  @app.post("/predict")
121
  async def predict(request: Request):
 
 
 
 
122
  try:
123
  ctype = (request.headers.get("content-type") or "").lower()
124
- if "application/octet-stream" in ctype or "image/jpeg" in ctype or "image/png" in ctype:
 
125
  img_bytes = await request.body()
126
  else:
 
127
  data = await request.json()
128
  import base64
129
  b64 = (data.get("image_b64") or "").split(",")[-1]
130
  img_bytes = base64.b64decode(b64) if b64 else b""
131
 
132
  if not img_bytes:
133
- return Response("plastico", media_type="text/plain", headers={"X-Confidence": "0.00"})
 
 
 
 
134
 
135
- label_pt, conf = predict_image_bytes(img_bytes)
136
- return Response(label_pt, media_type="text/plain", headers={"X-Confidence": f"{conf:.4f}"})
137
- except Exception as e:
138
- print("predict error:", e)
139
- return Response("plastico", media_type="text/plain", headers={"X-Confidence": "0.00"})
 
1
  # app.py
2
  from fastapi import FastAPI, Request, Response
3
+ from PIL import Image
4
  import io, os, torch
 
5
  from transformers import AutoImageProcessor, AutoModelForImageClassification
6
 
7
  # ========= CONFIG =========
8
  MODEL_ID = "prithivMLmods/Trash-Net"
9
 
10
+ # Mantemos só estas 4 classes em PT-BR; o resto vira "nao_identificado"
11
  MAP_PT = {
12
  "glass": "vidro",
13
  "metal": "metal",
14
  "paper": "papel",
15
  "plastic": "plastico",
16
  }
17
+ ALLOWED = set(MAP_PT.values())
18
 
19
+ # ========= OTIMIZAÇÕES (CPU do Space) =========
20
  torch.set_grad_enabled(False)
21
  torch.set_num_threads(1)
22
  torch.set_num_interop_threads(1)
 
26
  model = AutoModelForImageClassification.from_pretrained(MODEL_ID)
27
  model.eval()
28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  app = FastAPI()
30
 
31
+ def predict_image_bytes(img_bytes: bytes) -> str:
 
32
  img = Image.open(io.BytesIO(img_bytes)).convert("RGB")
33
+ # Reduz um pouco para acelerar sem perder muito
34
+ img = img.resize((256, 256))
 
 
 
 
 
35
 
 
 
36
  inputs = processor(images=img, return_tensors="pt")
37
  logits = model(**inputs).logits
38
+ idx = int(logits.softmax(-1).argmax(-1))
39
+ label_en = model.config.id2label[idx].lower()
40
 
41
+ # Converte apenas se for uma das 4; senão marca como não identificado
42
+ return MAP_PT.get(label_en, "nao_identificado")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
 
44
  @app.get("/health")
45
  def health():
46
+ return {"ok": True, "model": MODEL_ID}
47
 
48
  @app.post("/predict")
49
  async def predict(request: Request):
50
+ """
51
+ Espera: bytes JPEG (application/octet-stream)
52
+ Retorna: texto puro — 'vidro' | 'papel' | 'plastico' | 'metal' | 'nao_identificado'
53
+ """
54
  try:
55
  ctype = (request.headers.get("content-type") or "").lower()
56
+
57
+ if "application/octet-stream" in ctype or "image/jpeg" in ctype:
58
  img_bytes = await request.body()
59
  else:
60
+ # fallback opcional para JSON base64 (testes manuais)
61
  data = await request.json()
62
  import base64
63
  b64 = (data.get("image_b64") or "").split(",")[-1]
64
  img_bytes = base64.b64decode(b64) if b64 else b""
65
 
66
  if not img_bytes:
67
+ return Response("nao_identificado", media_type="text/plain")
68
+
69
+ label = predict_image_bytes(img_bytes)
70
+ if label not in ALLOWED:
71
+ label = "nao_identificado"
72
 
73
+ return Response(label, media_type="text/plain")
74
+ except Exception:
75
+ return Response("nao_identificado", media_type="text/plain")