ALSv commited on
Commit
e34292b
·
verified ·
1 Parent(s): 1fad6d0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +95 -46
app.py CHANGED
@@ -1,57 +1,106 @@
1
  import gradio as gr
2
- from PIL import Image, ImageOps
 
 
 
 
3
  import base64
4
  import io
5
 
 
 
 
6
 
7
- # Funzione di predizione condivisa
8
- def predict(text, image_b64):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  """
10
- Elabora un testo e un'immagine codificata in base64.
11
- - Testo: invertito
12
- - Immagine: convertita in scala di grigi e ritrasformata in base64
13
  """
14
- # Elabora testo
15
- out_text = text[::-1] if text else ""
16
-
17
- # Elabora immagine base64
18
- out_b64 = None
19
- if image_b64:
20
- try:
21
- # Decodifica da base64 (gestisce anche stringhe tipo "data:image/png;base64,...")
22
- img_bytes = base64.b64decode(image_b64.split(",")[-1])
23
- img = Image.open(io.BytesIO(img_bytes)).convert("RGB")
24
-
25
- # Conversione in bianco e nero
26
- img = ImageOps.grayscale(img)
27
-
28
- # Ricodifica in base64
29
- buffer = io.BytesIO()
30
- img.save(buffer, format="PNG")
31
- out_b64 = "data:image/png;base64," + base64.b64encode(buffer.getvalue()).decode("utf-8")
32
- except Exception as e:
33
- out_text += f"\n[Errore immagine: {e}]"
34
-
35
- return out_text, out_b64
36
-
37
-
38
- # Interfaccia Gradio
39
- demo = gr.Interface(
40
- fn=predict,
41
- inputs=[
42
- gr.Textbox(label="Testo"),
43
- gr.Textbox(label="Immagine in base64")
44
- ],
45
- outputs=[
46
- gr.Textbox(label="Testo elaborato"),
47
- gr.Textbox(label="Immagine in base64")
48
- ],
49
- title="Demo Base64 API Gradio 3.5",
50
- description="Input: testo + immagine base64 → Output: testo invertito + immagine base64 in scala di grigi"
51
- )
52
 
 
53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  if __name__ == "__main__":
55
- # In Gradio 3.5 le API sono esposte automaticamente su /api/predict
56
- demo.launch(server_name="0.0.0.0", server_port=7860)
57
 
 
1
  import gradio as gr
2
+ import torch
3
+ import torch.nn as nn
4
+ from torchvision import transforms
5
+ from torchvision.models import resnet18
6
+ from PIL import Image
7
  import base64
8
  import io
9
 
10
+ # ---------------- CONFIG ----------------
11
+ labels = ["Drawings", "Hentai", "Neutral", "Porn", "Sexy"]
12
+ theme_color = "#6C5B7B"
13
 
14
+ # ---------------- MODEL ----------------
15
+ class Classifier(nn.Module):
16
+ def __init__(self):
17
+ super().__init__()
18
+ self.cnn_layers = resnet18(weights=None)
19
+ self.fc_layers = nn.Sequential(
20
+ nn.Linear(1000, 512),
21
+ nn.Dropout(0.3),
22
+ nn.Linear(512, 128),
23
+ nn.ReLU(),
24
+ nn.Linear(128, 5)
25
+ )
26
+
27
+ def forward(self, x):
28
+ x = self.cnn_layers(x)
29
+ x = self.fc_layers(x)
30
+ return x
31
+
32
+ preprocess = transforms.Compose([
33
+ transforms.Resize((224,224)),
34
+ transforms.ToTensor(),
35
+ transforms.Normalize(mean=[0.485,0.456,0.406],
36
+ std=[0.229,0.224,0.225])
37
+ ])
38
+
39
+ model = Classifier()
40
+ model.load_state_dict(torch.load("classify_nsfw_v3.0.pth", map_location="cpu"))
41
+ model.eval()
42
+
43
+ # ---------------- FUNZIONE ----------------
44
+ def predict(base64_image):
45
  """
46
+ Input: immagine base64
47
+ Output: label più probabile + probabilità per tutte le classi
 
48
  """
49
+ try:
50
+ if base64_image.startswith("data:image"):
51
+ base64_image = base64_image.split(",",1)[1]
52
+ img_bytes = base64.b64decode(base64_image)
53
+ img = Image.open(io.BytesIO(img_bytes)).convert("RGB")
54
+ img_tensor = preprocess(img).unsqueeze(0)
55
+
56
+ with torch.no_grad():
57
+ pred = torch.nn.functional.softmax(model(img_tensor)[0], dim=0)
58
+
59
+ max_label = labels[torch.argmax(pred).item()]
60
+ probs = {labels[i]: float(pred[i]) for i in range(len(labels))}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
+ return max_label, probs
63
 
64
+ except Exception as e:
65
+ return f"Error: {str(e)}", {}
66
+
67
+ def clear_all():
68
+ return ""
69
+
70
+ # ---------------- INTERFACCIA ----------------
71
+ with gr.Blocks(title="NSFW Classifier") as demo:
72
+ gr.HTML(f"""
73
+ <div style="padding:10px; background:linear-gradient(135deg,#f8f9fa 0%,#e9ecef 100%); border-radius:10px;">
74
+ <h2 style="color:{theme_color};">🎨 NSFW Image Classifier</h2>
75
+ <p>Incolla qui l'immagine in base64 per analizzarla.</p>
76
+ </div>
77
+ """)
78
+
79
+ with gr.Row():
80
+ with gr.Column(scale=2):
81
+ base64_input = gr.Textbox(
82
+ label="📤 Base64 dell'immagine",
83
+ lines=6,
84
+ placeholder="Incolla qui la stringa base64 dell'immagine..."
85
+ )
86
+ with gr.Row():
87
+ submit_btn = gr.Button("✨ Analizza", variant="primary")
88
+ clear_btn = gr.Button("🔄 Pulisci", variant="secondary")
89
+ with gr.Column(scale=1):
90
+ label_output = gr.Textbox(label="Classe predetta", interactive=False)
91
+ result_display = gr.Label(label="Distribuzione probabilità", num_top_classes=5)
92
+
93
+ # Eventi
94
+ submit_btn.click(
95
+ fn=predict,
96
+ inputs=base64_input,
97
+ outputs=[label_output, result_display],
98
+ api_name="predict" # <- espone /run/predict
99
+ )
100
+ clear_btn.click(fn=clear_all, inputs=None, outputs=base64_input)
101
+
102
+ # ---------------- LAUNCH ----------------
103
  if __name__ == "__main__":
104
+ demo.launch(server_name="0.0.0.0", server_port=7860, show_api=True)
105
+
106