froidhj commited on
Commit
ac41ebb
·
verified ·
1 Parent(s): c3a6888

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -26
app.py CHANGED
@@ -5,23 +5,23 @@ import io, os, torch
5
  from transformers import AutoImageProcessor, AutoModelForImageClassification
6
 
7
  # ========= CONFIG =========
8
- MODEL_ID = "AmadFR/ecovision_mobilenetv3"
9
 
10
- # Mapeamento para português (apenas as 4 classes desejadas)
11
  MAP_PT = {
12
  "glass": "vidro",
13
  "metal": "metal",
14
  "paper": "papel",
15
- "plastic": "plastico"
16
  }
17
  ALLOWED = set(MAP_PT.values())
18
 
19
- # ========= OTIMIZAÇÕES =========
20
  torch.set_grad_enabled(False)
21
  torch.set_num_threads(1)
22
  torch.set_num_interop_threads(1)
23
 
24
- # ========= CARREGA MODELO =========
25
  processor = AutoImageProcessor.from_pretrained(MODEL_ID)
26
  model = AutoModelForImageClassification.from_pretrained(MODEL_ID)
27
  model.eval()
@@ -29,19 +29,17 @@ model.eval()
29
  app = FastAPI()
30
 
31
  def predict_image_bytes(img_bytes: bytes) -> str:
32
- """Recebe bytes JPEG e devolve um dos rótulos simplificados."""
33
  img = Image.open(io.BytesIO(img_bytes)).convert("RGB")
34
- img = img.resize((224, 224)) # acelera a inferência
 
35
 
36
  inputs = processor(images=img, return_tensors="pt")
37
- outputs = model(**inputs)
38
- logits = outputs.logits
39
- predicted_class_idx = int(logits.argmax(-1))
40
- label_en = model.config.id2label[predicted_class_idx].lower()
41
 
42
- # Converte para português
43
- label_pt = MAP_PT.get(label_en, "nao_identificado")
44
- return label_pt
45
 
46
  @app.get("/health")
47
  def health():
@@ -50,33 +48,28 @@ def health():
50
  @app.post("/predict")
51
  async def predict(request: Request):
52
  """
53
- Espera: imagem JPEG (application/octet-stream)
54
- Retorna: texto puro - 'vidro', 'papel', 'plastico', 'metal' ou 'nao_identificado'
55
  """
56
  try:
57
- content_type = (request.headers.get("content-type") or "").lower()
58
 
59
- # ESP32 envia como application/octet-stream
60
- if "application/octet-stream" in content_type or "image/jpeg" in content_type:
61
  img_bytes = await request.body()
62
  else:
63
- # fallback para JSON base64 (para testes)
64
  data = await request.json()
65
  import base64
66
- b64 = data.get("image_b64", "").split(",")[-1]
67
  img_bytes = base64.b64decode(b64) if b64 else b""
68
 
69
  if not img_bytes:
70
  return Response("nao_identificado", media_type="text/plain")
71
 
72
  label = predict_image_bytes(img_bytes)
73
-
74
- # filtro final — só 4 materiais
75
  if label not in ALLOWED:
76
  label = "nao_identificado"
77
 
78
  return Response(label, media_type="text/plain")
79
-
80
- except Exception as e:
81
- print("Erro:", e)
82
  return Response("nao_identificado", media_type="text/plain")
 
5
  from transformers import AutoImageProcessor, AutoModelForImageClassification
6
 
7
  # ========= CONFIG =========
8
+ MODEL_ID = "prithivMLmods/Trash-Net"
9
 
10
+ # Mantemos estas 4 classes em PT-BR; o resto vira "nao_identificado"
11
  MAP_PT = {
12
  "glass": "vidro",
13
  "metal": "metal",
14
  "paper": "papel",
15
+ "plastic": "plastico",
16
  }
17
  ALLOWED = set(MAP_PT.values())
18
 
19
+ # ========= OTIMIZAÇÕES (CPU do Space) =========
20
  torch.set_grad_enabled(False)
21
  torch.set_num_threads(1)
22
  torch.set_num_interop_threads(1)
23
 
24
+ # ========= CARREGAMENTO =========
25
  processor = AutoImageProcessor.from_pretrained(MODEL_ID)
26
  model = AutoModelForImageClassification.from_pretrained(MODEL_ID)
27
  model.eval()
 
29
  app = FastAPI()
30
 
31
  def predict_image_bytes(img_bytes: bytes) -> str:
 
32
  img = Image.open(io.BytesIO(img_bytes)).convert("RGB")
33
+ # Reduz um pouco para acelerar sem perder muito
34
+ img = img.resize((256, 256))
35
 
36
  inputs = processor(images=img, return_tensors="pt")
37
+ logits = model(**inputs).logits
38
+ idx = int(logits.softmax(-1).argmax(-1))
39
+ label_en = model.config.id2label[idx].lower()
 
40
 
41
+ # Converte apenas se for uma das 4; senão marca como não identificado
42
+ return MAP_PT.get(label_en, "nao_identificado")
 
43
 
44
  @app.get("/health")
45
  def health():
 
48
  @app.post("/predict")
49
  async def predict(request: Request):
50
  """
51
+ Espera: bytes JPEG (application/octet-stream)
52
+ Retorna: texto puro 'vidro' | 'papel' | 'plastico' | 'metal' | 'nao_identificado'
53
  """
54
  try:
55
+ ctype = (request.headers.get("content-type") or "").lower()
56
 
57
+ if "application/octet-stream" in ctype or "image/jpeg" in ctype:
 
58
  img_bytes = await request.body()
59
  else:
60
+ # fallback opcional para JSON base64 (testes manuais)
61
  data = await request.json()
62
  import base64
63
+ b64 = (data.get("image_b64") or "").split(",")[-1]
64
  img_bytes = base64.b64decode(b64) if b64 else b""
65
 
66
  if not img_bytes:
67
  return Response("nao_identificado", media_type="text/plain")
68
 
69
  label = predict_image_bytes(img_bytes)
 
 
70
  if label not in ALLOWED:
71
  label = "nao_identificado"
72
 
73
  return Response(label, media_type="text/plain")
74
+ except Exception:
 
 
75
  return Response("nao_identificado", media_type="text/plain")