Spaces:
Runtime error
Runtime error
File size: 983 Bytes
0067c59 3f59d69 0067c59 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 |
from fastapi import FastAPI
from pydantic import BaseModel
import torch
from transformers import AutoTokenizer, AutoModelForImageClassification
from PIL import Image
import torchvision.transforms as T
import requests
from io import BytesIO
app = FastAPI()
# load model once
model_name = "Falconsai/nsfw_image_detection"
model = AutoModelForImageClassification.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
transform = T.Compose([
T.Resize((224, 224)),
T.ToTensor(),
T.Normalize([0.5], [0.5])
])
class ImageInput(BaseModel):
url: str
@app.get("/")
def read_root():
return {"status": "running"}
@app.post("/predict")
def predict(input: ImageInput):
img = Image.open(BytesIO(requests.get(input.url).content)).convert("RGB")
img_tensor = transform(img).unsqueeze(0)
with torch.no_grad():
logits = model(img_tensor).logits
pred = torch.argmax(logits, dim=1).item()
return {"class": int(pred)}
|