Spaces:
Sleeping
Sleeping
| import cv2 | |
| import numpy as np | |
| from ultralytics import YOLO | |
| import yaml | |
| from huggingface_hub import hf_hub_download | |
| import os | |
| import torch | |
| from collections import defaultdict | |
| import time | |
| import sys | |
| class TrafficSignDetector: | |
| def __init__(self, config_path): | |
| with open(config_path, 'r') as f: | |
| config = yaml.safe_load(f) | |
| # Monkey patch torch.load to disable weights_only for ultralytics | |
| original_torch_load = torch.load | |
| def patched_torch_load(*args, **kwargs): | |
| kwargs['weights_only'] = False | |
| return original_torch_load(*args, **kwargs) | |
| torch.load = patched_torch_load | |
| try: | |
| # Load model from path | |
| model_path = config['model']['path'] | |
| # Handle HuggingFace paths | |
| if model_path.endswith('.pt'): | |
| # Full path with filename (e.g., VietCat/GTSRB-Model/models/GTSRB.pt) | |
| # repo_id can only be namespace/repo_name (2 parts max) | |
| parts = model_path.split('/') | |
| repo_id = '/'.join(parts[:2]) # Take first two parts: VietCat/GTSRB-Model | |
| file_path = '/'.join(parts[2:]) # Take rest: models/GTSRB.pt | |
| local_model_path = hf_hub_download(repo_id=repo_id, filename=file_path) | |
| self.model = YOLO(local_model_path) | |
| else: | |
| # Local path or direct model path | |
| self.model = YOLO(model_path) | |
| finally: | |
| # Restore original torch.load | |
| torch.load = original_torch_load | |
| self.conf_threshold = config['model']['confidence_threshold'] | |
| # Convert color strings to tuples if needed | |
| box_color = config['inference']['box_color'] | |
| if isinstance(box_color, str): | |
| # Convert string "(128, 0, 128)" to tuple (128, 0, 128) | |
| self.box_color = tuple(map(int, box_color.strip('()').split(','))) | |
| else: | |
| self.box_color = box_color | |
| text_color = config['inference']['text_color'] | |
| if isinstance(text_color, str): | |
| self.text_color = tuple(map(int, text_color.strip('()').split(','))) | |
| else: | |
| self.text_color = text_color | |
| self.thickness = config['inference']['thickness'] | |
| self.classes = config['classes'] | |
| # Print model information | |
| self._print_model_info() | |
| def _print_model_info(self): | |
| """ | |
| Print detailed information about the loaded model. | |
| """ | |
| print("\n" + "="*80) | |
| print("MODEL INFORMATION") | |
| print("="*80) | |
| # Basic model info | |
| print(f"Model type: {type(self.model)}") | |
| print(f"Model device: {self.model.device}") | |
| print(f"Confidence threshold: {self.conf_threshold}") | |
| print(f"Number of classes: {len(self.classes)}") | |
| # Model architecture | |
| try: | |
| print(f"\nModel architecture:") | |
| print(f" - Task: {self.model.task if hasattr(self.model, 'task') else 'Unknown'}") | |
| print(f" - Model type: {self.model.model.__class__.__name__ if hasattr(self.model, 'model') else 'Unknown'}") | |
| # Model parameters | |
| if hasattr(self.model, 'model') and hasattr(self.model.model, 'parameters'): | |
| total_params = sum(p.numel() for p in self.model.model.parameters()) | |
| trainable_params = sum(p.numel() for p in self.model.model.parameters() if p.requires_grad) | |
| weights_sum = sum(p.sum().item() for p in self.model.model.parameters()) | |
| print(f" - Total parameters: {total_params:,}") | |
| print(f" - Trainable parameters: {trainable_params:,}") | |
| print(f" - Weights sum: {weights_sum:.6f}") | |
| except Exception as e: | |
| print(f" - Could not retrieve architecture details: {e}") | |
| # Class information | |
| print(f"\nClasses ({len(self.classes)} total):") | |
| for i, cls in enumerate(self.classes): | |
| print(f" {i}: {cls}") | |
| # Try to get model summary | |
| try: | |
| if hasattr(self.model, 'info'): | |
| print(f"\nModel summary:") | |
| self.model.info() | |
| except Exception as e: | |
| print(f"Could not get model summary: {e}") | |
| print("="*80 + "\n") | |
| def _calculate_tiles_count(self, length, tile_size, min_overlap=0.2): | |
| """ | |
| Tính số tiles tối thiểu cần thiết cho 1 chiều. | |
| Đảm bảo overlap >= min_overlap. | |
| :param length: chiều dài của ảnh (width hoặc height) | |
| :param tile_size: kích thước tile | |
| :param min_overlap: overlap tối thiểu (0.2 = 20%) | |
| :return: (num_tiles, stride) | |
| """ | |
| if length <= tile_size: | |
| return 1, 0 | |
| # Cần ít nhất 2 tiles | |
| num_tiles = 2 | |
| max_iterations = 100 | |
| for _ in range(max_iterations): | |
| # stride = (length - tile_size) / (num_tiles - 1) | |
| stride = (length - tile_size) / (num_tiles - 1) | |
| overlap = (tile_size - stride) / tile_size | |
| if overlap >= min_overlap: | |
| return num_tiles, int(stride) | |
| num_tiles += 1 | |
| return num_tiles, int((length - tile_size) / (num_tiles - 1)) | |
| def _create_tiles(self, image, overlap_ratio=0.2): | |
| """ | |
| Cắt ảnh thành các tiles vuông với overlap tối thiểu. | |
| Tính số tiles cần thiết để cover hết ảnh với overlap >= overlap_ratio. | |
| :param image: input image (numpy array) | |
| :param overlap_ratio: tỉ lệ overlap tối thiểu (0.2 = 20%) | |
| :return: list of tile dicts | |
| """ | |
| height, width = image.shape[:2] | |
| tile_size = min(height, width) | |
| print(f"\n[TILING] Image: {width}x{height}, Min dimension (tile_size): {tile_size}") | |
| # Tính số tiles và stride cho mỗi chiều | |
| num_tiles_h, stride_h = self._calculate_tiles_count(height, tile_size, min_overlap=overlap_ratio) | |
| num_tiles_w, stride_w = self._calculate_tiles_count(width, tile_size, min_overlap=overlap_ratio) | |
| # Tính overlap thực tế | |
| overlap_h = (tile_size - stride_h) / tile_size if stride_h > 0 else 0 | |
| overlap_w = (tile_size - stride_w) / tile_size if stride_w > 0 else 0 | |
| print(f" - Tile size: {tile_size}x{tile_size}") | |
| print(f" - Height: {height} → {num_tiles_h} tiles, stride={stride_h}, overlap={overlap_h*100:.0f}%") | |
| print(f" - Width: {width} → {num_tiles_w} tiles, stride={stride_w}, overlap={overlap_w*100:.0f}%") | |
| tiles = [] | |
| # Tạo grid tiles | |
| for i in range(num_tiles_h): | |
| for j in range(num_tiles_w): | |
| # Tính vị trí | |
| y = int(i * stride_h) | |
| x = int(j * stride_w) | |
| # Đảm bảo không vượt quá bounds | |
| y = min(y, height - tile_size) | |
| x = min(x, width - tile_size) | |
| y_end = y + tile_size | |
| x_end = x + tile_size | |
| # Extract tile | |
| tile = image[y:y_end, x:x_end] | |
| tiles.append({ | |
| 'image': tile, | |
| 'y_min': y, | |
| 'x_min': x, | |
| 'y_max': y_end, | |
| 'x_max': x_end | |
| }) | |
| print(f" - Total tiles: {len(tiles)} ({num_tiles_h}x{num_tiles_w})") | |
| return tiles | |
| def _select_standard_size(self, tile_size): | |
| """ | |
| Chọn kích thước chuẩn gần nhất cho tile. | |
| :param tile_size: kích thước hiện tại | |
| :return: kích thước chuẩn (640, 960, hoặc 1024) | |
| """ | |
| standard_sizes = [640, 960, 1024] | |
| # Chọn size nhỏ nhất mà >= tile_size | |
| for size in standard_sizes: | |
| if size >= tile_size: | |
| return size | |
| return 1024 # Fallback to largest | |
| def _resize_to_standard(self, tile, target_size=640): | |
| """ | |
| Resize tile về size chuẩn với letterbox padding. | |
| :param tile: tile image | |
| :param target_size: target size (640, 960, hoặc 1024) | |
| :return: (resized_image, scale, pad_x, pad_y) | |
| """ | |
| height, width = tile.shape[:2] | |
| max_dim = max(width, height) | |
| # Scale to fit target while maintaining aspect ratio | |
| scale = target_size / max_dim | |
| # Calculate new dimensions | |
| new_width = int(width * scale) | |
| new_height = int(height * scale) | |
| # Resize image | |
| resized = cv2.resize(tile, (new_width, new_height), interpolation=cv2.INTER_LINEAR) | |
| # Create canvas and place resized image (letterbox) | |
| canvas = np.full((target_size, target_size, 3), (114, 114, 114), dtype=np.uint8) | |
| pad_x = (target_size - new_width) // 2 | |
| pad_y = (target_size - new_height) // 2 | |
| canvas[pad_y:pad_y + new_height, pad_x:pad_x + new_width] = resized | |
| return canvas, scale, pad_x, pad_y | |
| def _ensure_square(self, image, target_size=640): | |
| """ | |
| Adjust image to square while maintaining aspect ratio. | |
| Deprecated: use _resize_to_standard instead. | |
| """ | |
| return self._resize_to_standard(image, target_size) | |
| def _preprocess(self, image): | |
| """ | |
| Preprocess image: keep uint8 format as YOLO expects. | |
| :param image: input image (numpy array, uint8) | |
| :return: image in uint8 format | |
| """ | |
| # YOLO handles normalization internally, keep uint8 format | |
| print(f"Image format: {image.dtype}, Min: {image.min()}, Max: {image.max()}, Mean: {image.mean():.1f}") | |
| return image | |
| def _merge_detections(self, all_detections, overlap_threshold=0.5): | |
| """ | |
| Merge detections từ nhiều tiles, loại bỏ duplicates. | |
| Sử dụng NMS để gộp detections từ overlapping regions. | |
| :param all_detections: list of { | |
| 'x1': int, 'y1': int, 'x2': int, 'y2': int, | |
| 'conf': float, 'cls': int | |
| } | |
| :param overlap_threshold: IOU threshold cho NMS | |
| :return: merged_detections | |
| """ | |
| if not all_detections: | |
| return [] | |
| # Sort by confidence (descending) | |
| all_detections = sorted(all_detections, key=lambda x: x['conf'], reverse=True) | |
| merged = [] | |
| used = [False] * len(all_detections) | |
| for i, det in enumerate(all_detections): | |
| if used[i]: | |
| continue | |
| # Add this detection | |
| merged.append(det) | |
| used[i] = True | |
| # Mark overlapping detections as used | |
| for j in range(i + 1, len(all_detections)): | |
| if used[j]: | |
| continue | |
| # Calculate IOU | |
| x1_inter = max(det['x1'], all_detections[j]['x1']) | |
| y1_inter = max(det['y1'], all_detections[j]['y1']) | |
| x2_inter = min(det['x2'], all_detections[j]['x2']) | |
| y2_inter = min(det['y2'], all_detections[j]['y2']) | |
| if x2_inter < x1_inter or y2_inter < y1_inter: | |
| continue # No intersection | |
| inter_area = (x2_inter - x1_inter) * (y2_inter - y1_inter) | |
| det_area = (det['x2'] - det['x1']) * (det['y2'] - det['y1']) | |
| other_area = (all_detections[j]['x2'] - all_detections[j]['x1']) * (all_detections[j]['y2'] - all_detections[j]['y1']) | |
| union_area = det_area + other_area - inter_area | |
| iou = inter_area / union_area if union_area > 0 else 0 | |
| # Mark as duplicate if IOU > threshold | |
| if iou > overlap_threshold: | |
| used[j] = True | |
| return merged | |
| def detect(self, image, confidence_threshold=None): | |
| """ | |
| Perform inference on the image using tiling strategy. | |
| Cắt ảnh thành tiles, inference từng tile, sau đó merge kết quả. | |
| :param image: numpy array of the image | |
| :param confidence_threshold: optional override for confidence threshold | |
| :return: tuple of (image with drawn bounding boxes, preprocessed image for visualization) | |
| """ | |
| # Start timing | |
| start_time = time.time() | |
| start_time_str = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(start_time)) | |
| # Use provided threshold or fall back to config value | |
| if confidence_threshold is None: | |
| confidence_threshold = self.conf_threshold | |
| else: | |
| confidence_threshold = float(confidence_threshold) | |
| print(f"\n{'='*80}") | |
| print(f"DETECTION PIPELINE START (TILING STRATEGY)") | |
| print(f"{'='*80}") | |
| print(f"[START TIME] {start_time_str}") | |
| print(f"[STEP 1] INPUT IMAGE") | |
| print(f" - Shape: {image.shape}") | |
| print(f" - dtype: {image.dtype}") | |
| print(f" - Range: [{image.min()}, {image.max()}]") | |
| # Store original image for drawing | |
| original_image = image.copy() | |
| orig_h, orig_w = original_image.shape[:2] | |
| # STEP 2: Tạo tiles | |
| print(f"\n[STEP 2] TILING") | |
| tiles = self._create_tiles(original_image, overlap_ratio=0.2) | |
| # STEP 3: Xử lý từng tile | |
| print(f"\n[STEP 3] PROCESSING TILES") | |
| all_detections = [] | |
| for tile_idx, tile_info in enumerate(tiles): | |
| print(f"\n [TILE {tile_idx + 1}/{len(tiles)}]") | |
| print(f" Position in original: ({tile_info['x_min']}, {tile_info['y_min']}) → ({tile_info['x_max']}, {tile_info['y_max']})") | |
| tile = tile_info['image'] | |
| tile_h, tile_w = tile.shape[:2] | |
| # Chọn kích thước chuẩn | |
| standard_size = self._select_standard_size(max(tile_w, tile_h)) | |
| print(f" Tile size: {tile_w}x{tile_h} → Standard size: {standard_size}x{standard_size}") | |
| # Resize tile | |
| resized_tile, scale, pad_x, pad_y = self._resize_to_standard(tile, target_size=standard_size) | |
| # Inference | |
| results = self.model(resized_tile, conf=0.0, imgsz=standard_size, iou=0.55) | |
| # Process results | |
| for result in results: | |
| boxes = result.boxes | |
| print(f" Detections in this tile: {len(boxes)}") | |
| for box in boxes: | |
| # Get coordinates in resized tile space | |
| x1, y1, x2, y2 = box.xyxy[0].cpu().numpy().astype(int) | |
| # Transform back to original tile space | |
| x1 = int((x1 - pad_x) / scale) | |
| y1 = int((y1 - pad_y) / scale) | |
| x2 = int((x2 - pad_x) / scale) | |
| y2 = int((y2 - pad_y) / scale) | |
| # Clamp to tile bounds | |
| x1 = max(0, min(x1, tile_w)) | |
| y1 = max(0, min(y1, tile_h)) | |
| x2 = max(0, min(x2, tile_w)) | |
| y2 = max(0, min(y2, tile_h)) | |
| # Transform to original image space | |
| x1_orig = x1 + tile_info['x_min'] | |
| y1_orig = y1 + tile_info['y_min'] | |
| x2_orig = x2 + tile_info['x_min'] | |
| y2_orig = y2 + tile_info['y_min'] | |
| # Clamp to original image bounds | |
| x1_orig = max(0, min(x1_orig, orig_w)) | |
| y1_orig = max(0, min(y1_orig, orig_h)) | |
| x2_orig = max(0, min(x2_orig, orig_w)) | |
| y2_orig = max(0, min(y2_orig, orig_h)) | |
| conf = float(box.conf[0].cpu().numpy()) | |
| cls = int(box.cls[0].cpu().numpy()) | |
| all_detections.append({ | |
| 'x1': x1_orig, | |
| 'y1': y1_orig, | |
| 'x2': x2_orig, | |
| 'y2': y2_orig, | |
| 'conf': conf, | |
| 'cls': cls | |
| }) | |
| # STEP 4: Merge detections | |
| print(f"\n[STEP 4] MERGING DETECTIONS") | |
| sys.stdout.flush() | |
| print(f" - Raw detections from all tiles: {len(all_detections)}") | |
| sys.stdout.flush() | |
| merged_detections = self._merge_detections(all_detections, overlap_threshold=0.5) | |
| print(f" - After deduplication: {len(merged_detections)}") | |
| sys.stdout.flush() | |
| # STEP 5: Filter by confidence threshold | |
| print(f"\n[STEP 5] FILTERING & DRAWING") | |
| sys.stdout.flush() | |
| print(f" - Confidence threshold: {confidence_threshold}") | |
| sys.stdout.flush() | |
| # Get top 5 detections | |
| top_5_dets = sorted(merged_detections, key=lambda x: x['conf'], reverse=True)[:5] | |
| print(f"\n[TOP 5 DETECTIONS]") | |
| sys.stdout.flush() | |
| if len(top_5_dets) > 0: | |
| for rank, det in enumerate(top_5_dets, 1): | |
| x1, y1, x2, y2 = det['x1'], det['y1'], det['x2'], det['y2'] | |
| cls = det['cls'] | |
| conf = det['conf'] | |
| w = x2 - x1 | |
| h = y2 - y1 | |
| area = w * h | |
| print(f" {rank}. {self.classes[cls]:30s} | conf={conf:.4f} | size=({w}x{h}) | area={area:7d} | bbox=({x1},{y1})-({x2},{y2})") | |
| sys.stdout.flush() | |
| else: | |
| print(f" No detections found") | |
| sys.stdout.flush() | |
| drawn_count = 0 | |
| for det in merged_detections: | |
| if det['conf'] >= confidence_threshold: | |
| x1, y1, x2, y2 = det['x1'], det['y1'], det['x2'], det['y2'] | |
| cls = det['cls'] | |
| conf = det['conf'] | |
| # Draw bounding box | |
| cv2.rectangle(original_image, (x1, y1), (x2, y2), self.box_color, self.thickness) | |
| # Draw label | |
| label = f"{self.classes[cls]}: {conf:.2f}" | |
| cv2.putText(original_image, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, self.text_color, 2) | |
| drawn_count += 1 | |
| print(f"\n[FILTERING RESULT]") | |
| sys.stdout.flush() | |
| print(f" - Total detections: {len(merged_detections)}") | |
| sys.stdout.flush() | |
| print(f" - Drawn (conf >= {confidence_threshold}): {drawn_count}") | |
| sys.stdout.flush() | |
| # End timing | |
| end_time = time.time() | |
| end_time_str = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(end_time)) | |
| elapsed = end_time - start_time | |
| print(f"\n{'='*80}") | |
| sys.stdout.flush() | |
| print(f"DETECTION PIPELINE COMPLETE") | |
| sys.stdout.flush() | |
| print(f"{'='*80}") | |
| sys.stdout.flush() | |
| print(f"[END TIME] {end_time_str}") | |
| sys.stdout.flush() | |
| print(f"[TOTAL TIME] {elapsed:.2f} seconds\n") | |
| sys.stdout.flush() | |
| # Create preprocessed visualization (first tile for reference) | |
| preprocessed_display = tiles[0]['image'].copy() if tiles else original_image.copy() | |
| return original_image, preprocessed_display | |