|
|
import torch |
|
|
import torch.nn as nn |
|
|
from torchvision import models |
|
|
|
|
|
from torchvision import transforms |
|
|
from PIL import Image |
|
|
from fastapi import FastAPI, HTTPException |
|
|
from pydantic import BaseModel |
|
|
from typing import List |
|
|
import requests |
|
|
from io import BytesIO |
|
|
import gradio as gr |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
DEVICE = torch.device("cpu") |
|
|
NUM_CLASSES = 3 |
|
|
IMG_SIZE = 224 |
|
|
MODEL_PATH = "resnet50_best_9838.pth" |
|
|
|
|
|
class_mapping = { |
|
|
0: "safe", |
|
|
1: "sexy", |
|
|
2: "violence" |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model = models.resnet50(weights=None) |
|
|
|
|
|
in_features = model.fc.in_features |
|
|
model.fc = nn.Sequential( |
|
|
nn.Dropout(p=0.5), |
|
|
nn.Linear(in_features, NUM_CLASSES) |
|
|
) |
|
|
|
|
|
|
|
|
state_dict = torch.load( |
|
|
MODEL_PATH, |
|
|
map_location="cpu" |
|
|
) |
|
|
model.load_state_dict(state_dict) |
|
|
|
|
|
model.to(DEVICE) |
|
|
model.eval() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
transform = transforms.Compose([ |
|
|
transforms.Resize((IMG_SIZE, IMG_SIZE)), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize( |
|
|
mean=[0.485, 0.456, 0.406], |
|
|
std=[0.229, 0.224, 0.225] |
|
|
) |
|
|
]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def predict_single(img: Image.Image): |
|
|
x = transform(img).unsqueeze(0) |
|
|
with torch.no_grad(): |
|
|
probs = torch.softmax(model(x), dim=1)[0] |
|
|
conf, pred = torch.max(probs, dim=0) |
|
|
return pred.item(), conf.item(), probs.tolist() |
|
|
|
|
|
|
|
|
def predict_batch(images: List[Image.Image]): |
|
|
xs = torch.stack([transform(img) for img in images]) |
|
|
with torch.no_grad(): |
|
|
probs = torch.softmax(model(xs), dim=1) |
|
|
confs, preds = torch.max(probs, dim=1) |
|
|
return preds.tolist(), confs.tolist(), probs.tolist() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app = FastAPI(title="SFW Image Moderation API") |
|
|
|
|
|
class ImageRequest(BaseModel): |
|
|
image_url: str |
|
|
|
|
|
class BatchImageRequest(BaseModel): |
|
|
image_urls: List[str] |
|
|
|
|
|
|
|
|
@app.post("/moderate") |
|
|
def moderate(req: ImageRequest): |
|
|
try: |
|
|
r = requests.get(req.image_url, timeout=5) |
|
|
r.raise_for_status() |
|
|
img = Image.open(BytesIO(r.content)).convert("RGB") |
|
|
except Exception: |
|
|
raise HTTPException(status_code=400, detail="Invalid image URL") |
|
|
|
|
|
pred, conf, probs = predict_single(img) |
|
|
label = class_mapping[pred] |
|
|
|
|
|
return { |
|
|
"allowed": label == "safe" and conf >= 0.6, |
|
|
"label": label, |
|
|
"confidence": round(conf, 4), |
|
|
"probabilities": { |
|
|
class_mapping[i]: round(probs[i], 4) |
|
|
for i in range(NUM_CLASSES) |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
@app.post("/moderate_batch") |
|
|
def moderate_batch(req: BatchImageRequest): |
|
|
if len(req.image_urls) > 16: |
|
|
raise HTTPException(400, "Max 16 images per request") |
|
|
|
|
|
images = [] |
|
|
urls = [] |
|
|
|
|
|
for url in req.image_urls: |
|
|
try: |
|
|
r = requests.get(url, timeout=5) |
|
|
r.raise_for_status() |
|
|
img = Image.open(BytesIO(r.content)).convert("RGB") |
|
|
images.append(img) |
|
|
urls.append(url) |
|
|
except Exception: |
|
|
continue |
|
|
|
|
|
if not images: |
|
|
raise HTTPException(400, "No valid images") |
|
|
|
|
|
preds, confs, probs_list = predict_batch(images) |
|
|
|
|
|
results = [] |
|
|
for url, pred, conf, probs in zip(urls, preds, confs, probs_list): |
|
|
label = class_mapping[pred] |
|
|
results.append({ |
|
|
"image_url": url, |
|
|
"allowed": label == "safe" and conf >= 0.6, |
|
|
"label": label, |
|
|
"confidence": round(conf, 4), |
|
|
"probabilities": { |
|
|
class_mapping[i]: round(probs[i], 4) |
|
|
for i in range(NUM_CLASSES) |
|
|
} |
|
|
}) |
|
|
|
|
|
return { |
|
|
"count": len(results), |
|
|
"results": results |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def ui_predict(image_url): |
|
|
try: |
|
|
r = requests.get(image_url, timeout=5) |
|
|
r.raise_for_status() |
|
|
img = Image.open(BytesIO(r.content)).convert("RGB") |
|
|
|
|
|
pred, conf, probs = predict_single(img) |
|
|
|
|
|
result = { |
|
|
"label": class_mapping[pred], |
|
|
"confidence": round(conf, 4), |
|
|
"safe": round(probs[0], 4), |
|
|
"sexy": round(probs[1], 4), |
|
|
"violence": round(probs[2], 4), |
|
|
} |
|
|
|
|
|
return img, result |
|
|
|
|
|
except Exception: |
|
|
return None, {"error": "Invalid image URL"} |
|
|
|
|
|
ui = gr.Interface( |
|
|
fn=ui_predict, |
|
|
inputs=gr.Textbox( |
|
|
label="Image URL", |
|
|
placeholder="Paste image URL here..." |
|
|
), |
|
|
outputs=[ |
|
|
gr.Image(label="Preview Image"), |
|
|
gr.JSON(label="Prediction") |
|
|
], |
|
|
title="SFW Image Moderation Demo", |
|
|
description="Demo UI. Backend should call API endpoints directly." |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app = gr.mount_gradio_app(app, ui, path="/") |
|
|
|