Mjolnir65 commited on
Commit
8168abc
·
verified ·
1 Parent(s): 76b1517

create app.py

Browse files
Files changed (1) hide show
  1. app.py +286 -0
app.py ADDED
@@ -0,0 +1,286 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import cv2
4
+ import os
5
+ import gradio as gr
6
+ import logging
7
+ from pathlib import Path
8
+ from PIL import Image
9
+ from torch.utils.data.dataloader import DataLoader
10
+ from torch.utils.data import Dataset
11
+ import detection
12
+ from detection.faster_rcnn import FastRCNNPredictor
13
+
14
+ import torchvision.transforms as transforms
15
+
16
+ # Configure logging
17
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
18
+ logger = logging.getLogger(__name__)
19
+
20
+ # Configuration
21
+ CONFIG = {
22
+ "model_path": os.path.join('st', 'tv_frcnn_r50fpn_faster_rcnn_st.pth'),
23
+ "min_size": 600,
24
+ "max_size": 1000,
25
+ "score_threshold": 0.7,
26
+ "num_classes": 2,
27
+ "num_theta_bins": 359,
28
+ "example_image": "dataset/Q1/img/img106.jpg",
29
+ "device": torch.device("cuda" if torch.cuda.is_available() else "cpu")
30
+ }
31
+
32
+ class SceneTextTestDataset(Dataset):
33
+ def __init__(self, images):
34
+ self.images = images
35
+ self.transform = transforms.Compose([transforms.ToTensor()])
36
+
37
+ def __len__(self):
38
+ return len(self.images)
39
+
40
+ def __getitem__(self, index):
41
+ image = self.images[index]
42
+ if isinstance(image, np.ndarray):
43
+ image = Image.fromarray(image)
44
+ return self.transform(image)
45
+
46
+ def load_model(model_path=None):
47
+ """Load the Faster R-CNN model with error handling"""
48
+ try:
49
+ # Use configuration path if none provided
50
+ if model_path is None:
51
+ model_path = CONFIG["model_path"]
52
+
53
+ # Check if model file exists
54
+ if not os.path.exists(model_path):
55
+ logger.error(f"Model file not found: {model_path}")
56
+ return None
57
+
58
+ # Initialize model architecture
59
+ faster_rcnn_model = detection.fasterrcnn_resnet50_fpn(
60
+ pretrained=True,
61
+ min_size=CONFIG["min_size"],
62
+ max_size=CONFIG["max_size"],
63
+ box_score_thresh=CONFIG["score_threshold"],
64
+ )
65
+
66
+ # Set up the class predictor
67
+ faster_rcnn_model.roi_heads.box_predictor = FastRCNNPredictor(
68
+ faster_rcnn_model.roi_heads.box_predictor.cls_score.in_features,
69
+ num_classes=CONFIG["num_classes"],
70
+ num_theta_bins=CONFIG["num_theta_bins"],
71
+ )
72
+
73
+ # Load model weights
74
+ state_dict = torch.load(model_path, map_location=CONFIG["device"])
75
+ faster_rcnn_model.load_state_dict(state_dict)
76
+
77
+ # Set model to evaluation mode and move to appropriate device
78
+ faster_rcnn_model.eval()
79
+ faster_rcnn_model.to(CONFIG["device"])
80
+
81
+ logger.info(f"Model loaded successfully from {model_path}")
82
+ return faster_rcnn_model
83
+
84
+ except Exception as e:
85
+ logger.error(f"Error loading model: {str(e)}")
86
+ return None
87
+
88
+ def prepare_input(input_img):
89
+ """Prepare input image for processing"""
90
+ try:
91
+ if input_img is None:
92
+ logger.warning("No input image provided")
93
+ return None, None
94
+
95
+ # Convert to numpy array if needed
96
+ if not isinstance(input_img, np.ndarray):
97
+ input_img = np.array(input_img)
98
+
99
+ # Convert to RGB if needed
100
+ img_rgb = cv2.cvtColor(input_img, cv2.COLOR_BGR2RGB) if (len(input_img.shape) == 3 and input_img.shape[2] == 3) else input_img
101
+
102
+ # Create dataset and tensor
103
+ dataset = SceneTextTestDataset([img_rgb])
104
+ image_tensor = dataset[0]
105
+ input_tensor = image_tensor.unsqueeze(0).float().to(CONFIG["device"])
106
+
107
+ return input_tensor, input_img.copy()
108
+
109
+ except Exception as e:
110
+ logger.error(f"Error preparing input: {str(e)}")
111
+ return None, None
112
+
113
+ def remove_inner_boxes(boxes):
114
+
115
+ if len(boxes) <= 1:
116
+ return boxes
117
+
118
+ boxes_np = boxes.detach().cpu().numpy()
119
+ keep_indices = []
120
+
121
+ for i, box_a in enumerate(boxes_np):
122
+ x1_a, y1_a, x2_a, y2_a = box_a
123
+ is_inside = False
124
+
125
+ for j, box_b in enumerate(boxes_np):
126
+ if i == j:
127
+ continue
128
+ x1_b, y1_b, x2_b, y2_b = box_b
129
+
130
+ margin = 2
131
+ if (x1_b - margin <= x1_a and
132
+ y1_b - margin <= y1_a and
133
+ x2_b + margin >= x2_a and
134
+ y2_b + margin >= y2_a):
135
+ is_inside = True
136
+ break
137
+
138
+ if not is_inside:
139
+ keep_indices.append(i)
140
+
141
+ # Return boxes based on indices
142
+ if keep_indices:
143
+ return boxes[keep_indices]
144
+ return boxes
145
+
146
+ def process_image(input_img, filter_overlaps=True, color=(0, 255, 0)):
147
+
148
+ try:
149
+ # Prepare input
150
+ input_tensor, original_img = prepare_input(input_img)
151
+ if input_tensor is None or original_img is None:
152
+ return None
153
+
154
+ # Load model if not already loaded
155
+ if not hasattr(process_image, "model") or process_image.model is None:
156
+ process_image.model = load_model()
157
+ if process_image.model is None:
158
+ return original_img # Return original if model failed to load
159
+
160
+ # Perform inference
161
+ with torch.no_grad():
162
+ try:
163
+ output = process_image.model(input_tensor)[0]
164
+
165
+ # Process detection results
166
+ boxes = output["boxes"]
167
+
168
+ # Filter overlapping boxes if requested
169
+ if filter_overlaps:
170
+ boxes = remove_inner_boxes(boxes)
171
+
172
+ thetas = output["thetas"]
173
+ scores = output["scores"]
174
+
175
+ # Draw rotated bounding boxes
176
+ for idx, box in enumerate(boxes):
177
+ x1, y1, x2, y2 = box.detach().cpu().numpy()
178
+ x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
179
+
180
+ # Get box parameters
181
+ theta = thetas[idx].detach().cpu().numpy() * 180 / np.pi
182
+ score = scores[idx].detach().cpu().item()
183
+
184
+ # Calculate center and dimensions
185
+ cx, cy = (x1 + x2) / 2, (y1 + y2) / 2
186
+ w, h = x2 - x1, y2 - y1
187
+
188
+ # Create rotated rectangle
189
+ rect = ((cx, cy), (w, h), theta)
190
+ box_points = cv2.boxPoints(rect).astype(np.int32)
191
+
192
+ # Draw contour and score
193
+ cv2.drawContours(original_img, [box_points], 0, color, 2)
194
+
195
+ # # Draw score if high enough (optional)
196
+ # if score > 0.8: # Only draw high confidence scores
197
+ # cv2.putText(original_img, f"{score:.2f}",
198
+ # (int(cx), int(cy)),
199
+ # cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 1)
200
+
201
+ return original_img
202
+
203
+ except Exception as e:
204
+ logger.error(f"Error during inference: {str(e)}")
205
+ return original_img
206
+
207
+ except Exception as e:
208
+ logger.error(f"Error in process_image: {str(e)}")
209
+ return input_img if input_img is not None else None
210
+
211
+ def create_gradio_app():
212
+
213
+ with gr.Blocks(title="Rotated Text Box Detection") as app:
214
+ gr.Markdown("# Rotated Text Box Detection with Faster R-CNN")
215
+ gr.Markdown("Upload an image to detect text boxes with rotated bounding boxes.")
216
+
217
+ with gr.Row():
218
+ with gr.Column():
219
+ input_image = gr.Image(label="Input Image", type="numpy")
220
+
221
+ with gr.Row():
222
+ submit_btn = gr.Button("Detect Text Boxes", variant="primary")
223
+ filter_checkbox = gr.Checkbox(label="Filter Overlapping Boxes", value=False)
224
+
225
+ example_paths = [
226
+ CONFIG["example_image"],
227
+ "dataset/Q1/img/img108.jpg",
228
+ "dataset/Q1/img/img110.jpg"
229
+ ]
230
+
231
+ example_path = None
232
+ for path in example_paths:
233
+ if os.path.exists(path):
234
+ example_path = path
235
+ logger.info(f"Using example image: {path}")
236
+ break
237
+
238
+ if example_path:
239
+ gr.Examples(
240
+ examples=[[example_path]],
241
+ inputs=input_image,
242
+ label="Example Image"
243
+ )
244
+ else:
245
+ logger.warning("No example images found. Please upload your own.")
246
+
247
+ with gr.Column():
248
+ output_image = gr.Image(label="Detection Result")
249
+
250
+ submit_btn.click(
251
+ fn=process_image,
252
+ inputs=input_image,
253
+ outputs=output_image
254
+ )
255
+
256
+ gr.Markdown("## How to use")
257
+ gr.Markdown("1. Upload an image using the input panel or click on the example image")
258
+ gr.Markdown("2. Toggle 'Filter Overlapping Boxes' if you want to remove nested detections")
259
+ gr.Markdown("3. Click 'Detect Text Boxes' to perform detection")
260
+ gr.Markdown("4. View the results with rotated bounding boxes")
261
+
262
+ gr.Markdown("## Tips")
263
+ gr.Markdown("- For best results, use images with clear text and good contrast")
264
+ gr.Markdown("- The model works best with high-resolution images")
265
+ gr.Markdown("- If you get too many overlapping detections, enable the filtering option")
266
+
267
+ return app
268
+
269
+ if __name__ == "__main__":
270
+ # Print system information
271
+ logger.info(f"Using device: {CONFIG['device']}")
272
+ logger.info(f"PyTorch version: {torch.__version__}")
273
+ logger.info(f"OpenCV version: {cv2.__version__}")
274
+
275
+ ## load image from img folder
276
+ # img = cv2.imread(CONFIG["example_image"])
277
+
278
+ # output = process_image(img)
279
+
280
+ # #save the plot
281
+ # cv2.imwrite("output.jpg", output)
282
+
283
+
284
+ # Create and launch app
285
+ app = create_gradio_app()
286
+ app.launch(server_name="0.0.0.0", server_port=7860, share=True, debug=True)