nsfw-dect / app.py
atharv-16's picture
Update app.py
362c834 verified
import gradio as gr
import torch
import torch.nn as nn
from torchvision import models, transforms
from PIL import Image
# Load the trained model
model_path = "nsfw_classifier.pkl"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Define the model architecture
model = models.resnet18(weights=None)
model.fc = nn.Linear(model.fc.in_features, 2)
# Use `weights_only=True` for secure loading
model.load_state_dict(torch.load(model_path, map_location=device), strict=True)
model.eval()
# Transform for preprocessing the uploaded image
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) # ImageNet normalization
])
# Function to predict NSFW status
def classify_image(image_path):
try:
# Open and preprocess the image
image = Image.open(image_path).convert("RGB")
image = transform(image).unsqueeze(0) # Add batch dimension
# Make prediction
with torch.no_grad():
outputs = model(image.to(device))
_, predicted = outputs.max(1)
# Decode prediction
classes = ["Safe for Work", "Not Safe for Work (NSFW)"]
return classes[predicted.item()]
except Exception as e:
return f"Error processing image: {e}"
# Gradio interface
interface = gr.Interface(
fn=classify_image,
inputs=gr.Image(type="filepath"), # Use `filepath` to pass image path to the function
outputs="text",
title="NSFW Image Classifier",
description="Upload an image to classify whether it is Safe for Work (SFW) or Not Safe for Work (NSFW).",
examples=[
["example1.jpg"],
["example2.jpg"]
]
)
# Launch the app
interface.launch()