vagrillo commited on
Commit
d15a538
·
1 Parent(s): ebb9b75

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +221 -0
app.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from PIL import Image, ImageDraw, ImageFont
3
+ from transformers import GroundingDinoProcessor
4
+ from modeling_grounding_dino import GroundingDinoForObjectDetection
5
+
6
+ from PIL import Image, ImageDraw, ImageFont
7
+ from itertools import cycle
8
+ import os
9
+ from datetime import datetime
10
+ import gradio as gr
11
+
12
+ # Load model and processor
13
+ model_id = "fushh7/llmdet_swin_large_hf"
14
+ model_id = "fushh7/llmdet_swin_tiny_hf"
15
+ DEVICE = "cpu"
16
+
17
+ print(f"[INFO] Using device: {DEVICE}")
18
+ print(f"[INFO] Loading model from {model_id}...")
19
+
20
+ processor = GroundingDinoProcessor.from_pretrained(model_id)
21
+ model = GroundingDinoForObjectDetection.from_pretrained(model_id).to(DEVICE)
22
+ model.eval()
23
+
24
+ print("[INFO] Model loaded successfully.")
25
+
26
+ # Pre-defined palette (extend or tweak as you like)
27
+ BOX_COLORS = [
28
+ "deepskyblue", "red", "lime", "dodgerblue",
29
+ "cyan", "magenta", "yellow",
30
+ "orange", "chartreuse"
31
+ ]
32
+
33
+
34
+ def save_cropped_images(original_image, boxes, labels, scores, output_dir="static/output_crops"):
35
+ """
36
+ Salva ogni regione ritagliata definita dalle bounding box in file separati.
37
+
38
+ :param original_image: Immagine PIL originale
39
+ :param boxes: Lista di bounding box [x_min, y_min, x_max, y_max]
40
+ :param labels: Lista di etichette per ogni box
41
+ :param scores: Lista di punteggi di confidenza
42
+ :param output_dir: Directory base dove salvare le immagini
43
+ :return: Lista dei percorsi dei file salvati
44
+ """
45
+ # Crea una directory con timestamp per evitare sovrascritture
46
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
47
+ output_path = os.path.join(output_dir, f"detections_{timestamp}")
48
+ os.makedirs(output_path, exist_ok=True)
49
+
50
+ saved_paths = []
51
+
52
+ for i, (box, label, score) in enumerate(zip(boxes, labels, scores)):
53
+ # Pulisci il label per usarlo nel nome del file
54
+ clean_label = "".join(c if c.isalnum() else "_" for c in label)
55
+
56
+ # Ritaglia la regione dall'immagine originale
57
+ cropped_img = original_image.crop(box)
58
+
59
+ # Crea il nome del file
60
+ filename = f"crop_{i}_{clean_label}_{score:.2f}.jpg"
61
+ filepath = os.path.join(output_path, filename)
62
+
63
+ # Salva l'immagine ritagliata
64
+ cropped_img.save(filepath)
65
+ saved_paths.append(filepath)
66
+
67
+ return saved_paths
68
+
69
+
70
+ def draw_boxes(image, boxes, labels, scores, colors=BOX_COLORS, font_path="arial.ttf", font_size=16):
71
+ """
72
+ Draw bounding boxes and labels on a PIL Image.
73
+
74
+ :param image: PIL Image object
75
+ :param boxes: Iterable of [x_min, y_min, x_max, y_max]
76
+ :param labels: Iterable of label strings
77
+ :param scores: Iterable of scalar confidences (0-1)
78
+ :param colors: List/tuple of colour names or RGB tuples
79
+ :param font_path: Path to a TTF font for labels
80
+ :param font_size: Int size of font to use, default 16
81
+ :return: PIL Image with drawn boxes
82
+ """
83
+ # Ensure we can iterate colours indefinitely
84
+ colour_cycle = cycle(colors)
85
+ draw = ImageDraw.Draw(image)
86
+
87
+ # Pick a font (fallback to default if missing)
88
+ try:
89
+ font = ImageFont.truetype(font_path, size=font_size)
90
+ except IOError:
91
+ font = ImageFont.load_default(size=font_size)
92
+
93
+ # Assign a consistent colour per label (optional)
94
+ label_to_colour = {}
95
+
96
+ for box, label, score in zip(boxes, labels, scores):
97
+ # Reuse colour if label seen before, else take next from cycle
98
+ colour = label_to_colour.setdefault(label, next(colour_cycle))
99
+
100
+ x_min, y_min, x_max, y_max = map(int, box)
101
+
102
+ # Draw rectangle
103
+ draw.rectangle([x_min, y_min, x_max, y_max], outline=colour, width=2)
104
+
105
+ # Compose text
106
+ text = f"{label} ({score:.3f})"
107
+ text_size = draw.textbbox((0, 0), text, font=font)[2:]
108
+
109
+ # Draw text background for legibility
110
+ bg_coords = [x_min, y_min - text_size[1] - 4,
111
+ x_min + text_size[0] + 4, y_min]
112
+ draw.rectangle(bg_coords, fill=colour)
113
+
114
+ # Draw text
115
+ draw.text((x_min + 2, y_min - text_size[1] - 2),
116
+ text, fill="black", font=font)
117
+
118
+ return image
119
+
120
+ def resize_image_max_dimension(image, max_size=4096):
121
+ """
122
+ Resize an image so that the longest side is at most max_size pixels,
123
+ while maintaining the aspect ratio.
124
+
125
+ :param image: PIL Image object
126
+ :param max_size: Maximum dimension in pixels (default: 1024)
127
+ :return: PIL Image object (resized)
128
+ """
129
+ width, height = image.size
130
+
131
+ # Check if resizing is needed
132
+ if max(width, height) <= max_size:
133
+ return image
134
+
135
+ # Calculate new dimensions maintaining aspect ratio
136
+ ratio = max_size / max(width, height)
137
+ new_width = int(width * ratio)
138
+ new_height = int(height * ratio)
139
+
140
+ # Resize the image using high-quality resampling
141
+ return image.resize((new_width, new_height), Image.Resampling.LANCZOS)
142
+
143
+ def detect_and_draw(
144
+ img: Image.Image,
145
+ text_query: str,
146
+ box_threshold: float = 0.14,
147
+ text_threshold: float = 0.13,
148
+ save_crops: bool = True
149
+ ) -> Image.Image:
150
+ """
151
+ Detect objects described in `text_query`, draw boxes, return the image.
152
+ Note: `text_query` must be lowercase and each concept ends with a dot
153
+ (e.g. 'a cat. a remote control.')
154
+ """
155
+
156
+ # Make sure text is lowered
157
+ text_query = text_query.lower()
158
+
159
+ # If the image size is too large, we make it smaller
160
+ img = resize_image_max_dimension(img, max_size=4096)
161
+
162
+ # Preprocess the image
163
+ inputs = processor(images=img, text=text_query, return_tensors="pt").to(DEVICE)
164
+
165
+ with torch.no_grad():
166
+ outputs = model(**inputs)
167
+
168
+ results = processor.post_process_grounded_object_detection(
169
+ outputs,
170
+ inputs.input_ids,
171
+ box_threshold=box_threshold,
172
+ text_threshold=text_threshold,
173
+ target_sizes=[img.size[::-1]]
174
+ )[0]
175
+
176
+ img_out = img.copy()
177
+ img_out = draw_boxes(
178
+ img_out,
179
+ boxes = results["boxes"].cpu().numpy(),
180
+ labels = results.get("text_labels", results.get("labels", [])),
181
+ scores = results["scores"]
182
+ )
183
+ if save_crops:
184
+ saved_paths = save_cropped_images(
185
+ img,
186
+ boxes=results["boxes"].cpu().numpy(),
187
+ labels=results.get("text_labels", results.get("labels", [])),
188
+ scores=results["scores"]
189
+ )
190
+ print(f"Saved {len(saved_paths)} cropped images to: {os.path.dirname(saved_paths[0])}")
191
+
192
+ return img_out
193
+
194
+ # Create example list
195
+ examples = [
196
+ ["examples/stickers.jpg", "stickers. labels.", 0.24, 0.23],
197
+ # ["examples/IMG_8920.jpeg", "bin. water bottle. hand. shoe.", 0.4, 0.3],
198
+ # ["examples/IMG_9435.jpeg", "lettuce. orange slices (group). eggs (group). cheese (group). red cabbage. pear slices (group).", 0.4, 0.3],
199
+ ]
200
+
201
+ # Create Gradio demo
202
+ app = gr.Interface(
203
+ fn = detect_and_draw,
204
+ inputs = [
205
+ gr.Image(type="pil", label="Image"),
206
+ gr.Textbox(value="stickers",
207
+ label="Text Query (lowercase, end each with '.', for example 'a bird. a tree.')"),
208
+ gr.Slider(0.0, 1.0, 0.14, 0.05, label="Box Threshold"),
209
+ gr.Slider(0.0, 1.0, 0.13, 0.05, label="Text Threshold")
210
+ ],
211
+ outputs = gr.Image(type="pil", label="Detections"),
212
+ title = "Sticker Geo Tagger",
213
+ description = f"""Upload an image containings stickers and adjust thresholds to see detections.
214
+ <a href='/output_crops/' target='crops'>output_crops</a>
215
+ """,
216
+ examples = examples,
217
+ cache_examples = True,
218
+ )
219
+
220
+ #app.launch(server_name="0.0.0.0", server_port=22590, root_path="/stikkiers2", share=False)
221
+ app.launch(server_name="0.0.0.0", share=False)