File size: 983 Bytes
0067c59
 
 
 
 
 
 
 
 
 
 
 
3f59d69
0067c59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
from fastapi import FastAPI
from pydantic import BaseModel
import torch
from transformers import AutoTokenizer, AutoModelForImageClassification
from PIL import Image
import torchvision.transforms as T
import requests
from io import BytesIO

app = FastAPI()

# load model once
model_name = "Falconsai/nsfw_image_detection"
model = AutoModelForImageClassification.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
transform = T.Compose([
    T.Resize((224, 224)),
    T.ToTensor(),
    T.Normalize([0.5], [0.5])
])

class ImageInput(BaseModel):
    url: str

@app.get("/")
def read_root():
    return {"status": "running"}

@app.post("/predict")
def predict(input: ImageInput):
    img = Image.open(BytesIO(requests.get(input.url).content)).convert("RGB")
    img_tensor = transform(img).unsqueeze(0)

    with torch.no_grad():
        logits = model(img_tensor).logits
        pred = torch.argmax(logits, dim=1).item()

    return {"class": int(pred)}