Spaces:
Sleeping
Sleeping
| import torch | |
| import numpy as np | |
| import cv2 | |
| import os | |
| import gradio as gr | |
| import logging | |
| from pathlib import Path | |
| from PIL import Image | |
| from torch.utils.data.dataloader import DataLoader | |
| from torch.utils.data import Dataset | |
| import detection | |
| from detection.faster_rcnn import FastRCNNPredictor | |
| import torchvision.transforms as transforms | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
| logger = logging.getLogger(__name__) | |
| # Configuration | |
| CONFIG = { | |
| "model_path": os.path.join('st', 'tv_frcnn_r50fpn_faster_rcnn_st.pth'), | |
| "min_size": 600, | |
| "max_size": 1000, | |
| "score_threshold": 0.7, | |
| "num_classes": 2, | |
| "num_theta_bins": 359, | |
| "example_image": "dataset/Q1/img/img106.jpg", | |
| "device": torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| } | |
| class SceneTextTestDataset(Dataset): | |
| def __init__(self, images): | |
| self.images = images | |
| self.transform = transforms.Compose([transforms.ToTensor()]) | |
| def __len__(self): | |
| return len(self.images) | |
| def __getitem__(self, index): | |
| image = self.images[index] | |
| if isinstance(image, np.ndarray): | |
| image = Image.fromarray(image) | |
| return self.transform(image) | |
| def load_model(model_path=None): | |
| """Load the Faster R-CNN model with error handling""" | |
| try: | |
| # Use configuration path if none provided | |
| if model_path is None: | |
| model_path = CONFIG["model_path"] | |
| # Check if model file exists | |
| if not os.path.exists(model_path): | |
| logger.error(f"Model file not found: {model_path}") | |
| return None | |
| # Initialize model architecture | |
| faster_rcnn_model = detection.fasterrcnn_resnet50_fpn( | |
| pretrained=True, | |
| min_size=CONFIG["min_size"], | |
| max_size=CONFIG["max_size"], | |
| box_score_thresh=CONFIG["score_threshold"], | |
| ) | |
| # Set up the class predictor | |
| faster_rcnn_model.roi_heads.box_predictor = FastRCNNPredictor( | |
| faster_rcnn_model.roi_heads.box_predictor.cls_score.in_features, | |
| num_classes=CONFIG["num_classes"], | |
| num_theta_bins=CONFIG["num_theta_bins"], | |
| ) | |
| # Load model weights | |
| state_dict = torch.load(model_path, map_location=CONFIG["device"]) | |
| faster_rcnn_model.load_state_dict(state_dict) | |
| # Set model to evaluation mode and move to appropriate device | |
| faster_rcnn_model.eval() | |
| faster_rcnn_model.to(CONFIG["device"]) | |
| logger.info(f"Model loaded successfully from {model_path}") | |
| return faster_rcnn_model | |
| except Exception as e: | |
| logger.error(f"Error loading model: {str(e)}") | |
| return None | |
| def prepare_input(input_img): | |
| """Prepare input image for processing""" | |
| try: | |
| if input_img is None: | |
| logger.warning("No input image provided") | |
| return None, None | |
| # Convert to numpy array if needed | |
| if not isinstance(input_img, np.ndarray): | |
| input_img = np.array(input_img) | |
| # Convert to RGB if needed | |
| img_rgb = cv2.cvtColor(input_img, cv2.COLOR_BGR2RGB) if (len(input_img.shape) == 3 and input_img.shape[2] == 3) else input_img | |
| # Create dataset and tensor | |
| dataset = SceneTextTestDataset([img_rgb]) | |
| image_tensor = dataset[0] | |
| input_tensor = image_tensor.unsqueeze(0).float().to(CONFIG["device"]) | |
| return input_tensor, input_img.copy() | |
| except Exception as e: | |
| logger.error(f"Error preparing input: {str(e)}") | |
| return None, None | |
| def remove_inner_boxes(boxes): | |
| if len(boxes) <= 1: | |
| return boxes | |
| boxes_np = boxes.detach().cpu().numpy() | |
| keep_indices = [] | |
| for i, box_a in enumerate(boxes_np): | |
| x1_a, y1_a, x2_a, y2_a = box_a | |
| is_inside = False | |
| for j, box_b in enumerate(boxes_np): | |
| if i == j: | |
| continue | |
| x1_b, y1_b, x2_b, y2_b = box_b | |
| margin = 2 | |
| if (x1_b - margin <= x1_a and | |
| y1_b - margin <= y1_a and | |
| x2_b + margin >= x2_a and | |
| y2_b + margin >= y2_a): | |
| is_inside = True | |
| break | |
| if not is_inside: | |
| keep_indices.append(i) | |
| # Return boxes based on indices | |
| if keep_indices: | |
| return boxes[keep_indices] | |
| return boxes | |
| def process_image(input_img, filter_overlaps=True, color=(0, 255, 0)): | |
| try: | |
| # Prepare input | |
| input_tensor, original_img = prepare_input(input_img) | |
| if input_tensor is None or original_img is None: | |
| return None | |
| # Load model if not already loaded | |
| if not hasattr(process_image, "model") or process_image.model is None: | |
| process_image.model = load_model() | |
| if process_image.model is None: | |
| return original_img # Return original if model failed to load | |
| # Perform inference | |
| with torch.no_grad(): | |
| try: | |
| output = process_image.model(input_tensor)[0] | |
| # Process detection results | |
| boxes = output["boxes"] | |
| # Filter overlapping boxes if requested | |
| if filter_overlaps: | |
| boxes = remove_inner_boxes(boxes) | |
| thetas = output["thetas"] | |
| scores = output["scores"] | |
| # Draw rotated bounding boxes | |
| for idx, box in enumerate(boxes): | |
| x1, y1, x2, y2 = box.detach().cpu().numpy() | |
| x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2) | |
| # Get box parameters | |
| theta = thetas[idx].detach().cpu().numpy() * 180 / np.pi | |
| score = scores[idx].detach().cpu().item() | |
| # Calculate center and dimensions | |
| cx, cy = (x1 + x2) / 2, (y1 + y2) / 2 | |
| w, h = x2 - x1, y2 - y1 | |
| # Create rotated rectangle | |
| rect = ((cx, cy), (w, h), theta) | |
| box_points = cv2.boxPoints(rect).astype(np.int32) | |
| # Draw contour and score | |
| cv2.drawContours(original_img, [box_points], 0, color, 2) | |
| # # Draw score if high enough (optional) | |
| # if score > 0.8: # Only draw high confidence scores | |
| # cv2.putText(original_img, f"{score:.2f}", | |
| # (int(cx), int(cy)), | |
| # cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 1) | |
| return original_img | |
| except Exception as e: | |
| logger.error(f"Error during inference: {str(e)}") | |
| return original_img | |
| except Exception as e: | |
| logger.error(f"Error in process_image: {str(e)}") | |
| return input_img if input_img is not None else None | |
| def create_gradio_app(): | |
| with gr.Blocks(title="Rotated Text Box Detection") as app: | |
| gr.Markdown("# Rotated Text Box Detection with Faster R-CNN") | |
| gr.Markdown("Upload an image to detect text boxes with rotated bounding boxes.") | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_image = gr.Image(label="Input Image", type="numpy") | |
| with gr.Row(): | |
| submit_btn = gr.Button("Detect Text Boxes", variant="primary") | |
| filter_checkbox = gr.Checkbox(label="Filter Overlapping Boxes", value=False) | |
| example_paths = [ | |
| CONFIG["example_image"], | |
| "dataset/Q1/img/img108.jpg", | |
| "dataset/Q1/img/img110.jpg" | |
| ] | |
| example_path = None | |
| for path in example_paths: | |
| if os.path.exists(path): | |
| example_path = path | |
| logger.info(f"Using example image: {path}") | |
| break | |
| if example_path: | |
| gr.Examples( | |
| examples=[[example_path]], | |
| inputs=input_image, | |
| label="Example Image" | |
| ) | |
| else: | |
| logger.warning("No example images found. Please upload your own.") | |
| with gr.Column(): | |
| output_image = gr.Image(label="Detection Result") | |
| submit_btn.click( | |
| fn=process_image, | |
| inputs=input_image, | |
| outputs=output_image | |
| ) | |
| gr.Markdown("## How to use") | |
| gr.Markdown("1. Upload an image using the input panel or click on the example image") | |
| gr.Markdown("2. Toggle 'Filter Overlapping Boxes' if you want to remove nested detections") | |
| gr.Markdown("3. Click 'Detect Text Boxes' to perform detection") | |
| gr.Markdown("4. View the results with rotated bounding boxes") | |
| gr.Markdown("## Tips") | |
| gr.Markdown("- For best results, use images with clear text and good contrast") | |
| gr.Markdown("- The model works best with high-resolution images") | |
| gr.Markdown("- If you get too many overlapping detections, enable the filtering option") | |
| return app | |
| if __name__ == "__main__": | |
| # Print system information | |
| logger.info(f"Using device: {CONFIG['device']}") | |
| logger.info(f"PyTorch version: {torch.__version__}") | |
| logger.info(f"OpenCV version: {cv2.__version__}") | |
| ## load image from img folder | |
| # img = cv2.imread(CONFIG["example_image"]) | |
| # output = process_image(img) | |
| # #save the plot | |
| # cv2.imwrite("output.jpg", output) | |
| # Create and launch app | |
| app = create_gradio_app() | |
| app.launch(server_name="0.0.0.0", server_port=7860, share=True, debug=True) |