ALSv commited on
Commit
82ea080
·
verified ·
1 Parent(s): 5cc8910

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +70 -55
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import gradio as gr
2
  import torch
3
  import torch.nn as nn
@@ -6,12 +7,13 @@ from torchvision.models import resnet18
6
  from PIL import Image
7
  import base64
8
  import io
 
9
 
10
  # ---------------- CONFIG ----------------
11
  labels = ["Drawings", "Hentai", "Neutral", "Porn", "Sexy"]
12
  theme_color = "#6C5B7B"
13
 
14
- # ---------------- MODEL ----------------
15
  class Classifier(nn.Module):
16
  def __init__(self):
17
  super().__init__()
@@ -30,93 +32,106 @@ class Classifier(nn.Module):
30
  return x
31
 
32
  preprocess = transforms.Compose([
33
- transforms.Resize((224,224)),
34
  transforms.ToTensor(),
35
  transforms.Normalize(mean=[0.485,0.456,0.406],
36
- std=[0.229,0.224,0.225])
37
  ])
38
 
 
39
  model = Classifier()
40
  model.load_state_dict(torch.load("classify_nsfw_v3.0.pth", map_location="cpu"))
41
  model.eval()
42
 
43
- # ---------------- FUNZIONE ----------------
44
- def predict(image_input):
45
  """
46
- Supporta:
47
- - PIL Image (UI web)
48
- - stringa base64 (API)
49
  """
50
  try:
51
- if isinstance(image_input, str):
52
- if image_input.startswith("data:image"):
53
- image_input = image_input.split(",",1)[1]
54
- img_bytes = base64.b64decode(image_input)
55
- img = Image.open(io.BytesIO(img_bytes)).convert("RGB")
56
- else:
57
- img = image_input.convert("RGB")
58
 
59
- img_tensor = preprocess(img).unsqueeze(0)
 
 
 
 
 
 
 
 
 
 
60
 
 
 
61
  with torch.no_grad():
62
  logits = model(img_tensor)
63
  probs = torch.nn.functional.softmax(logits[0], dim=0)
64
 
65
  probs_dict = {labels[i]: float(probs[i]) for i in range(len(labels))}
66
  max_label = max(probs_dict, key=probs_dict.get)
67
-
68
  return max_label, probs_dict
69
 
70
- except Exception as e:
71
- return f"Error: {str(e)}", {}
72
-
73
- def clear_all():
74
- return "", ""
75
-
76
- # ---------------- INTERFACCIA ----------------
77
- with gr.Blocks(title="NSFW Image Classifier") as demo:
78
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  gr.HTML(f"""
80
- <div style="padding:10px; background:linear-gradient(135deg,#f8f9fa 0%,#e9ecef 100%); border-radius:10px;">
81
- <h2 style="color:{theme_color};">🎨 NSFW Image Classifier</h2>
82
- <p>Carica un'immagine o incolla la stringa base64 per analizzarla.</p>
83
  </div>
84
  """)
85
-
86
  with gr.Row():
87
  with gr.Column(scale=2):
88
- # Input UI
89
- img_input = gr.Image(label="📷 Carica immagine", type="pil")
90
- base64_input = gr.Textbox(
91
- label="📤 Base64 dell'immagine (API)",
92
- lines=6,
93
- placeholder="Incolla qui la stringa base64..."
94
- )
95
  with gr.Row():
96
- submit_btn = gr.Button("✨ Analizza", variant="primary")
97
- clear_btn = gr.Button("🔄 Pulisci", variant="secondary")
98
-
99
  with gr.Column(scale=1):
100
  label_output = gr.Textbox(label="Classe predetta", interactive=False)
101
  result_display = gr.Label(label="Distribuzione probabilità", num_top_classes=len(labels))
102
 
103
- # ---------------- Eventi UI ----------------
104
- submit_btn.click(
105
- fn=predict,
106
- inputs=[img_input],
107
- outputs=[label_output, result_display]
108
- )
109
- clear_btn.click(fn=clear_all, inputs=None, outputs=[img_input, base64_input])
110
-
111
- # ---------------- Pulsante invisibile per API base64 ----------------
112
- api_button = gr.Button(visible=False)
113
- api_button.click(
114
- fn=predict,
115
- inputs=[base64_input],
116
- outputs=[label_output, result_display],
117
- api_name="predict" # espone /run/predict
118
- )
119
 
120
  # ---------------- LAUNCH ----------------
121
  if __name__ == "__main__":
 
122
  demo.launch(server_name="0.0.0.0", server_port=7860, show_api=True)
 
 
1
+ # nsfw_app.py
2
  import gradio as gr
3
  import torch
4
  import torch.nn as nn
 
7
  from PIL import Image
8
  import base64
9
  import io
10
+ import traceback
11
 
12
  # ---------------- CONFIG ----------------
13
  labels = ["Drawings", "Hentai", "Neutral", "Porn", "Sexy"]
14
  theme_color = "#6C5B7B"
15
 
16
+ # ---------------- MODEL (stesso file pesi) ----------------
17
  class Classifier(nn.Module):
18
  def __init__(self):
19
  super().__init__()
 
32
  return x
33
 
34
  preprocess = transforms.Compose([
35
+ transforms.Resize((224, 224)),
36
  transforms.ToTensor(),
37
  transforms.Normalize(mean=[0.485,0.456,0.406],
38
+ std =[0.229,0.224,0.225])
39
  ])
40
 
41
+ # Carica pesi (stesso file che usavi)
42
  model = Classifier()
43
  model.load_state_dict(torch.load("classify_nsfw_v3.0.pth", map_location="cpu"))
44
  model.eval()
45
 
46
+ # ---------------- FUNZIONE UNICA predict (accetta SOLO base64) ----------------
47
+ def predict(base64_input: str):
48
  """
49
+ Unico input dell'API: stringa base64 (es. "data:image/jpeg;base64,...")
50
+ Ritorna: (label_str, {label:prob})
 
51
  """
52
  try:
53
+ if not base64_input or not isinstance(base64_input, str):
54
+ return "Input base64 mancante o non valido", {}
55
+
56
+ # rimuovi eventuale prefisso data:image...
57
+ if base64_input.startswith("data:image"):
58
+ base64_input = base64_input.split(",", 1)[1]
 
59
 
60
+ # decodifica base64
61
+ try:
62
+ img_bytes = base64.b64decode(base64_input)
63
+ except Exception as e:
64
+ return f"Errore decodifica base64: {e}", {}
65
+
66
+ # apri immagine
67
+ try:
68
+ img = Image.open(io.BytesIO(img_bytes)).convert("RGB")
69
+ except Exception as e:
70
+ return f"Errore apertura immagine: {e}", {}
71
 
72
+ # preprocess + inferenza
73
+ img_tensor = preprocess(img).unsqueeze(0) # 1x3x224x224
74
  with torch.no_grad():
75
  logits = model(img_tensor)
76
  probs = torch.nn.functional.softmax(logits[0], dim=0)
77
 
78
  probs_dict = {labels[i]: float(probs[i]) for i in range(len(labels))}
79
  max_label = max(probs_dict, key=probs_dict.get)
 
80
  return max_label, probs_dict
81
 
82
+ except Exception:
83
+ return f"Unhandled error:\n{traceback.format_exc()}", {}
 
 
 
 
 
 
84
 
85
+ # ---------------- Helper: convert image upload -> base64 ----------------
86
+ def image_to_base64(img: Image.Image):
87
+ """
88
+ Converte PIL image in data:image/jpeg;base64,...
89
+ (usato dall'UI: caricamento immagine -> si popola la textbox base64)
90
+ """
91
+ if img is None:
92
+ return ""
93
+ buffer = io.BytesIO()
94
+ img.save(buffer, format="JPEG", quality=90)
95
+ b64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
96
+ return "data:image/jpeg;base64," + b64
97
+
98
+ def clear_box():
99
+ return ""
100
+
101
+ # ---------------- UI (Blocks) ----------------
102
+ with gr.Blocks(title="NSFW Image Classifier (base64 single-input)"):
103
  gr.HTML(f"""
104
+ <div style="padding:12px; background:linear-gradient(135deg,#f8f9fa 0%,#e9ecef 100%); border-radius:8px;">
105
+ <h2 style="color:{theme_color}; margin:0;">🎨 NSFW Image Classifier</h2>
106
+ <p style="margin:6px 0 0 0;">Carica un'immagine oppure incolla la base64. L'API accetta solo base64.</p>
107
  </div>
108
  """)
 
109
  with gr.Row():
110
  with gr.Column(scale=2):
111
+ image_input = gr.Image(label="📷 Carica immagine (verrà convertita in base64)", type="pil")
112
+ base64_input = gr.Textbox(label="📤 Base64 (API) — unico input", lines=6,
113
+ placeholder="Incolla qui la stringa base64 (data:image/..;base64,...)")
 
 
 
 
114
  with gr.Row():
115
+ analyze_btn = gr.Button("✨ Analizza (usa la base64 sopra)")
116
+ clear_btn = gr.Button("🔄 Pulisci")
 
117
  with gr.Column(scale=1):
118
  label_output = gr.Textbox(label="Classe predetta", interactive=False)
119
  result_display = gr.Label(label="Distribuzione probabilità", num_top_classes=len(labels))
120
 
121
+ # quando carichi immagine -> converto e popolo la textbox con la base64
122
+ image_input.change(fn=image_to_base64, inputs=image_input, outputs=base64_input)
123
+
124
+ # quando la base64 cambia -> chiamo predict (questo espone automaticamente l'endpoint API
125
+ # che accetta solo la textbox base64; gradio mappa l'endpoint in /run/predict in locale)
126
+ base64_input.change(fn=predict, inputs=base64_input, outputs=[label_output, result_display], api_name="predict")
127
+
128
+ # pulsante per analizzare manualmente (usa la base64 contenuta nella textbox)
129
+ analyze_btn.click(fn=predict, inputs=base64_input, outputs=[label_output, result_display])
130
+
131
+ clear_btn.click(fn=clear_box, inputs=None, outputs=base64_input)
 
 
 
 
 
132
 
133
  # ---------------- LAUNCH ----------------
134
  if __name__ == "__main__":
135
+ # show_api=True per vedere il link "View API" nella UI (opzionale)
136
  demo.launch(server_name="0.0.0.0", server_port=7860, show_api=True)
137
+