Spaces:
Runtime error
Runtime error
| from fastapi import FastAPI | |
| from pydantic import BaseModel | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForImageClassification | |
| from PIL import Image | |
| import torchvision.transforms as T | |
| import requests | |
| from io import BytesIO | |
| app = FastAPI() | |
| # load model once | |
| model_name = "Falconsai/nsfw_image_detection" | |
| model = AutoModelForImageClassification.from_pretrained(model_name) | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| transform = T.Compose([ | |
| T.Resize((224, 224)), | |
| T.ToTensor(), | |
| T.Normalize([0.5], [0.5]) | |
| ]) | |
| class ImageInput(BaseModel): | |
| url: str | |
| def read_root(): | |
| return {"status": "running"} | |
| def predict(input: ImageInput): | |
| img = Image.open(BytesIO(requests.get(input.url).content)).convert("RGB") | |
| img_tensor = transform(img).unsqueeze(0) | |
| with torch.no_grad(): | |
| logits = model(img_tensor).logits | |
| pred = torch.argmax(logits, dim=1).item() | |
| return {"class": int(pred)} | |