from fastapi import FastAPI from pydantic import BaseModel import torch from transformers import AutoTokenizer, AutoModelForImageClassification from PIL import Image import torchvision.transforms as T import requests from io import BytesIO app = FastAPI() # load model once model_name = "Falconsai/nsfw_image_detection" model = AutoModelForImageClassification.from_pretrained(model_name) tokenizer = AutoTokenizer.from_pretrained(model_name) transform = T.Compose([ T.Resize((224, 224)), T.ToTensor(), T.Normalize([0.5], [0.5]) ]) class ImageInput(BaseModel): url: str @app.get("/") def read_root(): return {"status": "running"} @app.post("/predict") def predict(input: ImageInput): img = Image.open(BytesIO(requests.get(input.url).content)).convert("RGB") img_tensor = transform(img).unsqueeze(0) with torch.no_grad(): logits = model(img_tensor).logits pred = torch.argmax(logits, dim=1).item() return {"class": int(pred)}