AIProject2025 commited on
Commit
cfc3a56
·
verified ·
1 Parent(s): 27af485

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -0
app.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import timm
4
+ from PIL import Image
5
+
6
+ model = timm.create_model("hf_hub:Marqo/nsfw-image-detection-384", pretrained=True).eval()
7
+ data_config = timm.data.resolve_model_data_config(model)
8
+ transforms = timm.data.create_transform(**data_config, is_training=False)
9
+ class_names = model.pretrained_cfg["label_names"]
10
+
11
+ @torch.inference_mode()
12
+ def predict(image: Image.Image):
13
+ tensor = transforms(image).unsqueeze(0)
14
+ probs = model(tensor).softmax(dim=-1).cpu().flatten()
15
+ top_id = int(probs.argmax())
16
+ top_label = class_names[top_id]
17
+ probs_dict = {class_names[i]: float(p) for i, p in enumerate(probs)}
18
+ return top_label, probs_dict
19
+
20
+ demo = gr.Interface(
21
+ fn=predict,
22
+ inputs=gr.Image(type="pil"),
23
+ outputs=[
24
+ gr.Label(label="Top prediction"),
25
+ gr.Label(label="All probabilities", num_top_classes=len(class_names)),
26
+ ],
27
+ title="NSFW Image Detection",
28
+ description="Drag & drop an image to see the predicted class",
29
+ )
30
+
31
+ if __name__ == "__main__":
32
+ demo.launch(show_error=True)