Spaces:
Sleeping
Sleeping
| import torch | |
| from PIL import Image, ImageDraw, ImageFont | |
| from transformers import GroundingDinoProcessor | |
| from modeling_grounding_dino import GroundingDinoForObjectDetection | |
| from PIL import Image, ImageDraw, ImageFont | |
| from itertools import cycle | |
| import os | |
| from datetime import datetime | |
| import gradio as gr | |
| # Load model and processor | |
| model_id = "fushh7/llmdet_swin_large_hf" | |
| model_id = "fushh7/llmdet_swin_tiny_hf" | |
| DEVICE = "cpu" | |
| print(f"[INFO] Using device: {DEVICE}") | |
| print(f"[INFO] Loading model from {model_id}...") | |
| processor = GroundingDinoProcessor.from_pretrained(model_id) | |
| model = GroundingDinoForObjectDetection.from_pretrained(model_id).to(DEVICE) | |
| model.eval() | |
| print("[INFO] Model loaded successfully.") | |
| # Pre-defined palette (extend or tweak as you like) | |
| BOX_COLORS = [ | |
| "deepskyblue", "red", "lime", "dodgerblue", | |
| "cyan", "magenta", "yellow", | |
| "orange", "chartreuse" | |
| ] | |
| def save_cropped_images(original_image, boxes, labels, scores, output_dir="static/output_crops"): | |
| """ | |
| Salva ogni regione ritagliata definita dalle bounding box in file separati. | |
| :param original_image: Immagine PIL originale | |
| :param boxes: Lista di bounding box [x_min, y_min, x_max, y_max] | |
| :param labels: Lista di etichette per ogni box | |
| :param scores: Lista di punteggi di confidenza | |
| :param output_dir: Directory base dove salvare le immagini | |
| :return: Lista dei percorsi dei file salvati | |
| """ | |
| # Crea una directory con timestamp per evitare sovrascritture | |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| output_path = os.path.join(output_dir, f"detections_{timestamp}") | |
| os.makedirs(output_path, exist_ok=True) | |
| saved_paths = [] | |
| for i, (box, label, score) in enumerate(zip(boxes, labels, scores)): | |
| # Pulisci il label per usarlo nel nome del file | |
| clean_label = "".join(c if c.isalnum() else "_" for c in label) | |
| # Ritaglia la regione dall'immagine originale | |
| cropped_img = original_image.crop(box) | |
| # Crea il nome del file | |
| filename = f"crop_{i}_{clean_label}_{score:.2f}.jpg" | |
| filepath = os.path.join(output_path, filename) | |
| # Salva l'immagine ritagliata | |
| cropped_img.save(filepath) | |
| saved_paths.append(filepath) | |
| return saved_paths | |
| def draw_boxes(image, boxes, labels, scores, colors=BOX_COLORS, font_path="arial.ttf", font_size=16): | |
| """ | |
| Draw bounding boxes and labels on a PIL Image. | |
| :param image: PIL Image object | |
| :param boxes: Iterable of [x_min, y_min, x_max, y_max] | |
| :param labels: Iterable of label strings | |
| :param scores: Iterable of scalar confidences (0-1) | |
| :param colors: List/tuple of colour names or RGB tuples | |
| :param font_path: Path to a TTF font for labels | |
| :param font_size: Int size of font to use, default 16 | |
| :return: PIL Image with drawn boxes | |
| """ | |
| # Ensure we can iterate colours indefinitely | |
| colour_cycle = cycle(colors) | |
| draw = ImageDraw.Draw(image) | |
| # Pick a font (fallback to default if missing) | |
| try: | |
| font = ImageFont.truetype(font_path, size=font_size) | |
| except IOError: | |
| font = ImageFont.load_default(size=font_size) | |
| # Assign a consistent colour per label (optional) | |
| label_to_colour = {} | |
| for box, label, score in zip(boxes, labels, scores): | |
| # Reuse colour if label seen before, else take next from cycle | |
| colour = label_to_colour.setdefault(label, next(colour_cycle)) | |
| x_min, y_min, x_max, y_max = map(int, box) | |
| # Draw rectangle | |
| draw.rectangle([x_min, y_min, x_max, y_max], outline=colour, width=2) | |
| # Compose text | |
| text = f"{label} ({score:.3f})" | |
| text_size = draw.textbbox((0, 0), text, font=font)[2:] | |
| # Draw text background for legibility | |
| bg_coords = [x_min, y_min - text_size[1] - 4, | |
| x_min + text_size[0] + 4, y_min] | |
| draw.rectangle(bg_coords, fill=colour) | |
| # Draw text | |
| draw.text((x_min + 2, y_min - text_size[1] - 2), | |
| text, fill="black", font=font) | |
| return image | |
| def resize_image_max_dimension(image, max_size=4096): | |
| """ | |
| Resize an image so that the longest side is at most max_size pixels, | |
| while maintaining the aspect ratio. | |
| :param image: PIL Image object | |
| :param max_size: Maximum dimension in pixels (default: 1024) | |
| :return: PIL Image object (resized) | |
| """ | |
| width, height = image.size | |
| # Check if resizing is needed | |
| if max(width, height) <= max_size: | |
| return image | |
| # Calculate new dimensions maintaining aspect ratio | |
| ratio = max_size / max(width, height) | |
| new_width = int(width * ratio) | |
| new_height = int(height * ratio) | |
| # Resize the image using high-quality resampling | |
| return image.resize((new_width, new_height), Image.Resampling.LANCZOS) | |
| def detect_and_draw( | |
| img: Image.Image, | |
| text_query: str, | |
| box_threshold: float = 0.14, | |
| text_threshold: float = 0.13, | |
| save_crops: bool = True | |
| ) -> Image.Image: | |
| """ | |
| Detect objects described in `text_query`, draw boxes, return the image. | |
| Note: `text_query` must be lowercase and each concept ends with a dot | |
| (e.g. 'a cat. a remote control.') | |
| """ | |
| # Make sure text is lowered | |
| text_query = text_query.lower() | |
| # If the image size is too large, we make it smaller | |
| img = resize_image_max_dimension(img, max_size=4096) | |
| # Preprocess the image | |
| inputs = processor(images=img, text=text_query, return_tensors="pt").to(DEVICE) | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| results = processor.post_process_grounded_object_detection( | |
| outputs, | |
| inputs.input_ids, | |
| box_threshold=box_threshold, | |
| text_threshold=text_threshold, | |
| target_sizes=[img.size[::-1]] | |
| )[0] | |
| img_out = img.copy() | |
| img_out = draw_boxes( | |
| img_out, | |
| boxes = results["boxes"].cpu().numpy(), | |
| labels = results.get("text_labels", results.get("labels", [])), | |
| scores = results["scores"] | |
| ) | |
| if save_crops: | |
| saved_paths = save_cropped_images( | |
| img, | |
| boxes=results["boxes"].cpu().numpy(), | |
| labels=results.get("text_labels", results.get("labels", [])), | |
| scores=results["scores"] | |
| ) | |
| print(f"Saved {len(saved_paths)} cropped images to: {os.path.dirname(saved_paths[0])}") | |
| return img_out | |
| # Create example list | |
| examples = [ | |
| ["examples/stickers.jpg", "stickers. labels.", 0.24, 0.23], | |
| # ["examples/IMG_8920.jpeg", "bin. water bottle. hand. shoe.", 0.4, 0.3], | |
| # ["examples/IMG_9435.jpeg", "lettuce. orange slices (group). eggs (group). cheese (group). red cabbage. pear slices (group).", 0.4, 0.3], | |
| ] | |
| # Create Gradio demo | |
| app = gr.Interface( | |
| fn = detect_and_draw, | |
| inputs = [ | |
| gr.Image(type="pil", label="Image"), | |
| gr.Textbox(value="stickers", | |
| label="Text Query (lowercase, end each with '.', for example 'a bird. a tree.')"), | |
| gr.Slider(0.0, 1.0, 0.14, 0.05, label="Box Threshold"), | |
| gr.Slider(0.0, 1.0, 0.13, 0.05, label="Text Threshold") | |
| ], | |
| outputs = gr.Image(type="pil", label="Detections"), | |
| title = "Sticker Geo Tagger", | |
| description = f"""Upload an image containings stickers and adjust thresholds to see detections. | |
| <a href='/output_crops/' target='crops'>output_crops</a> | |
| """, | |
| examples = examples, | |
| cache_examples = True, | |
| ) | |
| #app.launch(server_name="0.0.0.0", server_port=22590, root_path="/stikkiers2", share=False) | |
| app.launch(server_name="0.0.0.0", share=False) | |