enpaiva commited on
Commit
584b383
Β·
verified Β·
1 Parent(s): a4cb188

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +145 -159
app.py CHANGED
@@ -111,11 +111,11 @@ def nms_custom(boxes, scores, iou_threshold=0.5):
111
  return torch.tensor(keep, dtype=torch.long)
112
 
113
  def load_model(model_name):
114
- """Load the selected model."""
115
  global current_model, current_processor, current_model_name
116
 
117
  if current_model_name == model_name:
118
- return f"βœ… Model {model_name} is already loaded!"
119
 
120
  try:
121
  model_info = MODELS[model_name]
@@ -133,11 +133,11 @@ def load_model(model_name):
133
  current_model = model
134
  current_model_name = model_name
135
 
136
- return f"βœ… Successfully loaded {model_name}!"
137
 
138
  except Exception as e:
139
  print(f"Error loading model: {e}")
140
- return f"❌ Error loading {model_name}: {str(e)}"
141
 
142
  def visualize_bbox(image_input, bboxes, classes, scores, id_to_names, alpha=0.3, show_labels=True):
143
  """Visualize bounding boxes with OpenCV."""
@@ -199,13 +199,15 @@ def visualize_bbox(image_input, bboxes, classes, scores, id_to_names, alpha=0.3,
199
 
200
  return cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
201
 
202
- def process_image(input_img, conf_threshold, iou_threshold, nms_method, alpha, show_labels):
203
  """Process image with document layout detection."""
204
  if input_img is None:
205
  return None, "❌ Please upload an image first."
206
-
207
- if current_model is None or current_processor is None:
208
- return None, "❌ Please load a model first."
 
 
209
 
210
  try:
211
  # Prepare image
@@ -216,14 +218,14 @@ def process_image(input_img, conf_threshold, iou_threshold, nms_method, alpha, s
216
  input_img = input_img.convert('RGB')
217
 
218
  # Process with model
219
- inputs = current_processor(images=[input_img], return_tensors="pt")
220
  inputs = {k: v.to(device) for k, v in inputs.items()}
221
 
222
  with torch.no_grad():
223
- outputs = current_model(**inputs)
224
 
225
  # Post-process results
226
- results = current_processor.post_process_object_detection(
227
  outputs,
228
  target_sizes=torch.tensor([input_img.size[::-1]]),
229
  threshold=conf_threshold,
@@ -256,7 +258,7 @@ def process_image(input_img, conf_threshold, iou_threshold, nms_method, alpha, s
256
  output = visualize_bbox(input_img, boxes, labels, scores, classes_map, alpha=alpha, show_labels=show_labels)
257
 
258
  labels_status = "with labels" if show_labels else "without labels"
259
- info = f"βœ… Found {len(boxes)} detections ({labels_status}) | NMS: {nms_method} | Threshold: {conf_threshold:.2f}"
260
 
261
  return output, info
262
 
@@ -267,58 +269,54 @@ def process_image(input_img, conf_threshold, iou_threshold, nms_method, alpha, s
267
  return np.array(input_img), error_msg
268
  return np.zeros((512, 512, 3), dtype=np.uint8), error_msg
269
 
270
- def reset_interface():
271
- """Reset all interface components."""
272
- return gr.update(value=None), gr.update(value=None), gr.update(value="")
273
-
274
  if __name__ == "__main__":
275
  print(f"πŸš€ Starting Document Layout Analysis App")
276
  print(f"πŸ“± Device: {device}")
277
  print(f"πŸ€– Available models: {len(MODELS)}")
278
 
279
- # Custom CSS for full-width layout
280
  custom_css = """
281
  .gradio-container {
282
- max-width: 100% !important;
 
283
  padding: 20px !important;
284
  }
285
 
286
- .main-container {
287
- width: 100% !important;
288
- max-width: none !important;
289
- }
290
-
291
- .panel-left, .panel-right {
292
- min-height: 600px;
293
- padding: 20px;
294
  background: #f8f9fa;
295
  border-radius: 12px;
296
- border: 1px solid #e9ecef;
 
 
297
  }
298
 
299
- .control-section {
300
- margin-bottom: 20px;
301
- padding: 15px;
302
- background: white;
303
- border-radius: 8px;
304
  border: 1px solid #dee2e6;
 
305
  }
306
 
307
- .status-good { color: #28a745; font-weight: bold; }
308
- .status-error { color: #dc3545; font-weight: bold; }
309
- .status-info { color: #17a2b8; font-weight: bold; }
 
 
310
 
311
- .toggle-labels {
312
  background: linear-gradient(45deg, #667eea, #764ba2) !important;
313
  border: none !important;
314
  color: white !important;
315
  font-weight: bold !important;
 
 
 
316
  }
317
  """
318
 
319
  # Create Gradio interface
320
  with gr.Blocks(
321
- title="πŸ“„ Document Layout Analysis - Full Width",
322
  theme=gr.themes.Soft(),
323
  css=custom_css
324
  ) as demo:
@@ -326,138 +324,126 @@ if __name__ == "__main__":
326
  # Header
327
  gr.HTML("""
328
  <div style='text-align: center; padding: 30px; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); color: white; border-radius: 15px; margin-bottom: 30px;'>
329
- <h1 style='margin: 0; font-size: 3em; text-shadow: 2px 2px 4px rgba(0,0,0,0.3);'>πŸ” Document Layout Analysis</h1>
330
- <p style='margin: 10px 0 0 0; font-size: 1.3em; opacity: 0.9;'>Advanced document structure detection with multiple AI models</p>
331
  </div>
332
  """)
333
 
334
- # Main content in two columns
335
- with gr.Row():
336
- # LEFT COLUMN - Controls and Input
337
- with gr.Column(scale=1, elem_classes=["panel-left"]):
338
-
339
- # Model Section
340
- with gr.Group(elem_classes=["control-section"]):
341
- gr.HTML("<h3>πŸ€– Model Configuration</h3>")
342
-
343
- model_dropdown = gr.Dropdown(
344
- choices=list(MODELS.keys()),
345
- value="Egret XLarge",
346
- label="Select Model",
347
- info="Choose the AI model for document analysis",
348
- interactive=True
349
- )
350
-
351
- with gr.Row():
352
- load_btn = gr.Button("πŸ“₯ Load Model", variant="primary", scale=1)
353
- clear_btn = gr.Button("πŸ—‘οΈ Clear All", variant="secondary", scale=1)
354
-
355
- model_status = gr.Textbox(
356
- label="Model Status",
357
- value="πŸ”„ No model loaded. Please select and load a model.",
358
- interactive=False,
359
- lines=2
360
- )
361
-
362
- # Image Upload Section
363
- with gr.Group(elem_classes=["control-section"]):
364
- gr.HTML("<h3>πŸ“„ Image Input</h3>")
365
-
366
- input_img = gr.Image(
367
- label="Upload Document Image",
368
- type="pil",
369
- height=400,
370
- interactive=True
371
- )
372
-
373
- detect_btn = gr.Button("πŸ” Analyze Document", variant="primary", size="lg")
374
 
375
- # Parameters Section
376
- with gr.Group(elem_classes=["control-section"]):
377
- gr.HTML("<h3>βš™οΈ Detection Parameters</h3>")
378
-
379
- conf_threshold = gr.Slider(
380
- minimum=0.0,
381
- maximum=1.0,
382
- value=0.6,
383
- step=0.05,
384
- label="Confidence Threshold",
385
- info="Minimum confidence for detections"
386
- )
387
-
388
- iou_threshold = gr.Slider(
389
- minimum=0.0,
390
- maximum=1.0,
391
- value=0.5,
392
- step=0.05,
393
- label="NMS IoU Threshold",
394
- info="Non-maximum suppression threshold"
395
- )
396
-
397
- nms_method = gr.Radio(
398
- choices=["Custom IoMin", "Standard IoU"],
399
- value="Custom IoMin",
400
- label="NMS Algorithm",
401
- info="Choose suppression method"
402
- )
403
-
404
- alpha_slider = gr.Slider(
405
- minimum=0.0,
406
- maximum=1.0,
407
- value=0.3,
408
- step=0.1,
409
- label="Overlay Transparency",
410
- info="Transparency of detection overlays"
411
- )
412
 
413
- # RIGHT COLUMN - Results and Output
414
- with gr.Column(scale=1, elem_classes=["panel-right"]):
 
 
 
 
 
415
 
416
- # Results Section
417
- with gr.Group(elem_classes=["control-section"]):
418
- gr.HTML("<h3>🎯 Detection Results</h3>")
419
-
420
- output_img = gr.Image(
421
- label="Analyzed Document",
422
- type="numpy",
423
- height=500,
424
- interactive=False
425
- )
426
-
427
- detection_info = gr.Textbox(
428
- label="Analysis Summary",
429
- value="",
430
- interactive=False,
431
- lines=3,
432
- placeholder="Detection results will appear here..."
433
- )
434
-
435
- # Visualization Options Section
436
- with gr.Group(elem_classes=["control-section"]):
437
- gr.HTML("<h3>🎨 Visualization Options</h3>")
438
-
439
- show_labels_checkbox = gr.Checkbox(
440
- value=True,
441
- label="Show Class Labels",
442
- info="Display class names and confidence scores on detections",
443
- interactive=True
444
- )
445
-
446
- # Event Handlers
447
- load_btn.click(
448
- fn=load_model,
449
- inputs=[model_dropdown],
450
- outputs=[model_status]
451
- )
452
 
453
- clear_btn.click(
454
- fn=reset_interface,
455
- outputs=[input_img, output_img, detection_info]
456
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
457
 
 
458
  detect_btn.click(
459
  fn=process_image,
460
- inputs=[input_img, conf_threshold, iou_threshold, nms_method, alpha_slider, show_labels_checkbox],
 
 
 
 
 
 
 
 
461
  outputs=[output_img, detection_info]
462
  )
463
 
 
111
  return torch.tensor(keep, dtype=torch.long)
112
 
113
  def load_model(model_name):
114
+ """Load the selected model automatically."""
115
  global current_model, current_processor, current_model_name
116
 
117
  if current_model_name == model_name:
118
+ return current_model, current_processor
119
 
120
  try:
121
  model_info = MODELS[model_name]
 
133
  current_model = model
134
  current_model_name = model_name
135
 
136
+ return model, processor
137
 
138
  except Exception as e:
139
  print(f"Error loading model: {e}")
140
+ return None, None
141
 
142
  def visualize_bbox(image_input, bboxes, classes, scores, id_to_names, alpha=0.3, show_labels=True):
143
  """Visualize bounding boxes with OpenCV."""
 
199
 
200
  return cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
201
 
202
+ def process_image(input_img, model_name, conf_threshold, iou_threshold, nms_method, alpha, show_labels):
203
  """Process image with document layout detection."""
204
  if input_img is None:
205
  return None, "❌ Please upload an image first."
206
+
207
+ # Load model if needed
208
+ model, processor = load_model(model_name)
209
+ if model is None or processor is None:
210
+ return None, f"❌ Error loading model {model_name}."
211
 
212
  try:
213
  # Prepare image
 
218
  input_img = input_img.convert('RGB')
219
 
220
  # Process with model
221
+ inputs = processor(images=[input_img], return_tensors="pt")
222
  inputs = {k: v.to(device) for k, v in inputs.items()}
223
 
224
  with torch.no_grad():
225
+ outputs = model(**inputs)
226
 
227
  # Post-process results
228
+ results = processor.post_process_object_detection(
229
  outputs,
230
  target_sizes=torch.tensor([input_img.size[::-1]]),
231
  threshold=conf_threshold,
 
258
  output = visualize_bbox(input_img, boxes, labels, scores, classes_map, alpha=alpha, show_labels=show_labels)
259
 
260
  labels_status = "with labels" if show_labels else "without labels"
261
+ info = f"βœ… Found {len(boxes)} detections ({labels_status}) | Model: {model_name} | Confidence: {conf_threshold:.2f}"
262
 
263
  return output, info
264
 
 
269
  return np.array(input_img), error_msg
270
  return np.zeros((512, 512, 3), dtype=np.uint8), error_msg
271
 
 
 
 
 
272
  if __name__ == "__main__":
273
  print(f"πŸš€ Starting Document Layout Analysis App")
274
  print(f"πŸ“± Device: {device}")
275
  print(f"πŸ€– Available models: {len(MODELS)}")
276
 
277
+ # Custom CSS for compact layout
278
  custom_css = """
279
  .gradio-container {
280
+ max-width: 1400px !important;
281
+ margin: 0 auto !important;
282
  padding: 20px !important;
283
  }
284
 
285
+ .controls-container {
 
 
 
 
 
 
 
286
  background: #f8f9fa;
287
  border-radius: 12px;
288
+ border: 1px solid #dee2e6;
289
+ padding: 20px;
290
+ margin-bottom: 20px;
291
  }
292
 
293
+ .results-container {
294
+ background: #ffffff;
295
+ border-radius: 12px;
 
 
296
  border: 1px solid #dee2e6;
297
+ padding: 20px;
298
  }
299
 
300
+ .section-divider {
301
+ border-top: 2px solid #e9ecef;
302
+ margin: 20px 0;
303
+ padding-top: 20px;
304
+ }
305
 
306
+ .analyze-btn {
307
  background: linear-gradient(45deg, #667eea, #764ba2) !important;
308
  border: none !important;
309
  color: white !important;
310
  font-weight: bold !important;
311
+ font-size: 18px !important;
312
+ padding: 15px 30px !important;
313
+ border-radius: 10px !important;
314
  }
315
  """
316
 
317
  # Create Gradio interface
318
  with gr.Blocks(
319
+ title="πŸ“„ Document Layout Analysis",
320
  theme=gr.themes.Soft(),
321
  css=custom_css
322
  ) as demo:
 
324
  # Header
325
  gr.HTML("""
326
  <div style='text-align: center; padding: 30px; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); color: white; border-radius: 15px; margin-bottom: 30px;'>
327
+ <h1 style='margin: 0; font-size: 2.5em; text-shadow: 2px 2px 4px rgba(0,0,0,0.3);'>πŸ” Document Layout Analysis</h1>
328
+ <p style='margin: 10px 0 0 0; font-size: 1.2em; opacity: 0.9;'>Compact interface for advanced document structure detection</p>
329
  </div>
330
  """)
331
 
332
+ # Controls Section
333
+ with gr.Group(elem_classes=["controls-container"]):
334
+ # 1. Image Upload (First)
335
+ gr.HTML("<h3 style='margin-top: 0;'>πŸ“„ Upload Document</h3>")
336
+ input_img = gr.Image(
337
+ label="Document Image",
338
+ type="pil",
339
+ height=300,
340
+ interactive=True
341
+ )
342
+
343
+ # Divider
344
+ gr.HTML("<div class='section-divider'></div>")
345
+
346
+ # 2. Model Selection (Second)
347
+ gr.HTML("<h3>πŸ€– Model Selection</h3>")
348
+ model_dropdown = gr.Dropdown(
349
+ choices=list(MODELS.keys()),
350
+ value="Egret XLarge",
351
+ label="AI Model",
352
+ info="Model will load automatically when analyzing",
353
+ interactive=True
354
+ )
355
+
356
+ # Divider
357
+ gr.HTML("<div class='section-divider'></div>")
358
+
359
+ # 3. Detection Parameters (Third)
360
+ gr.HTML("<h3>βš™οΈ Detection Settings</h3>")
361
+
362
+ with gr.Row():
363
+ conf_threshold = gr.Slider(
364
+ minimum=0.0,
365
+ maximum=1.0,
366
+ value=0.6,
367
+ step=0.05,
368
+ label="Confidence Threshold",
369
+ info="Minimum confidence for detections"
370
+ )
 
371
 
372
+ iou_threshold = gr.Slider(
373
+ minimum=0.0,
374
+ maximum=1.0,
375
+ value=0.5,
376
+ step=0.05,
377
+ label="NMS IoU Threshold",
378
+ info="Non-maximum suppression threshold"
379
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
380
 
381
+ with gr.Row():
382
+ nms_method = gr.Radio(
383
+ choices=["Custom IoMin", "Standard IoU"],
384
+ value="Custom IoMin",
385
+ label="NMS Algorithm",
386
+ info="Choose suppression method"
387
+ )
388
 
389
+ alpha_slider = gr.Slider(
390
+ minimum=0.0,
391
+ maximum=1.0,
392
+ value=0.3,
393
+ step=0.1,
394
+ label="Overlay Transparency",
395
+ info="Transparency of detection overlays"
396
+ )
397
+
398
+ show_labels_checkbox = gr.Checkbox(
399
+ value=True,
400
+ label="Show Class Labels and Confidence Scores",
401
+ info="Display detection labels on the output image",
402
+ interactive=True
403
+ )
404
+
405
+ # Divider
406
+ gr.HTML("<div class='section-divider'></div>")
407
+
408
+ # 4. Analyze Button (Last)
409
+ detect_btn = gr.Button(
410
+ "πŸ” Analyze Document",
411
+ variant="primary",
412
+ size="lg",
413
+ elem_classes=["analyze-btn"]
414
+ )
 
 
 
 
 
 
 
 
 
 
415
 
416
+ # Results Section
417
+ with gr.Group(elem_classes=["results-container"]):
418
+ gr.HTML("<h3 style='margin-top: 0;'>🎯 Analysis Results</h3>")
419
+
420
+ output_img = gr.Image(
421
+ label="Analyzed Document",
422
+ type="numpy",
423
+ height=600,
424
+ interactive=False
425
+ )
426
+
427
+ detection_info = gr.Textbox(
428
+ label="Detection Summary",
429
+ value="Ready for analysis. Upload an image and click 'Analyze Document'.",
430
+ interactive=False,
431
+ lines=2,
432
+ show_copy_button=True
433
+ )
434
 
435
+ # Event Handler
436
  detect_btn.click(
437
  fn=process_image,
438
+ inputs=[
439
+ input_img,
440
+ model_dropdown,
441
+ conf_threshold,
442
+ iou_threshold,
443
+ nms_method,
444
+ alpha_slider,
445
+ show_labels_checkbox
446
+ ],
447
  outputs=[output_img, detection_info]
448
  )
449