enpaiva commited on
Commit
8602b6d
Β·
verified Β·
1 Parent(s): 4c69cf1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +240 -116
app.py CHANGED
@@ -3,6 +3,7 @@ os.environ["GRADIO_TEMP_DIR"] = "./tmp"
3
 
4
  import sys
5
  import torch
 
6
  import gradio as gr
7
  import numpy as np
8
  import cv2
@@ -13,10 +14,10 @@ from transformers import (
13
  RTDetrImageProcessor,
14
  )
15
 
16
- # == select device ==
17
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
18
 
19
- # Available models with their corresponding model classes
20
  MODELS = {
21
  "Egret XLarge": {
22
  "path": "ds4sd/docling-layout-egret-xlarge",
@@ -40,34 +41,22 @@ MODELS = {
40
  }
41
  }
42
 
43
- # Classes mapping for the docling model
44
  classes_map = {
45
- 0: "Caption",
46
- 1: "Footnote",
47
- 2: "Formula",
48
- 3: "List-item",
49
- 4: "Page-footer",
50
- 5: "Page-header",
51
- 6: "Picture",
52
- 7: "Section-header",
53
- 8: "Table",
54
- 9: "Text",
55
- 10: "Title",
56
- 11: "Document Index",
57
- 12: "Code",
58
- 13: "Checkbox-Selected",
59
- 14: "Checkbox-Unselected",
60
- 15: "Form",
61
- 16: "Key-Value Region",
62
  }
63
 
64
- # Global variables for model
65
  current_model = None
66
  current_processor = None
67
  current_model_name = None
68
 
69
  def colormap(N=256, normalized=False):
70
- """Generate the color map."""
71
  def bitget(byteval, idx):
72
  return ((byteval & (1 << idx)) != 0)
73
 
@@ -84,25 +73,24 @@ def colormap(N=256, normalized=False):
84
 
85
  if normalized:
86
  cmap = cmap.astype(np.float32) / 255.0
87
-
88
  return cmap
89
 
90
  def iomin(box1, box2):
91
- """Intersection over Minimum (IoMin)"""
92
  x1 = torch.max(box1[:, 0], box2[:, 0])
93
  y1 = torch.max(box1[:, 1], box2[:, 1])
94
  x2 = torch.min(box1[:, 2], box2[:, 2])
95
  y2 = torch.min(box1[:, 3], box2[:, 3])
96
  inter_area = torch.clamp(x2 - x1, min=0) * torch.clamp(y2 - y1, min=0)
97
-
98
  box1_area = (box1[:, 2] - box1[:, 0]) * (box1[:, 3] - box1[:, 1])
99
  box2_area = (box2[:, 2] - box2[:, 0]) * (box2[:, 3] - box2[:, 1])
100
  min_area = torch.min(box1_area, box2_area)
101
-
102
  return inter_area / min_area
103
 
104
- def nms(boxes, scores, iou_threshold=0.5):
105
- """Custom NMS implementation using IoMin"""
106
  keep = []
107
  _, order = scores.sort(descending=True)
108
 
@@ -123,18 +111,19 @@ def nms(boxes, scores, iou_threshold=0.5):
123
  return torch.tensor(keep, dtype=torch.long)
124
 
125
  def load_model(model_name):
126
- """Load the selected model"""
127
  global current_model, current_processor, current_model_name
128
 
129
  if current_model_name == model_name:
130
  return f"βœ… Model {model_name} is already loaded!"
131
 
132
  try:
133
- print(f"Loading model: {model_name}")
134
  model_info = MODELS[model_name]
135
  model_path = model_info["path"]
136
  model_class = model_info["model_class"]
137
 
 
 
138
  processor = RTDetrImageProcessor.from_pretrained(model_path)
139
  model = model_class.from_pretrained(model_path)
140
  model = model.to(device)
@@ -147,10 +136,11 @@ def load_model(model_name):
147
  return f"βœ… Successfully loaded {model_name}!"
148
 
149
  except Exception as e:
 
150
  return f"❌ Error loading {model_name}: {str(e)}"
151
 
152
  def visualize_bbox(image_input, bboxes, classes, scores, id_to_names, alpha=0.3):
153
- """Visualize bounding boxes with transparent overlays using OpenCV"""
154
  if isinstance(image_input, Image.Image):
155
  image = np.array(image_input)
156
  image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
@@ -162,12 +152,12 @@ def visualize_bbox(image_input, bboxes, classes, scores, id_to_names, alpha=0.3)
162
  else:
163
  raise ValueError("Input must be PIL Image or numpy array")
164
 
165
- overlay = image.copy()
166
- cmap = colormap(N=len(id_to_names), normalized=False)
167
-
168
  if len(bboxes) == 0:
169
  return cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
170
 
 
 
 
171
  for i in range(len(bboxes)):
172
  try:
173
  bbox = bboxes[i]
@@ -186,43 +176,52 @@ def visualize_bbox(image_input, bboxes, classes, scores, id_to_names, alpha=0.3)
186
  class_id = int(class_id)
187
  class_name = id_to_names.get(class_id, f"unknown_{class_id}")
188
 
189
- text = f"{class_name}:{score:.3f}"
190
  color = tuple(int(c) for c in cmap[class_id % len(cmap)])
191
 
 
192
  cv2.rectangle(overlay, (x_min, y_min), (x_max, y_max), color, -1)
193
- cv2.rectangle(image, (x_min, y_min), (x_max, y_max), color, 2)
 
194
 
195
- (text_width, text_height), baseline = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, 0.7, 2)
196
- cv2.rectangle(image, (x_min, y_min - text_height - baseline), (x_min + text_width, y_min), color, -1)
197
- cv2.putText(image, text, (x_min, y_min - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2)
 
 
198
 
199
  except Exception as e:
200
  print(f"Skipping box {i} due to error: {e}")
201
 
 
202
  cv2.addWeighted(overlay, alpha, image, 1 - alpha, 0, image)
 
203
  return cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
204
 
205
- def recognize_image(input_img, conf_threshold, iou_threshold, nms_method, alpha):
206
- """Process image with docling layout model"""
207
  if input_img is None:
208
- return None, "Please upload an image first."
209
 
210
  if current_model is None or current_processor is None:
211
- return None, "Please load a model first."
212
 
213
  try:
 
214
  if isinstance(input_img, np.ndarray):
215
  input_img = Image.fromarray(input_img)
216
 
217
  if input_img.mode != 'RGB':
218
  input_img = input_img.convert('RGB')
219
 
 
220
  inputs = current_processor(images=[input_img], return_tensors="pt")
221
  inputs = {k: v.to(device) for k, v in inputs.items()}
222
 
223
  with torch.no_grad():
224
  outputs = current_model(**inputs)
225
 
 
226
  results = current_processor.post_process_object_detection(
227
  outputs,
228
  target_sizes=torch.tensor([input_img.size[::-1]]),
@@ -230,7 +229,7 @@ def recognize_image(input_img, conf_threshold, iou_threshold, nms_method, alpha)
230
  )
231
 
232
  if not results or len(results) == 0:
233
- return np.array(input_img), "No detections found."
234
 
235
  result = results[0]
236
  boxes = result["boxes"]
@@ -238,116 +237,241 @@ def recognize_image(input_img, conf_threshold, iou_threshold, nms_method, alpha)
238
  labels = result["labels"]
239
 
240
  if len(boxes) == 0:
241
- return np.array(input_img), "No detections above confidence threshold."
242
 
 
243
  if iou_threshold < 1.0:
244
  if nms_method == "Custom IoMin":
245
- keep_indices = nms(boxes=boxes, scores=scores, iou_threshold=iou_threshold)
246
  else:
247
- keep_indices = torch.ops.torchvision.nms(boxes=boxes, scores=scores, iou_threshold=iou_threshold)
 
248
 
249
  boxes = boxes[keep_indices]
250
  scores = scores[keep_indices]
251
  labels = labels[keep_indices]
252
 
253
- if len(boxes.shape) == 1:
254
- boxes = boxes.unsqueeze(0)
255
- scores = scores.unsqueeze(0)
256
- labels = labels.unsqueeze(0)
257
-
258
  output = visualize_bbox(input_img, boxes, labels, scores, classes_map, alpha=alpha)
259
- detection_info = f"Found {len(boxes)} detections after NMS ({nms_method})"
260
- return output, detection_info
 
261
 
262
  except Exception as e:
263
- print(f"[ERROR] recognize_image failed: {e}")
264
- error_msg = f"Error during processing: {str(e)}"
265
  if input_img is not None:
266
  return np.array(input_img), error_msg
267
  return np.zeros((512, 512, 3), dtype=np.uint8), error_msg
268
 
269
- def gradio_reset():
 
270
  return gr.update(value=None), gr.update(value=None), gr.update(value="")
271
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
272
  if __name__ == "__main__":
273
- print(f"Using device: {device}")
 
 
274
 
275
- # Custom CSS for better scrolling and layout
276
  custom_css = """
277
  .gradio-container {
278
- max-width: 1200px !important;
279
- margin: auto !important;
280
  }
281
- .main-content {
282
- overflow-y: auto !important;
283
- max-height: 100vh !important;
 
284
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
285
  """
286
 
287
- with gr.Blocks(title="Document Layout Analysis", theme=gr.themes.Soft(), css=custom_css) as demo:
 
 
 
 
 
 
288
  # Header
289
  gr.HTML("""
290
- <div style="text-align: center; margin-bottom: 20px;">
291
- <h1>πŸ” Document Layout Analysis</h1>
292
- <p>Using Docling Layout Models for document structure detection</p>
293
  </div>
294
  """)
295
 
 
296
  with gr.Row():
297
- # Left Column - Controls
298
- with gr.Column(scale=1):
299
- # Model selection
300
- model_dropdown = gr.Dropdown(
301
- choices=list(MODELS.keys()),
302
- value="Egret XLarge",
303
- label="πŸ€– Select Model"
304
- )
305
-
306
- load_btn = gr.Button("πŸ“₯ Load Model", variant="secondary", size="sm")
307
- model_status = gr.Textbox(label="Model Status", interactive=False, value="No model loaded", max_lines=2)
308
 
309
- input_img = gr.Image(label="πŸ“„ Upload Image", type="pil", height=300)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
310
 
311
- with gr.Row():
312
- clear = gr.Button("πŸ—‘οΈ Clear", size="sm")
313
- predict = gr.Button("πŸ” Detect", variant="primary", size="sm")
 
 
 
 
 
 
 
 
 
314
 
315
- # Parameters
316
- conf_threshold = gr.Slider(0.0, 1.0, value=0.6, step=0.05, label="Confidence Threshold")
317
- iou_threshold = gr.Slider(0.0, 1.0, value=0.5, step=0.05, label="NMS IoU Threshold")
318
- nms_method = gr.Radio(["Custom IoMin", "Standard IoU"], value="Custom IoMin", label="NMS Method")
319
- alpha_slider = gr.Slider(0.0, 1.0, value=0.3, step=0.1, label="Overlay Transparency")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
320
 
321
- # Right Column - Results
322
- with gr.Column(scale=1):
323
- gr.HTML("<h3>🎯 Detection Results</h3>")
324
- output_img = gr.Image(label="Detected Layout", interactive=False, type="numpy", height=400)
325
- detection_info = gr.Textbox(label="Detection Info", interactive=False, max_lines=2)
 
 
 
 
 
 
 
 
 
 
 
 
 
326
 
327
- # Legend at the bottom
328
- with gr.Accordion("πŸ“‹ Detected Classes", open=False):
329
- cmap = colormap(N=len(classes_map), normalized=False)
330
- legend_items = []
331
- for class_id, class_name in classes_map.items():
332
- color_rgb = cmap[class_id % len(cmap)]
333
- color_hex = f"#{color_rgb[0]:02x}{color_rgb[1]:02x}{color_rgb[2]:02x}"
334
- legend_items.append(f'<span style="display:inline-block;width:15px;height:15px;background-color:{color_hex};margin-right:5px;border:1px solid #ccc;"></span>{class_name}')
335
-
336
- legend_html = f"""
337
- <div style='display: grid; grid-template-columns: repeat(3, 1fr); gap: 10px; font-size: 14px;'>
338
- {''.join([f'<div>{item}</div>' for item in legend_items])}
339
- </div>
340
- """
341
- gr.HTML(legend_html)
 
342
 
343
- # Event handlers
344
- load_btn.click(load_model, inputs=[model_dropdown], outputs=[model_status])
345
- clear.click(gradio_reset, inputs=None, outputs=[input_img, output_img, detection_info])
346
- predict.click(
347
- recognize_image,
348
- inputs=[input_img, conf_threshold, iou_threshold, nms_method, alpha_slider],
349
  outputs=[output_img, detection_info]
350
  )
351
 
352
- # Launch
353
- demo.launch(server_name="0.0.0.0", server_port=7860, debug=True, share=False)
 
 
 
 
 
 
 
 
3
 
4
  import sys
5
  import torch
6
+ import torchvision
7
  import gradio as gr
8
  import numpy as np
9
  import cv2
 
14
  RTDetrImageProcessor,
15
  )
16
 
17
+ # == Device configuration ==
18
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
19
 
20
+ # == Model configurations ==
21
  MODELS = {
22
  "Egret XLarge": {
23
  "path": "ds4sd/docling-layout-egret-xlarge",
 
41
  }
42
  }
43
 
44
+ # == Class mappings ==
45
  classes_map = {
46
+ 0: "Caption", 1: "Footnote", 2: "Formula", 3: "List-item",
47
+ 4: "Page-footer", 5: "Page-header", 6: "Picture", 7: "Section-header",
48
+ 8: "Table", 9: "Text", 10: "Title", 11: "Document Index",
49
+ 12: "Code", 13: "Checkbox-Selected", 14: "Checkbox-Unselected",
50
+ 15: "Form", 16: "Key-Value Region",
 
 
 
 
 
 
 
 
 
 
 
 
51
  }
52
 
53
+ # == Global model variables ==
54
  current_model = None
55
  current_processor = None
56
  current_model_name = None
57
 
58
  def colormap(N=256, normalized=False):
59
+ """Generate dynamic colormap."""
60
  def bitget(byteval, idx):
61
  return ((byteval & (1 << idx)) != 0)
62
 
 
73
 
74
  if normalized:
75
  cmap = cmap.astype(np.float32) / 255.0
 
76
  return cmap
77
 
78
  def iomin(box1, box2):
79
+ """Intersection over Minimum (IoMin)."""
80
  x1 = torch.max(box1[:, 0], box2[:, 0])
81
  y1 = torch.max(box1[:, 1], box2[:, 1])
82
  x2 = torch.min(box1[:, 2], box2[:, 2])
83
  y2 = torch.min(box1[:, 3], box2[:, 3])
84
  inter_area = torch.clamp(x2 - x1, min=0) * torch.clamp(y2 - y1, min=0)
85
+
86
  box1_area = (box1[:, 2] - box1[:, 0]) * (box1[:, 3] - box1[:, 1])
87
  box2_area = (box2[:, 2] - box2[:, 0]) * (box2[:, 3] - box2[:, 1])
88
  min_area = torch.min(box1_area, box2_area)
89
+
90
  return inter_area / min_area
91
 
92
+ def nms_custom(boxes, scores, iou_threshold=0.5):
93
+ """Custom NMS implementation using IoMin."""
94
  keep = []
95
  _, order = scores.sort(descending=True)
96
 
 
111
  return torch.tensor(keep, dtype=torch.long)
112
 
113
  def load_model(model_name):
114
+ """Load the selected model."""
115
  global current_model, current_processor, current_model_name
116
 
117
  if current_model_name == model_name:
118
  return f"βœ… Model {model_name} is already loaded!"
119
 
120
  try:
 
121
  model_info = MODELS[model_name]
122
  model_path = model_info["path"]
123
  model_class = model_info["model_class"]
124
 
125
+ print(f"Loading {model_name} from {model_path}")
126
+
127
  processor = RTDetrImageProcessor.from_pretrained(model_path)
128
  model = model_class.from_pretrained(model_path)
129
  model = model.to(device)
 
136
  return f"βœ… Successfully loaded {model_name}!"
137
 
138
  except Exception as e:
139
+ print(f"Error loading model: {e}")
140
  return f"❌ Error loading {model_name}: {str(e)}"
141
 
142
  def visualize_bbox(image_input, bboxes, classes, scores, id_to_names, alpha=0.3):
143
+ """Visualize bounding boxes with OpenCV."""
144
  if isinstance(image_input, Image.Image):
145
  image = np.array(image_input)
146
  image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
 
152
  else:
153
  raise ValueError("Input must be PIL Image or numpy array")
154
 
 
 
 
155
  if len(bboxes) == 0:
156
  return cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
157
 
158
+ overlay = image.copy()
159
+ cmap = colormap(N=len(id_to_names), normalized=False)
160
+
161
  for i in range(len(bboxes)):
162
  try:
163
  bbox = bboxes[i]
 
176
  class_id = int(class_id)
177
  class_name = id_to_names.get(class_id, f"unknown_{class_id}")
178
 
179
+ text = f"{class_name}: {score:.3f}"
180
  color = tuple(int(c) for c in cmap[class_id % len(cmap)])
181
 
182
+ # Draw filled rectangle on overlay
183
  cv2.rectangle(overlay, (x_min, y_min), (x_max, y_max), color, -1)
184
+ # Draw border on main image
185
+ cv2.rectangle(image, (x_min, y_min), (x_max, y_max), color, 3)
186
 
187
+ # Add text label
188
+ (text_width, text_height), baseline = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, 0.8, 2)
189
+ cv2.rectangle(image, (x_min, y_min - text_height - baseline - 4),
190
+ (x_min + text_width + 8, y_min), color, -1)
191
+ cv2.putText(image, text, (x_min + 4, y_min - 6), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (255, 255, 255), 2)
192
 
193
  except Exception as e:
194
  print(f"Skipping box {i} due to error: {e}")
195
 
196
+ # Apply transparency
197
  cv2.addWeighted(overlay, alpha, image, 1 - alpha, 0, image)
198
+
199
  return cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
200
 
201
+ def process_image(input_img, conf_threshold, iou_threshold, nms_method, alpha):
202
+ """Process image with document layout detection."""
203
  if input_img is None:
204
+ return None, "❌ Please upload an image first."
205
 
206
  if current_model is None or current_processor is None:
207
+ return None, "❌ Please load a model first."
208
 
209
  try:
210
+ # Prepare image
211
  if isinstance(input_img, np.ndarray):
212
  input_img = Image.fromarray(input_img)
213
 
214
  if input_img.mode != 'RGB':
215
  input_img = input_img.convert('RGB')
216
 
217
+ # Process with model
218
  inputs = current_processor(images=[input_img], return_tensors="pt")
219
  inputs = {k: v.to(device) for k, v in inputs.items()}
220
 
221
  with torch.no_grad():
222
  outputs = current_model(**inputs)
223
 
224
+ # Post-process results
225
  results = current_processor.post_process_object_detection(
226
  outputs,
227
  target_sizes=torch.tensor([input_img.size[::-1]]),
 
229
  )
230
 
231
  if not results or len(results) == 0:
232
+ return np.array(input_img), "ℹ️ No detections found."
233
 
234
  result = results[0]
235
  boxes = result["boxes"]
 
237
  labels = result["labels"]
238
 
239
  if len(boxes) == 0:
240
+ return np.array(input_img), f"ℹ️ No detections above threshold {conf_threshold:.2f}."
241
 
242
+ # Apply NMS
243
  if iou_threshold < 1.0:
244
  if nms_method == "Custom IoMin":
245
+ keep_indices = nms_custom(boxes=boxes, scores=scores, iou_threshold=iou_threshold)
246
  else:
247
+ # Use torchvision NMS with correct format
248
+ keep_indices = torchvision.ops.nms(boxes, scores, iou_threshold)
249
 
250
  boxes = boxes[keep_indices]
251
  scores = scores[keep_indices]
252
  labels = labels[keep_indices]
253
 
254
+ # Visualize results
 
 
 
 
255
  output = visualize_bbox(input_img, boxes, labels, scores, classes_map, alpha=alpha)
256
+ info = f"βœ… Found {len(boxes)} detections | NMS: {nms_method} | Threshold: {conf_threshold:.2f}"
257
+
258
+ return output, info
259
 
260
  except Exception as e:
261
+ print(f"[ERROR] process_image failed: {e}")
262
+ error_msg = f"❌ Processing error: {str(e)}"
263
  if input_img is not None:
264
  return np.array(input_img), error_msg
265
  return np.zeros((512, 512, 3), dtype=np.uint8), error_msg
266
 
267
+ def reset_interface():
268
+ """Reset all interface components."""
269
  return gr.update(value=None), gr.update(value=None), gr.update(value="")
270
 
271
+ def create_legend_html():
272
+ """Create HTML for the class legend."""
273
+ cmap = colormap(N=len(classes_map), normalized=False)
274
+ legend_items = []
275
+
276
+ for class_id, class_name in classes_map.items():
277
+ color_rgb = cmap[class_id % len(cmap)]
278
+ color_hex = f"#{color_rgb[0]:02x}{color_rgb[1]:02x}{color_rgb[2]:02x}"
279
+ legend_items.append(f"""
280
+ <div style='display: flex; align-items: center; padding: 8px; margin: 4px; background-color: #f8f9fa; border-radius: 6px;'>
281
+ <div style='width: 24px; height: 24px; background-color: {color_hex}; margin-right: 12px; border: 2px solid #dee2e6; border-radius: 4px;'></div>
282
+ <span style='font-weight: 500; color: #495057;'>{class_name}</span>
283
+ </div>
284
+ """)
285
+
286
+ return f"""
287
+ <div style='display: grid; grid-template-columns: repeat(auto-fit, minmax(200px, 1fr)); gap: 8px; padding: 16px; background-color: #ffffff; border-radius: 8px; border: 1px solid #e9ecef;'>
288
+ {''.join(legend_items)}
289
+ </div>
290
+ """
291
+
292
  if __name__ == "__main__":
293
+ print(f"πŸš€ Starting Document Layout Analysis App")
294
+ print(f"πŸ“± Device: {device}")
295
+ print(f"πŸ€– Available models: {len(MODELS)}")
296
 
297
+ # Custom CSS for full-width layout
298
  custom_css = """
299
  .gradio-container {
300
+ max-width: 100% !important;
301
+ padding: 20px !important;
302
  }
303
+
304
+ .main-container {
305
+ width: 100% !important;
306
+ max-width: none !important;
307
  }
308
+
309
+ .panel-left, .panel-right {
310
+ min-height: 600px;
311
+ padding: 20px;
312
+ background: #f8f9fa;
313
+ border-radius: 12px;
314
+ border: 1px solid #e9ecef;
315
+ }
316
+
317
+ .control-section {
318
+ margin-bottom: 20px;
319
+ padding: 15px;
320
+ background: white;
321
+ border-radius: 8px;
322
+ border: 1px solid #dee2e6;
323
+ }
324
+
325
+ .status-good { color: #28a745; font-weight: bold; }
326
+ .status-error { color: #dc3545; font-weight: bold; }
327
+ .status-info { color: #17a2b8; font-weight: bold; }
328
  """
329
 
330
+ # Create Gradio interface
331
+ with gr.Blocks(
332
+ title="πŸ“„ Document Layout Analysis - Full Width",
333
+ theme=gr.themes.Soft(),
334
+ css=custom_css
335
+ ) as demo:
336
+
337
  # Header
338
  gr.HTML("""
339
+ <div style='text-align: center; padding: 30px; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); color: white; border-radius: 15px; margin-bottom: 30px;'>
340
+ <h1 style='margin: 0; font-size: 3em; text-shadow: 2px 2px 4px rgba(0,0,0,0.3);'>πŸ” Document Layout Analysis</h1>
341
+ <p style='margin: 10px 0 0 0; font-size: 1.3em; opacity: 0.9;'>Advanced document structure detection with multiple AI models</p>
342
  </div>
343
  """)
344
 
345
+ # Main content in two columns
346
  with gr.Row():
347
+ # LEFT COLUMN - Controls and Input
348
+ with gr.Column(scale=1, elem_classes=["panel-left"]):
 
 
 
 
 
 
 
 
 
349
 
350
+ # Model Section
351
+ with gr.Group(elem_classes=["control-section"]):
352
+ gr.HTML("<h3>πŸ€– Model Configuration</h3>")
353
+
354
+ model_dropdown = gr.Dropdown(
355
+ choices=list(MODELS.keys()),
356
+ value="Egret XLarge",
357
+ label="Select Model",
358
+ info="Choose the AI model for document analysis",
359
+ interactive=True
360
+ )
361
+
362
+ with gr.Row():
363
+ load_btn = gr.Button("πŸ“₯ Load Model", variant="primary", scale=1)
364
+ clear_btn = gr.Button("πŸ—‘οΈ Clear All", variant="secondary", scale=1)
365
+
366
+ model_status = gr.Textbox(
367
+ label="Model Status",
368
+ value="οΏ½οΏ½οΏ½ No model loaded. Please select and load a model.",
369
+ interactive=False,
370
+ lines=2
371
+ )
372
 
373
+ # Image Upload Section
374
+ with gr.Group(elem_classes=["control-section"]):
375
+ gr.HTML("<h3>πŸ“„ Image Input</h3>")
376
+
377
+ input_img = gr.Image(
378
+ label="Upload Document Image",
379
+ type="pil",
380
+ height=400,
381
+ interactive=True
382
+ )
383
+
384
+ detect_btn = gr.Button("πŸ” Analyze Document", variant="primary", size="lg")
385
 
386
+ # Parameters Section
387
+ with gr.Group(elem_classes=["control-section"]):
388
+ gr.HTML("<h3>βš™οΈ Detection Parameters</h3>")
389
+
390
+ conf_threshold = gr.Slider(
391
+ minimum=0.0,
392
+ maximum=1.0,
393
+ value=0.6,
394
+ step=0.05,
395
+ label="Confidence Threshold",
396
+ info="Minimum confidence for detections"
397
+ )
398
+
399
+ iou_threshold = gr.Slider(
400
+ minimum=0.0,
401
+ maximum=1.0,
402
+ value=0.5,
403
+ step=0.05,
404
+ label="NMS IoU Threshold",
405
+ info="Non-maximum suppression threshold"
406
+ )
407
+
408
+ nms_method = gr.Radio(
409
+ choices=["Custom IoMin", "Standard IoU"],
410
+ value="Custom IoMin",
411
+ label="NMS Algorithm",
412
+ info="Choose suppression method"
413
+ )
414
+
415
+ alpha_slider = gr.Slider(
416
+ minimum=0.0,
417
+ maximum=1.0,
418
+ value=0.3,
419
+ step=0.1,
420
+ label="Overlay Transparency",
421
+ info="Transparency of detection overlays"
422
+ )
423
+
424
+ # RIGHT COLUMN - Results and Output
425
+ with gr.Column(scale=1, elem_classes=["panel-right"]):
426
 
427
+ # Results Section
428
+ with gr.Group(elem_classes=["control-section"]):
429
+ gr.HTML("<h3>🎯 Detection Results</h3>")
430
+
431
+ output_img = gr.Image(
432
+ label="Analyzed Document",
433
+ type="numpy",
434
+ height=500,
435
+ interactive=False
436
+ )
437
+
438
+ detection_info = gr.Textbox(
439
+ label="Analysis Summary",
440
+ value="",
441
+ interactive=False,
442
+ lines=3,
443
+ placeholder="Detection results will appear here..."
444
+ )
445
 
446
+ # Legend Section (Full Width)
447
+ with gr.Group():
448
+ with gr.Accordion("πŸ“‹ Class Legend - All Detectable Elements", open=False):
449
+ gr.HTML(create_legend_html())
450
+
451
+ # Event Handlers
452
+ load_btn.click(
453
+ fn=load_model,
454
+ inputs=[model_dropdown],
455
+ outputs=[model_status]
456
+ )
457
+
458
+ clear_btn.click(
459
+ fn=reset_interface,
460
+ outputs=[input_img, output_img, detection_info]
461
+ )
462
 
463
+ detect_btn.click(
464
+ fn=process_image,
465
+ inputs=[input_img, conf_threshold, iou_threshold, nms_method, alpha_slider],
 
 
 
466
  outputs=[output_img, detection_info]
467
  )
468
 
469
+ # Launch application
470
+ demo.launch(
471
+ server_name="0.0.0.0",
472
+ server_port=7860,
473
+ debug=True,
474
+ share=False,
475
+ show_error=True,
476
+ inbrowser=True
477
+ )