mohamedtsou commited on
Commit
227a247
·
verified ·
1 Parent(s): 3e40551

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -0
app.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, File, UploadFile
2
+ from transformers import AutoImageProcessor, AutoModelForImageClassification
3
+ from PIL import Image
4
+ import torch, io, os, uvicorn
5
+
6
+ app = FastAPI()
7
+
8
+ MODEL_NAME = "yangy50/garbage-classification"
9
+
10
+ processor = AutoImageProcessor.from_pretrained(MODEL_NAME)
11
+ model = AutoModelForImageClassification.from_pretrained(MODEL_NAME)
12
+ model.eval()
13
+
14
+ @app.get("/")
15
+ def root():
16
+ return {"status": "ok"}
17
+
18
+ @app.post("/predict")
19
+ async def predict(file: UploadFile = File(...)):
20
+ image = Image.open(io.BytesIO(await file.read())).convert("RGB")
21
+ inputs = processor(images=image, return_tensors="pt")
22
+
23
+ with torch.no_grad():
24
+ outputs = model(**inputs)
25
+
26
+ probs = torch.softmax(outputs.logits, dim=1)[0]
27
+ return {
28
+ model.config.id2label[i]: float(probs[i])
29
+ for i in range(len(probs))
30
+ }
31
+
32
+ if __name__ == "__main__":
33
+ uvicorn.run(app, host="0.0.0.0", port=int(os.environ.get("PORT", 7860)))