Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| import timm | |
| from PIL import Image | |
| model = timm.create_model("hf_hub:Marqo/nsfw-image-detection-384", pretrained=True).eval() | |
| data_config = timm.data.resolve_model_data_config(model) | |
| transforms = timm.data.create_transform(**data_config, is_training=False) | |
| class_names = model.pretrained_cfg["label_names"] | |
| def predict(image: Image.Image): | |
| tensor = transforms(image).unsqueeze(0) | |
| probs = model(tensor).softmax(dim=-1).cpu().flatten() | |
| top_id = int(probs.argmax()) | |
| top_label = class_names[top_id] | |
| probs_dict = {class_names[i]: float(p) for i, p in enumerate(probs)} | |
| return top_label, probs_dict | |
| demo = gr.Interface( | |
| fn=predict, | |
| inputs=gr.Image(type="pil"), | |
| outputs=[ | |
| gr.Label(label="Top prediction"), | |
| gr.Label(label="All probabilities", num_top_classes=len(class_names)), | |
| ], | |
| title="NSFW Image Detection", | |
| description="Drag & drop an image to see the predicted class", | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(error=true) |