garbage-api / app.py
mohamedtsou's picture
Create app.py
227a247 verified
from fastapi import FastAPI, File, UploadFile
from transformers import AutoImageProcessor, AutoModelForImageClassification
from PIL import Image
import torch, io, os, uvicorn
app = FastAPI()
MODEL_NAME = "yangy50/garbage-classification"
processor = AutoImageProcessor.from_pretrained(MODEL_NAME)
model = AutoModelForImageClassification.from_pretrained(MODEL_NAME)
model.eval()
@app.get("/")
def root():
return {"status": "ok"}
@app.post("/predict")
async def predict(file: UploadFile = File(...)):
image = Image.open(io.BytesIO(await file.read())).convert("RGB")
inputs = processor(images=image, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
probs = torch.softmax(outputs.logits, dim=1)[0]
return {
model.config.id2label[i]: float(probs[i])
for i in range(len(probs))
}
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=int(os.environ.get("PORT", 7860)))