froidhj commited on
Commit
587862b
·
verified ·
1 Parent(s): 4638890

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -0
app.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, Request, Response
2
+ from PIL import Image
3
+ import io, torch
4
+ from transformers import AutoImageProcessor, AutoModelForImageClassification
5
+
6
+ MODEL_ID = "prithivMLmods/Trash-Net"
7
+ PT_MAP = {
8
+ "plastic":"plastico", "paper":"papel", "glass":"vidro",
9
+ "metal":"metal", "cardboard":"papel", "trash":"nao_identificado"
10
+ }
11
+
12
+ processor = AutoImageProcessor.from_pretrained(MODEL_ID)
13
+ model = AutoModelForImageClassification.from_pretrained(MODEL_ID)
14
+ model.eval()
15
+
16
+ app = FastAPI()
17
+
18
+ def predict_bytes(img_bytes: bytes) -> str:
19
+ img = Image.open(io.BytesIO(img_bytes)).convert("RGB")
20
+ inputs = processor(images=img, return_tensors="pt")
21
+ with torch.no_grad():
22
+ logits = model(**inputs).logits
23
+ idx = int(logits.softmax(-1).argmax(-1))
24
+ label_en = model.config.id2label[idx].lower()
25
+ return PT_MAP.get(label_en, "nao_identificado")
26
+
27
+ @app.get("/health")
28
+ def health():
29
+ return {"ok": True}
30
+
31
+ @app.post("/predict")
32
+ async def predict(request: Request):
33
+ try:
34
+ img_bytes = await request.body()
35
+ if not img_bytes:
36
+ return Response("nao_identificado", media_type="text/plain")
37
+ label = predict_bytes(img_bytes)
38
+ return Response(label, media_type="text/plain")
39
+ except Exception:
40
+ return Response("nao_identificado", media_type="text/plain")