dgapiv2 / app.py
Hayloo9838's picture
Update app.py
3f59d69 verified
from fastapi import FastAPI
from pydantic import BaseModel
import torch
from transformers import AutoTokenizer, AutoModelForImageClassification
from PIL import Image
import torchvision.transforms as T
import requests
from io import BytesIO
app = FastAPI()
# load model once
model_name = "Falconsai/nsfw_image_detection"
model = AutoModelForImageClassification.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
transform = T.Compose([
T.Resize((224, 224)),
T.ToTensor(),
T.Normalize([0.5], [0.5])
])
class ImageInput(BaseModel):
url: str
@app.get("/")
def read_root():
return {"status": "running"}
@app.post("/predict")
def predict(input: ImageInput):
img = Image.open(BytesIO(requests.get(input.url).content)).convert("RGB")
img_tensor = transform(img).unsqueeze(0)
with torch.no_grad():
logits = model(img_tensor).logits
pred = torch.argmax(logits, dim=1).item()
return {"class": int(pred)}