| | import gradio as gr |
| | import torch |
| | import torch.nn as nn |
| | from torchvision import models, transforms |
| | from PIL import Image |
| |
|
| | |
| | model_path = "nsfw_classifier.pkl" |
| | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| |
|
| | |
| | model = models.resnet18(weights=None) |
| | model.fc = nn.Linear(model.fc.in_features, 2) |
| |
|
| | |
| | model.load_state_dict(torch.load(model_path, map_location=device), strict=True) |
| | model.eval() |
| |
|
| | |
| | transform = transforms.Compose([ |
| | transforms.Resize((224, 224)), |
| | transforms.ToTensor(), |
| | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) |
| | ]) |
| |
|
| | |
| | def classify_image(image_path): |
| | try: |
| | |
| | image = Image.open(image_path).convert("RGB") |
| | image = transform(image).unsqueeze(0) |
| |
|
| | |
| | with torch.no_grad(): |
| | outputs = model(image.to(device)) |
| | _, predicted = outputs.max(1) |
| |
|
| | |
| | classes = ["Safe for Work", "Not Safe for Work (NSFW)"] |
| | return classes[predicted.item()] |
| | except Exception as e: |
| | return f"Error processing image: {e}" |
| |
|
| | |
| | interface = gr.Interface( |
| | fn=classify_image, |
| | inputs=gr.Image(type="filepath"), |
| | outputs="text", |
| | title="NSFW Image Classifier", |
| | description="Upload an image to classify whether it is Safe for Work (SFW) or Not Safe for Work (NSFW).", |
| | examples=[ |
| | ["example1.jpg"], |
| | ["example2.jpg"] |
| | ] |
| | ) |
| |
|
| | |
| | interface.launch() |
| |
|