vagrillo commited on
Commit
bf50961
·
verified ·
1 Parent(s): 1ed4d46

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +88 -49
app.py CHANGED
@@ -8,6 +8,7 @@ from itertools import cycle
8
  import os
9
  from datetime import datetime
10
  import gradio as gr
 
11
 
12
  # Load model and processor
13
  model_id = "fushh7/llmdet_swin_large_hf"
@@ -30,43 +31,32 @@ BOX_COLORS = [
30
  "orange", "chartreuse"
31
  ]
32
 
33
-
34
- def save_cropped_images(original_image, boxes, labels, scores, output_dir="static/output_crops"):
35
  """
36
- Salva ogni regione ritagliata definita dalle bounding box in file separati.
37
-
38
  :param original_image: Immagine PIL originale
39
  :param boxes: Lista di bounding box [x_min, y_min, x_max, y_max]
40
  :param labels: Lista di etichette per ogni box
41
  :param scores: Lista di punteggi di confidenza
42
- :param output_dir: Directory base dove salvare le immagini
43
- :return: Lista dei percorsi dei file salvati
44
  """
45
- # Crea una directory con timestamp per evitare sovrascritture
46
- timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
47
- output_path = os.path.join(output_dir, f"detections_{timestamp}")
48
- os.makedirs(output_path, exist_ok=True)
49
-
50
  saved_paths = []
51
-
52
  for i, (box, label, score) in enumerate(zip(boxes, labels, scores)):
53
- # Pulisci il label per usarlo nel nome del file
54
- clean_label = "".join(c if c.isalnum() else "_" for c in label)
55
-
 
56
  # Ritaglia la regione dall'immagine originale
57
  cropped_img = original_image.crop(box)
58
-
59
- # Crea il nome del file
60
- filename = f"crop_{i}_{clean_label}_{score:.2f}.jpg"
61
- filepath = os.path.join(output_path, filename)
62
-
63
  # Salva l'immagine ritagliata
64
  cropped_img.save(filepath)
65
  saved_paths.append(filepath)
66
-
67
  return saved_paths
68
 
69
-
70
  def draw_boxes(image, boxes, labels, scores, colors=BOX_COLORS, font_path="arial.ttf", font_size=16):
71
  """
72
  Draw bounding boxes and labels on a PIL Image.
@@ -146,9 +136,9 @@ def detect_and_draw(
146
  box_threshold: float = 0.14,
147
  text_threshold: float = 0.13,
148
  save_crops: bool = True
149
- ) -> Image.Image:
150
  """
151
- Detect objects described in `text_query`, draw boxes, return the image.
152
  Note: `text_query` must be lowercase and each concept ends with a dot
153
  (e.g. 'a cat. a remote control.')
154
  """
@@ -180,42 +170,91 @@ def detect_and_draw(
180
  labels = results.get("text_labels", results.get("labels", [])),
181
  scores = results["scores"]
182
  )
 
 
 
 
183
  if save_crops:
184
- saved_paths = save_cropped_images(
185
  img,
186
  boxes=results["boxes"].cpu().numpy(),
187
  labels=results.get("text_labels", results.get("labels", [])),
188
  scores=results["scores"]
189
  )
190
- print(f"Saved {len(saved_paths)} cropped images to: {os.path.dirname(saved_paths[0])}")
191
 
192
- return img_out
193
 
194
  # Create example list
195
  examples = [
196
  ["examples/stickers(1).jpg", "stickers. labels.", 0.24, 0.23],
197
- # ["examples/IMG_8920.jpeg", "bin. water bottle. hand. shoe.", 0.4, 0.3],
198
- # ["examples/IMG_9435.jpeg", "lettuce. orange slices (group). eggs (group). cheese (group). red cabbage. pear slices (group).", 0.4, 0.3],
199
  ]
200
 
 
 
 
 
 
 
 
 
201
  # Create Gradio demo
202
- app = gr.Interface(
203
- fn = detect_and_draw,
204
- inputs = [
205
- gr.Image(type="pil", label="Image"),
206
- gr.Textbox(value="stickers. labels. postcards.",
207
- label="Text Query (lowercase, end each with '.', for example 'a bird. a tree.')"),
208
- gr.Slider(0.0, 1.0, 0.14, 0.05, label="Box Threshold"),
209
- gr.Slider(0.0, 1.0, 0.13, 0.05, label="Text Threshold")
210
- ],
211
- outputs = gr.Image(type="pil", label="Detections"),
212
- title = "Sticker Geo Tagger",
213
- description = f"""Upload an image containings stickers and adjust thresholds to see detections.
214
- <a href='/output_crops/' target='crops'>output_crops</a>
215
- """,
216
- examples = examples,
217
- cache_examples = True,
218
- )
219
-
220
- #app.launch(server_name="0.0.0.0", server_port=22590, root_path="/stikkiers2", share=False)
221
- app.launch(server_name="0.0.0.0", share=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  import os
9
  from datetime import datetime
10
  import gradio as gr
11
+ import tempfile
12
 
13
  # Load model and processor
14
  model_id = "fushh7/llmdet_swin_large_hf"
 
31
  "orange", "chartreuse"
32
  ]
33
 
34
+ def save_cropped_images(original_image, boxes, labels, scores):
 
35
  """
36
+ Salva ogni regione ritagliata definita dalle bounding box in file temporanei.
37
+
38
  :param original_image: Immagine PIL originale
39
  :param boxes: Lista di bounding box [x_min, y_min, x_max, y_max]
40
  :param labels: Lista di etichette per ogni box
41
  :param scores: Lista di punteggi di confidenza
42
+ :return: Lista dei percorsi dei file temporanei salvati
 
43
  """
 
 
 
 
 
44
  saved_paths = []
45
+
46
  for i, (box, label, score) in enumerate(zip(boxes, labels, scores)):
47
+ # Crea un file temporaneo
48
+ with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp_file:
49
+ filepath = tmp_file.name
50
+
51
  # Ritaglia la regione dall'immagine originale
52
  cropped_img = original_image.crop(box)
53
+
 
 
 
 
54
  # Salva l'immagine ritagliata
55
  cropped_img.save(filepath)
56
  saved_paths.append(filepath)
57
+
58
  return saved_paths
59
 
 
60
  def draw_boxes(image, boxes, labels, scores, colors=BOX_COLORS, font_path="arial.ttf", font_size=16):
61
  """
62
  Draw bounding boxes and labels on a PIL Image.
 
136
  box_threshold: float = 0.14,
137
  text_threshold: float = 0.13,
138
  save_crops: bool = True
139
+ ):
140
  """
141
+ Detect objects described in `text_query`, draw boxes, return the image and crops.
142
  Note: `text_query` must be lowercase and each concept ends with a dot
143
  (e.g. 'a cat. a remote control.')
144
  """
 
170
  labels = results.get("text_labels", results.get("labels", [])),
171
  scores = results["scores"]
172
  )
173
+
174
+ # Lista per i percorsi dei crop
175
+ crop_paths = []
176
+
177
  if save_crops:
178
+ crop_paths = save_cropped_images(
179
  img,
180
  boxes=results["boxes"].cpu().numpy(),
181
  labels=results.get("text_labels", results.get("labels", [])),
182
  scores=results["scores"]
183
  )
184
+ print(f"Generated {len(crop_paths)} cropped images")
185
 
186
+ return img_out, crop_paths
187
 
188
  # Create example list
189
  examples = [
190
  ["examples/stickers(1).jpg", "stickers. labels.", 0.24, 0.23],
 
 
191
  ]
192
 
193
+ # Funzione per pulire i file temporanei dopo l'uso
194
+ def cleanup_temp_files(crop_paths):
195
+ for path in crop_paths:
196
+ try:
197
+ os.unlink(path)
198
+ except:
199
+ pass
200
+
201
  # Create Gradio demo
202
+ with gr.Blocks() as demo:
203
+ gr.Markdown("# Sticker Geo Tagger")
204
+ gr.Markdown("Upload an image containing stickers and adjust thresholds to see detections.")
205
+
206
+ with gr.Row():
207
+ with gr.Column():
208
+ image_input = gr.Image(type="pil", label="Input Image")
209
+ text_query = gr.Textbox(
210
+ value="stickers. labels. postcards.",
211
+ label="Text Query (lowercase, end each with '.', for example 'a bird. a tree.')"
212
+ )
213
+ box_threshold = gr.Slider(0.0, 1.0, 0.14, step=0.05, label="Box Threshold")
214
+ text_threshold = gr.Slider(0.0, 1.0, 0.13, step=0.05, label="Text Threshold")
215
+ submit_btn = gr.Button("Detect")
216
+
217
+ with gr.Column():
218
+ image_output = gr.Image(type="pil", label="Detections")
219
+
220
+ # Galleria per i crop
221
+ gallery = gr.Gallery(
222
+ label="Detected Crops",
223
+ columns=[4],
224
+ rows=[2],
225
+ object_fit="contain",
226
+ height="auto"
227
+ )
228
+
229
+ # Esempi
230
+ gr.Examples(
231
+ examples=examples,
232
+ inputs=[image_input, text_query, box_threshold, text_threshold],
233
+ outputs=[image_output, gallery],
234
+ fn=detect_and_draw,
235
+ cache_examples=True
236
+ )
237
+
238
+ # Pulsante di submit
239
+ submit_btn.click(
240
+ fn=detect_and_draw,
241
+ inputs=[image_input, text_query, box_threshold, text_threshold],
242
+ outputs=[image_output, gallery]
243
+ )
244
+
245
+ # Pulisci i file temporanei quando viene caricato un nuovo esempio
246
+ demo.load(
247
+ fn=lambda: None,
248
+ inputs=None,
249
+ outputs=None,
250
+ _js="""
251
+ function() {
252
+ // Pulisci i file temporanei quando la pagina viene ricaricata
253
+ fetch('/cleanup_temp_files', {method: 'POST'});
254
+ return [];
255
+ }
256
+ """
257
+ )
258
+
259
+ if __name__ == "__main__":
260
+ demo.launch(server_name="0.0.0.0", share=False)