froidhj commited on
Commit
c3a6888
·
verified ·
1 Parent(s): f448105

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -15
app.py CHANGED
@@ -1,40 +1,82 @@
 
1
  from fastapi import FastAPI, Request, Response
2
  from PIL import Image
3
- import io, torch
4
  from transformers import AutoImageProcessor, AutoModelForImageClassification
5
 
6
- MODEL_ID = "huggingface-projects/resnet-50"
7
- PT_MAP = {
8
- "plastic":"plastico", "paper":"papel", "glass":"vidro",
9
- "metal":"metal", "cardboard":"papel", "trash":"nao_identificado"
 
 
 
 
 
10
  }
 
 
 
 
 
 
11
 
 
12
  processor = AutoImageProcessor.from_pretrained(MODEL_ID)
13
  model = AutoModelForImageClassification.from_pretrained(MODEL_ID)
14
  model.eval()
15
 
16
  app = FastAPI()
17
 
18
- def predict_bytes(img_bytes: bytes) -> str:
 
19
  img = Image.open(io.BytesIO(img_bytes)).convert("RGB")
 
 
20
  inputs = processor(images=img, return_tensors="pt")
21
- with torch.no_grad():
22
- logits = model(**inputs).logits
23
- idx = int(logits.softmax(-1).argmax(-1))
24
- label_en = model.config.id2label[idx].lower()
25
- return PT_MAP.get(label_en, "nao_identificado")
 
 
 
26
 
27
  @app.get("/health")
28
  def health():
29
- return {"ok": True}
30
 
31
  @app.post("/predict")
32
  async def predict(request: Request):
 
 
 
 
33
  try:
34
- img_bytes = await request.body()
 
 
 
 
 
 
 
 
 
 
 
35
  if not img_bytes:
36
  return Response("nao_identificado", media_type="text/plain")
37
- label = predict_bytes(img_bytes)
 
 
 
 
 
 
38
  return Response(label, media_type="text/plain")
39
- except Exception:
 
 
40
  return Response("nao_identificado", media_type="text/plain")
 
1
+ # app.py
2
  from fastapi import FastAPI, Request, Response
3
  from PIL import Image
4
+ import io, os, torch
5
  from transformers import AutoImageProcessor, AutoModelForImageClassification
6
 
7
+ # ========= CONFIG =========
8
+ MODEL_ID = "AmadFR/ecovision_mobilenetv3"
9
+
10
+ # Mapeamento para português (apenas as 4 classes desejadas)
11
+ MAP_PT = {
12
+ "glass": "vidro",
13
+ "metal": "metal",
14
+ "paper": "papel",
15
+ "plastic": "plastico"
16
  }
17
+ ALLOWED = set(MAP_PT.values())
18
+
19
+ # ========= OTIMIZAÇÕES =========
20
+ torch.set_grad_enabled(False)
21
+ torch.set_num_threads(1)
22
+ torch.set_num_interop_threads(1)
23
 
24
+ # ========= CARREGA MODELO =========
25
  processor = AutoImageProcessor.from_pretrained(MODEL_ID)
26
  model = AutoModelForImageClassification.from_pretrained(MODEL_ID)
27
  model.eval()
28
 
29
  app = FastAPI()
30
 
31
+ def predict_image_bytes(img_bytes: bytes) -> str:
32
+ """Recebe bytes JPEG e devolve um dos rótulos simplificados."""
33
  img = Image.open(io.BytesIO(img_bytes)).convert("RGB")
34
+ img = img.resize((224, 224)) # acelera a inferência
35
+
36
  inputs = processor(images=img, return_tensors="pt")
37
+ outputs = model(**inputs)
38
+ logits = outputs.logits
39
+ predicted_class_idx = int(logits.argmax(-1))
40
+ label_en = model.config.id2label[predicted_class_idx].lower()
41
+
42
+ # Converte para português
43
+ label_pt = MAP_PT.get(label_en, "nao_identificado")
44
+ return label_pt
45
 
46
  @app.get("/health")
47
  def health():
48
+ return {"ok": True, "model": MODEL_ID}
49
 
50
  @app.post("/predict")
51
  async def predict(request: Request):
52
+ """
53
+ Espera: imagem JPEG (application/octet-stream)
54
+ Retorna: texto puro - 'vidro', 'papel', 'plastico', 'metal' ou 'nao_identificado'
55
+ """
56
  try:
57
+ content_type = (request.headers.get("content-type") or "").lower()
58
+
59
+ # ESP32 envia como application/octet-stream
60
+ if "application/octet-stream" in content_type or "image/jpeg" in content_type:
61
+ img_bytes = await request.body()
62
+ else:
63
+ # fallback para JSON base64 (para testes)
64
+ data = await request.json()
65
+ import base64
66
+ b64 = data.get("image_b64", "").split(",")[-1]
67
+ img_bytes = base64.b64decode(b64) if b64 else b""
68
+
69
  if not img_bytes:
70
  return Response("nao_identificado", media_type="text/plain")
71
+
72
+ label = predict_image_bytes(img_bytes)
73
+
74
+ # filtro final — só 4 materiais
75
+ if label not in ALLOWED:
76
+ label = "nao_identificado"
77
+
78
  return Response(label, media_type="text/plain")
79
+
80
+ except Exception as e:
81
+ print("Erro:", e)
82
  return Response("nao_identificado", media_type="text/plain")