vagrillo commited on
Commit
c69b41a
·
verified ·
1 Parent(s): 694eeaf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +214 -310
app.py CHANGED
@@ -1,35 +1,19 @@
1
- from flask import Flask, session, request, redirect, url_for, render_template_string, send_file
2
- import datetime
3
- import os
4
- import secrets
5
  import torch
6
- from PIL import Image, ImageDraw
7
  from transformers import GroundingDinoProcessor
8
  from modeling_grounding_dino import GroundingDinoForObjectDetection
 
 
9
  from itertools import cycle
 
 
 
10
  import tempfile
11
- import io
12
-
13
- app = Flask(__name__)
14
- app.secret_key = os.environ.get('SECRET_KEY', secrets.token_hex(16))
15
- SECRET_PASSWORD = "VeronaTrento25!"
16
- app.permanent_session_lifetime = datetime.timedelta(hours=24)
17
-
18
- # ===== AUTHENTICATION FUNCTIONS =====
19
- def is_authenticated():
20
- return session.get('authenticated', False)
21
-
22
- def require_auth(f):
23
- def decorated_function(*args, **kwargs):
24
- if not is_authenticated():
25
- return redirect(url_for('login'))
26
- return f(*args, **kwargs)
27
- decorated_function.__name__ = f.__name__
28
- return decorated_function
29
-
30
- # ===== ML MODEL SETUP =====
31
- DEVICE = "cpu"
32
  model_id = "fushh7/llmdet_swin_tiny_hf"
 
33
 
34
  print(f"[INFO] Using device: {DEVICE}")
35
  print(f"[INFO] Loading model from {model_id}...")
@@ -40,346 +24,266 @@ model.eval()
40
 
41
  print("[INFO] Model loaded successfully.")
42
 
43
- # Pre-defined palette
44
  BOX_COLORS = [
45
  "deepskyblue", "red", "lime", "dodgerblue",
46
- "cyan", "magenta", "yellow", "orange", "chartreuse"
 
47
  ]
48
 
49
- # ===== ML FUNCTIONS =====
50
  def save_cropped_images(original_image, boxes, labels, scores):
 
 
 
 
 
 
 
 
 
51
  saved_paths = []
 
52
  for i, (box, label, score) in enumerate(zip(boxes, labels, scores)):
 
53
  with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp_file:
54
  filepath = tmp_file.name
 
 
55
  cropped_img = original_image.crop(box)
 
 
56
  cropped_img.save(filepath)
57
  saved_paths.append(filepath)
 
58
  return saved_paths
59
 
60
- def draw_boxes(image, boxes, labels, scores, colors=BOX_COLORS, font_size=16):
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  colour_cycle = cycle(colors)
62
  draw = ImageDraw.Draw(image)
63
-
 
64
  try:
65
- font = ImageFont.truetype("arial.ttf", size=font_size)
66
- except:
67
- font = ImageFont.load_default()
68
-
 
69
  label_to_colour = {}
70
-
71
  for box, label, score in zip(boxes, labels, scores):
 
72
  colour = label_to_colour.setdefault(label, next(colour_cycle))
 
73
  x_min, y_min, x_max, y_max = map(int, box)
74
-
 
75
  draw.rectangle([x_min, y_min, x_max, y_max], outline=colour, width=2)
 
 
76
  text = f"{label} ({score:.3f})"
77
- text_bbox = draw.textbbox((0, 0), text, font=font)
78
- text_width = text_bbox[2] - text_bbox[0]
79
- text_height = text_bbox[3] - text_bbox[1]
80
-
81
- bg_coords = [x_min, y_min - text_height - 4, x_min + text_width + 4, y_min]
82
  draw.rectangle(bg_coords, fill=colour)
83
- draw.text((x_min + 2, y_min - text_height - 2), text, fill="black", font=font)
84
-
 
 
 
85
  return image
86
 
87
- def resize_image_max_dimension(image, max_size=1024):
 
 
 
 
 
 
 
 
88
  width, height = image.size
 
 
89
  if max(width, height) <= max_size:
90
  return image
 
 
91
  ratio = max_size / max(width, height)
92
  new_width = int(width * ratio)
93
  new_height = int(height * ratio)
 
 
94
  return image.resize((new_width, new_height), Image.Resampling.LANCZOS)
95
 
96
- def detect_and_draw(img, text_query, box_threshold=0.14, text_threshold=0.13):
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  text_query = text_query.lower()
98
- img = resize_image_max_dimension(img, max_size=1024)
99
-
 
 
 
100
  inputs = processor(images=img, text=text_query, return_tensors="pt").to(DEVICE)
101
-
102
  with torch.no_grad():
103
  outputs = model(**inputs)
104
-
105
  results = processor.post_process_grounded_object_detection(
106
  outputs,
107
  inputs.input_ids,
 
108
  text_threshold=text_threshold,
109
  target_sizes=[img.size[::-1]]
110
  )[0]
111
-
112
  img_out = img.copy()
113
  img_out = draw_boxes(
114
  img_out,
115
- boxes=results["boxes"].cpu().numpy(),
116
- labels=results.get("text_labels", results.get("labels", [])),
117
- scores=results["scores"]
118
  )
119
 
120
- crop_paths = save_cropped_images(
121
- img,
122
- boxes=results["boxes"].cpu().numpy(),
123
- labels=results.get("text_labels", results.get("labels", [])),
124
- scores=results["scores"]
125
- )
126
 
 
 
 
 
 
 
 
 
 
127
  return img_out, crop_paths
128
 
129
- # ===== FLASK ROUTES =====
130
- @app.route('/')
131
- #@require_auth
132
- def index():
133
- return render_template_string('''
134
- <!DOCTYPE html>
135
- <html>
136
- <head>
137
- <title>Student Finder - Protetto</title>
138
- <style>
139
- body { font-family: Arial, sans-serif; max-width: 1200px; margin: 0 auto; padding: 20px; }
140
- .header { background: #e8f5e8; padding: 20px; border-radius: 10px; margin-bottom: 20px; }
141
- .content { background: #f5f5f5; padding: 30px; border-radius: 10px; }
142
- .form-group { margin-bottom: 15px; }
143
- label { display: block; margin-bottom: 5px; font-weight: bold; }
144
- input, textarea, select { width: 100%; padding: 8px; border: 1px solid #ddd; border-radius: 4px; }
145
- button { background: #007bff; color: white; padding: 10px 20px; border: none; border-radius: 4px; cursor: pointer; }
146
- button:hover { background: #0056b3; }
147
- .logout { float: right; }
148
- .results { margin-top: 20px; }
149
- .gallery { display: grid; grid-template-columns: repeat(auto-fit, minmax(200px, 1fr)); gap: 10px; margin-top: 20px; }
150
- .gallery img { max-width: 100%; height: auto; border: 1px solid #ddd; border-radius: 4px; }
151
- </style>
152
- </head>
153
- <body>
154
- <div class="header">
155
- <h1>🎓 Student Finder</h1>
156
- <p>Carica una foto di classe e trova gli studenti</p>
157
- <a href="/logout" class="logout">🔓 Logout</a>
158
- <div style="clear: both;"></div>
159
- </div>
160
-
161
- <div class="content">
162
- <form method="post" enctype="multipart/form-data" action="/detect">
163
- <div class="form-group">
164
- <label for="image">Immagine:</label>
165
- <input type="file" id="image" name="image" accept="image/*" required>
166
- </div>
167
-
168
- <div class="form-group">
169
- <label for="text_query">Text Query:</label>
170
- <textarea id="text_query" name="text_query" rows="2" required>heads.</textarea>
171
- <small>Testo in lowercase, ogni concetto termina con '.' (es. 'heads. faces.')</small>
172
- </div>
173
-
174
- <div class="form-group">
175
- <label for="box_threshold">Box Threshold ({{ box_threshold }}):</label>
176
- <input type="range" id="box_threshold" name="box_threshold" min="0" max="1" step="0.05" value="0.14">
177
- </div>
178
-
179
- <div class="form-group">
180
- <label for="text_threshold">Text Threshold ({{ text_threshold }}):</label>
181
- <input type="range" id="text_threshold" name="text_threshold" min="0" max="1" step="0.05" value="0.13">
182
- </div>
183
-
184
- <button type="submit">🔍 Rileva Studenti</button>
185
- </form>
186
-
187
- {% if result_image %}
188
- <div class="results">
189
- <h3>Risultati:</h3>
190
- <img src="data:image/jpeg;base64,{{ result_image }}" alt="Risultato" style="max-width: 100%;">
191
-
192
- {% if crops %}
193
- <h4>Ritagli individuati ({{ crops|length }}):</h4>
194
- <div class="gallery">
195
- {% for crop in crops %}
196
- <img src="data:image/jpeg;base64,{{ crop }}" alt="Ritaglio {{ loop.index }}">
197
- {% endfor %}
198
- </div>
199
- {% endif %}
200
- </div>
201
- {% endif %}
202
- </div>
203
- </body>
204
- </html>
205
- ''', box_threshold=0.14, text_threshold=0.13)
206
-
207
- @app.route('/detect', methods=['POST'])
208
- @require_auth
209
- def detect():
210
- if 'image' not in request.files:
211
- return redirect(url_for('index'))
212
 
213
- image_file = request.files['image']
214
- if image_file.filename == '':
215
- return redirect(url_for('index'))
 
216
 
217
- try:
218
- # Process image
219
- image = Image.open(image_file.stream).convert('RGB')
220
- text_query = request.form.get('text_query', 'heads.')
221
- box_threshold = float(request.form.get('box_threshold', 0.14))
222
- text_threshold = float(request.form.get('text_threshold', 0.13))
223
-
224
- # Run detection
225
- result_image, crop_paths = detect_and_draw(image, text_query, box_threshold, text_threshold)
226
-
227
- # Convert images to base64 for display
228
- import base64
229
-
230
- # Convert result image to base64
231
- img_buffer = io.BytesIO()
232
- result_image.save(img_buffer, format='JPEG')
233
- result_b64 = base64.b64encode(img_buffer.getvalue()).decode()
234
-
235
- # Convert crops to base64
236
- crops_b64 = []
237
- for crop_path in crop_paths:
238
- with open(crop_path, 'rb') as f:
239
- crop_b64 = base64.b64encode(f.read()).decode()
240
- crops_b64.append(crop_b64)
241
- # Cleanup temp file
242
- os.unlink(crop_path)
243
-
244
- return render_template_string('''
245
- <!DOCTYPE html>
246
- <html>
247
- <head>
248
- <title>Risultati - Student Finder</title>
249
- <style>
250
- body { font-family: Arial, sans-serif; max-width: 1200px; margin: 0 auto; padding: 20px; }
251
- .header { background: #e8f5e8; padding: 20px; border-radius: 10px; margin-bottom: 20px; }
252
- .content { background: #f5f5f5; padding: 30px; border-radius: 10px; }
253
- .logout { float: right; }
254
- .gallery { display: grid; grid-template-columns: repeat(auto-fit, minmax(200px, 1fr)); gap: 10px; margin-top: 20px; }
255
- .gallery img { max-width: 100%; height: auto; border: 1px solid #ddd; border-radius: 4px; }
256
- .back-btn { background: #6c757d; color: white; padding: 10px 20px; border: none; border-radius: 4px; cursor: pointer; text-decoration: none; display: inline-block; margin-bottom: 20px; }
257
- .back-btn:hover { background: #545b62; }
258
- </style>
259
- </head>
260
- <body>
261
- <div class="header">
262
- <h1>🎓 Risultati Student Finder</h1>
263
- <a href="/logout" class="logout">🔓 Logout</a>
264
- <div style="clear: both;"></div>
265
- </div>
266
-
267
- <a href="/" class="back-btn">← Nuova Analisi</a>
268
-
269
- <div class="content">
270
- <h3>Immagine con bounding box:</h3>
271
- <img src="data:image/jpeg;base64,{{ result_image }}" alt="Risultato" style="max-width: 100%; border: 1px solid #ddd; border-radius: 4px;">
272
-
273
- {% if crops %}
274
- <h3>Ritagli individuati ({{ crops|length }}):</h3>
275
- <div class="gallery">
276
- {% for crop in crops %}
277
- <img src="data:image/jpeg;base64,{{ crop }}" alt="Ritaglio {{ loop.index }}">
278
- {% endfor %}
279
- </div>
280
- {% else %}
281
- <p>Nessun ritaglio individuato.</p>
282
- {% endif %}
283
- </div>
284
- </body>
285
- </html>
286
- ''', result_image=result_b64, crops=crops_b64)
287
-
288
- except Exception as e:
289
- return f"Errore durante l'elaborazione: {str(e)}", 500
290
 
291
- @app.route('/login', methods=['GET', 'POST'])
292
- def login():
293
- if is_authenticated():
294
- return redirect(url_for('index'))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
295
 
296
- error = None
297
- if request.method == 'POST':
298
- if request.form.get('password') == SECRET_PASSWORD:
299
- session.permanent = True
300
- session['authenticated'] = True
301
- return redirect(url_for('index'))
302
- else:
303
- error = "❌ Password errata. Riprova."
304
 
305
- return render_template_string('''
306
- <!DOCTYPE html>
307
- <html>
308
- <head>
309
- <title>Login - Student Finder</title>
310
- <style>
311
- body {
312
- font-family: Arial, sans-serif;
313
- max-width: 400px;
314
- margin: 100px auto;
315
- padding: 20px;
316
- background: #f5f5f5;
317
- }
318
- .login-form {
319
- background: white;
320
- padding: 30px;
321
- border-radius: 10px;
322
- box-shadow: 0 2px 10px rgba(0,0,0,0.1);
323
- }
324
- h2 {
325
- color: #333;
326
- text-align: center;
327
- margin-bottom: 20px;
328
- }
329
- input[type="password"] {
330
- width: 100%;
331
- padding: 12px;
332
- margin: 15px 0;
333
- border: 1px solid #ddd;
334
- border-radius: 5px;
335
- box-sizing: border-box;
336
- font-size: 16px;
337
- }
338
- button {
339
- background: #007bff;
340
- color: white;
341
- padding: 12px 20px;
342
- border: none;
343
- border-radius: 5px;
344
- cursor: pointer;
345
- width: 100%;
346
- font-size: 16px;
347
- }
348
- button:hover {
349
- background: #0056b3;
350
- }
351
- .error {
352
- color: red;
353
- margin-bottom: 15px;
354
- text-align: center;
355
- padding: 10px;
356
- background: #ffe6e6;
357
- border-radius: 5px;
358
- }
359
- </style>
360
- </head>
361
- <body>
362
- <div class="login-form">
363
- <h2>🔒 Student Finder - Accesso Protetto</h2>
364
- <p style="text-align: center; color: #666;">Inserisci la password per accedere</p>
365
- {% if error %}
366
- <div class="error">{{ error }}</div>
367
- {% endif %}
368
- <form method="POST">
369
- <input type="password" name="password" placeholder="Password" required>
370
- <button type="submit">🔑 Accedi</button>
371
- </form>
372
- </div>
373
- </body>
374
- </html>
375
- ''', error=error)
376
-
377
- @app.route('/logout')
378
- def logout():
379
- session.clear()
380
- return redirect(url_for('login'))
381
 
 
 
 
 
 
 
382
 
383
- if __name__ == '__main__':
384
- port = int(os.environ.get('PORT', 7860))
385
- app.run(host='0.0.0.0', port=port, debug=False)
 
 
 
 
 
1
  import torch
2
+ from PIL import Image, ImageDraw, ImageFont
3
  from transformers import GroundingDinoProcessor
4
  from modeling_grounding_dino import GroundingDinoForObjectDetection
5
+
6
+ from PIL import Image, ImageDraw, ImageFont
7
  from itertools import cycle
8
+ import os
9
+ from datetime import datetime
10
+ import gradio as gr
11
  import tempfile
12
+
13
+ # Load model and processor
14
+ model_id = "fushh7/llmdet_swin_large_hf"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  model_id = "fushh7/llmdet_swin_tiny_hf"
16
+ DEVICE = "cpu"
17
 
18
  print(f"[INFO] Using device: {DEVICE}")
19
  print(f"[INFO] Loading model from {model_id}...")
 
24
 
25
  print("[INFO] Model loaded successfully.")
26
 
27
+ # Pre-defined palette (extend or tweak as you like)
28
  BOX_COLORS = [
29
  "deepskyblue", "red", "lime", "dodgerblue",
30
+ "cyan", "magenta", "yellow",
31
+ "orange", "chartreuse"
32
  ]
33
 
 
34
  def save_cropped_images(original_image, boxes, labels, scores):
35
+ """
36
+ Salva ogni regione ritagliata definita dalle bounding box in file temporanei.
37
+
38
+ :param original_image: Immagine PIL originale
39
+ :param boxes: Lista di bounding box [x_min, y_min, x_max, y_max]
40
+ :param labels: Lista di etichette per ogni box
41
+ :param scores: Lista di punteggi di confidenza
42
+ :return: Lista dei percorsi dei file temporanei salvati
43
+ """
44
  saved_paths = []
45
+
46
  for i, (box, label, score) in enumerate(zip(boxes, labels, scores)):
47
+ # Crea un file temporaneo
48
  with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp_file:
49
  filepath = tmp_file.name
50
+
51
+ # Ritaglia la regione dall'immagine originale
52
  cropped_img = original_image.crop(box)
53
+
54
+ # Salva l'immagine ritagliata
55
  cropped_img.save(filepath)
56
  saved_paths.append(filepath)
57
+
58
  return saved_paths
59
 
60
+ def draw_boxes(image, boxes, labels, scores, colors=BOX_COLORS, font_path="arial.ttf", font_size=16):
61
+ """
62
+ Draw bounding boxes and labels on a PIL Image.
63
+
64
+ :param image: PIL Image object
65
+ :param boxes: Iterable of [x_min, y_min, x_max, y_max]
66
+ :param labels: Iterable of label strings
67
+ :param scores: Iterable of scalar confidences (0-1)
68
+ :param colors: List/tuple of colour names or RGB tuples
69
+ :param font_path: Path to a TTF font for labels
70
+ :param font_size: Int size of font to use, default 16
71
+ :return: PIL Image with drawn boxes
72
+ """
73
+ # Ensure we can iterate colours indefinitely
74
  colour_cycle = cycle(colors)
75
  draw = ImageDraw.Draw(image)
76
+
77
+ # Pick a font (fallback to default if missing)
78
  try:
79
+ font = ImageFont.truetype(font_path, size=font_size)
80
+ except IOError:
81
+ font = ImageFont.load_default(size=font_size)
82
+
83
+ # Assign a consistent colour per label (optional)
84
  label_to_colour = {}
85
+
86
  for box, label, score in zip(boxes, labels, scores):
87
+ # Reuse colour if label seen before, else take next from cycle
88
  colour = label_to_colour.setdefault(label, next(colour_cycle))
89
+
90
  x_min, y_min, x_max, y_max = map(int, box)
91
+
92
+ # Draw rectangle
93
  draw.rectangle([x_min, y_min, x_max, y_max], outline=colour, width=2)
94
+
95
+ # Compose text
96
  text = f"{label} ({score:.3f})"
97
+ text_size = draw.textbbox((0, 0), text, font=font)[2:]
98
+
99
+ # Draw text background for legibility
100
+ bg_coords = [x_min, y_min - text_size[1] - 4,
101
+ x_min + text_size[0] + 4, y_min]
102
  draw.rectangle(bg_coords, fill=colour)
103
+
104
+ # Draw text
105
+ draw.text((x_min + 2, y_min - text_size[1] - 2),
106
+ text, fill="black", font=font)
107
+
108
  return image
109
 
110
+ def resize_image_max_dimension(image, max_size=4096):
111
+ """
112
+ Resize an image so that the longest side is at most max_size pixels,
113
+ while maintaining the aspect ratio.
114
+
115
+ :param image: PIL Image object
116
+ :param max_size: Maximum dimension in pixels (default: 1024)
117
+ :return: PIL Image object (resized)
118
+ """
119
  width, height = image.size
120
+
121
+ # Check if resizing is needed
122
  if max(width, height) <= max_size:
123
  return image
124
+
125
+ # Calculate new dimensions maintaining aspect ratio
126
  ratio = max_size / max(width, height)
127
  new_width = int(width * ratio)
128
  new_height = int(height * ratio)
129
+
130
+ # Resize the image using high-quality resampling
131
  return image.resize((new_width, new_height), Image.Resampling.LANCZOS)
132
 
133
+ def detect_and_draw(
134
+ img: Image.Image,
135
+ text_query: str,
136
+ box_threshold: float = 0.14,
137
+ text_threshold: float = 0.13,
138
+ save_crops: bool = True
139
+ ):
140
+ """
141
+ Detect objects described in `text_query`, draw boxes, return the image and crops.
142
+ Note: `text_query` must be lowercase and each concept ends with a dot
143
+ (e.g. 'a cat. a remote control.')
144
+ """
145
+
146
+ # Make sure text is lowered
147
  text_query = text_query.lower()
148
+
149
+ # If the image size is too large, we make it smaller
150
+ img = resize_image_max_dimension(img, max_size=4096)
151
+
152
+ # Preprocess the image
153
  inputs = processor(images=img, text=text_query, return_tensors="pt").to(DEVICE)
154
+
155
  with torch.no_grad():
156
  outputs = model(**inputs)
157
+
158
  results = processor.post_process_grounded_object_detection(
159
  outputs,
160
  inputs.input_ids,
161
+ box_threshold=box_threshold,
162
  text_threshold=text_threshold,
163
  target_sizes=[img.size[::-1]]
164
  )[0]
165
+
166
  img_out = img.copy()
167
  img_out = draw_boxes(
168
  img_out,
169
+ boxes = results["boxes"].cpu().numpy(),
170
+ labels = results.get("text_labels", results.get("labels", [])),
171
+ scores = results["scores"]
172
  )
173
 
174
+ # Lista per i percorsi dei crop
175
+ crop_paths = []
 
 
 
 
176
 
177
+ if save_crops:
178
+ crop_paths = save_cropped_images(
179
+ img,
180
+ boxes=results["boxes"].cpu().numpy(),
181
+ labels=results.get("text_labels", results.get("labels", [])),
182
+ scores=results["scores"]
183
+ )
184
+ print(f"Generated {len(crop_paths)} cropped images")
185
+
186
  return img_out, crop_paths
187
 
188
+ # Create example list dynamically from examples directory
189
+ def load_examples_from_directory(directory="examples"):
190
+ """
191
+ Carica automaticamente tutti i file JPG dalla directory degli esempi.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
192
 
193
+ :param directory: Percorso della directory contenente gli esempi
194
+ :return: Lista di esempi nel formato [filepath, text_query, box_threshold, text_threshold]
195
+ """
196
+ examples = []
197
 
198
+ # Verifica se la directory esiste
199
+ if not os.path.exists(directory):
200
+ print(f"[WARNING] Directory '{directory}' non trovata. Creala e aggiungi file JPG.")
201
+ return examples
202
+
203
+ # Cerca tutti i file JPG nella directory
204
+ #jpg_files = [f for f in os.listdir(directory) if f.lower().endswith('.jpg')]
205
+ jpg_files = [f for f in os.listdir(directory) if f.lower().endswith(('.jpg', '.png'))]
206
+ if not jpg_files:
207
+ print(f"[WARNING] Nessun file JPG trovato nella directory '{directory}'")
208
+ return examples
209
+
210
+ print(f"[INFO] Trovati {len(jpg_files)} file JPG nella directory examples/")
211
+
212
+ # Crea gli esempi per ogni file JPG
213
+ for jpg_file in jpg_files:
214
+ filepath = os.path.join(directory, jpg_file)
215
+ examples.append([filepath, "heads.", 0.24, 0.23])
216
+
217
+ return examples
218
+
219
+ # Popola automaticamente la lista degli esempi
220
+ examples = load_examples_from_directory()
221
+
222
+ # Se non sono stati trovati esempi, usa un esempio di fallback
223
+ if not examples:
224
+ print("[INFO] Usando esempio di fallback")
225
+ examples = [
226
+ ["examples/stickers(1).jpg", "heads.", 0.24, 0.23],
227
+ ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
228
 
229
+ # Funzione per pulire i file temporanei dopo l'uso
230
+ def cleanup_temp_files(crop_paths):
231
+ for path in crop_paths:
232
+ try:
233
+ os.unlink(path)
234
+ except:
235
+ pass
236
+
237
+ # Create Gradio demo
238
+ with gr.Blocks(title="ClasmateFaceFinder", css=".gradio-container {max-width: 100% !important}") as demo:
239
+ gr.Markdown("# Classmate Finder")
240
+ gr.Markdown("Upload an image and adjust thresholds to see detections.")
241
+
242
+ with gr.Row():
243
+ with gr.Column():
244
+ image_input = gr.Image(type="pil", label="Input Image")
245
+ text_query = gr.Textbox(
246
+ value="head.",
247
+ label="Text Query (lowercase, end each with '.', for example 'a bird. a tree.')"
248
+ )
249
+ box_threshold = gr.Slider(0.0, 1.0, 0.14, step=0.05, label="Box Threshold")
250
+ text_threshold = gr.Slider(0.0, 1.0, 0.13, step=0.05, label="Text Threshold")
251
+ submit_btn = gr.Button("Detect")
252
+
253
+ with gr.Column():
254
+ image_output = gr.Image(type="pil", label="Detections")
255
+
256
+ # Galleria per i crop
257
+ gallery = gr.Gallery(
258
+ label="Detected Crops",
259
+ columns=[4],
260
+ rows=[2],
261
+ object_fit="contain",
262
+ height="auto"
263
+ )
264
 
265
+ # Esempi
266
+ gr.Examples(
267
+ examples=examples,
268
+ inputs=[image_input, text_query, box_threshold, text_threshold],
269
+ outputs=[image_output, gallery],
270
+ fn=detect_and_draw,
271
+ cache_examples=True
272
+ )
273
 
274
+ # Pulsante di submit
275
+ submit_btn.click(
276
+ fn=detect_and_draw,
277
+ inputs=[image_input, text_query, box_threshold, text_threshold],
278
+ outputs=[image_output, gallery]
279
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
280
 
281
+ # Pulisci i file temporanei quando viene caricato un nuovo esempio
282
+ demo.load(
283
+ fn=lambda: None,
284
+ inputs=None,
285
+ outputs=None,
286
+ )
287
 
288
+ if __name__ == "__main__":
289
+ demo.launch(server_name="0.0.0.0", share=False)