iammraat commited on
Commit
c70ab27
·
verified ·
1 Parent(s): e55fda2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +75 -45
app.py CHANGED
@@ -2,42 +2,57 @@ import gradio as gr
2
  import cv2
3
  import numpy as np
4
  import onnxruntime as ort
5
- from huggingface_hub import hf_hub_download
6
 
7
- # --- STEP 1: Download the ONNX Model ---
8
- print("Downloading ONNX model...")
9
- model_path = hf_hub_download(repo_id="alex-dinh/PP-DocLayoutV3-ONNX", filename="model.onnx")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  print(f"Model downloaded to: {model_path}")
11
 
12
- # --- STEP 2: Initialize ONNX Engine ---
13
- # This loads the AI "brain" directly without needing Paddle
14
  session = ort.InferenceSession(model_path)
15
  input_names = [i.name for i in session.get_inputs()]
16
  output_names = [o.name for o in session.get_outputs()]
17
 
18
- # Define labels map (Standard for PP-DocLayout)
19
- LABELS = {1: "Text", 2: "Title", 3: "List", 4: "Table", 5: "Figure"}
 
20
 
21
- def preprocess_image(image, target_size=(800, 800)):
22
  """
23
- Prepares the image exactly how the AI expects it (Resize -> Normalize).
24
  """
25
  h, w = image.shape[:2]
26
 
27
- # 1. Resize
28
- # We do NOT keep aspect ratio for the input blob, but we keep scales to fix boxes later
29
- img_resized = cv2.resize(image, target_size)
30
 
31
- # 2. Normalize (Standard ImageNet mean/std)
32
  img_data = img_resized.astype(np.float32) / 255.0
33
  mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
34
  std = np.array([0.229, 0.224, 0.225], dtype=np.float32)
35
  img_data = (img_data - mean) / std
36
 
37
- # 3. Transpose to (Batch, Channel, Height, Width)
38
  img_data = img_data.transpose(2, 0, 1)[None, :, :, :]
39
 
40
- # Calculate scale factors to map detections back to original image
 
41
  scale_factor = np.array([target_size[0] / h, target_size[1] / w], dtype=np.float32).reshape(1, 2)
42
 
43
  return img_data, scale_factor
@@ -46,60 +61,75 @@ def analyze_layout(input_image):
46
  if input_image is None:
47
  return None, "No image uploaded"
48
 
49
- # Convert PIL to Numpy/OpenCV
50
  image_np = np.array(input_image)
51
- orig_h, orig_w = image_np.shape[:2]
52
 
53
  # --- INFERENCE ---
 
54
  input_blob, scale_factor = preprocess_image(image_np)
55
 
56
- # ONNX Runtime inputs
57
  inputs = {
58
- input_names[0]: input_blob, # The image data
59
- input_names[1]: scale_factor # The resize scale
60
  }
61
 
62
- # Run!
 
 
 
63
  outputs = session.run(output_names, inputs)
64
 
65
- # --- POST-PROCESSING ---
66
- # Output format is typically [Batch, N, 6] -> [Class, Score, X1, Y1, X2, Y2]
67
- detections = outputs[0]
 
68
 
69
  viz_image = image_np.copy()
70
  log = []
71
 
 
 
 
 
72
  for det in detections:
73
- class_id = int(det[0])
74
  score = det[1]
75
- bbox = det[2:]
76
 
77
- if score < 0.5: continue # Filter weak detections
 
78
 
79
  # Map labels
80
- label_name = LABELS.get(class_id, "Unknown")
81
 
82
- # Coordinates
83
- x1, y1, x2, y2 = map(int, bbox)
84
-
85
- # Color coding
86
- color = (0, 255, 0) # Green
87
- if label_name == "Title": color = (0, 0, 255)
88
- elif label_name == "Table": color = (255, 255, 0)
89
- elif label_name == "Figure": color = (255, 0, 0)
90
-
91
- # Draw
92
- cv2.rectangle(viz_image, (x1, y1), (x2, y2), color, 3)
93
- cv2.putText(viz_image, f"{label_name} {score:.2f}", (x1, y1-10),
94
- cv2.FONT_HERSHEY_SIMPLEX, 0.6, color, 2)
95
-
96
- log.append(f"Found {label_name} at [{x1}, {y1}, {x2}, {y2}] (Conf: {score:.2f})")
 
 
 
 
 
 
 
97
 
98
  return viz_image, "\n".join(log)
99
 
100
  with gr.Blocks(title="ONNX Layout Analysis") as demo:
101
  gr.Markdown("## ⚡ Fast V3 Layout Analysis (ONNX)")
102
- gr.Markdown("Uses **PP-DocLayoutV3** via ONNX Runtime. No Paddle dependencies.")
103
 
104
  with gr.Row():
105
  with gr.Column():
 
2
  import cv2
3
  import numpy as np
4
  import onnxruntime as ort
5
+ from huggingface_hub import hf_hub_download, list_repo_files
6
 
7
+ # --- STEP 1: Find the correct ONNX filename ---
8
+ REPO_ID = "alex-dinh/PP-DocLayoutV3-ONNX"
9
+ print(f"Searching for ONNX model in {REPO_ID}...")
10
+
11
+ # Get list of all files in the repo
12
+ all_files = list_repo_files(repo_id=REPO_ID)
13
+
14
+ # Find the first file that ends with .onnx
15
+ onnx_filename = next((f for f in all_files if f.endswith('.onnx')), None)
16
+
17
+ if onnx_filename is None:
18
+ raise FileNotFoundError(f"No .onnx file found in {REPO_ID}. Repo contents: {all_files}")
19
+
20
+ print(f"Found model file: {onnx_filename}")
21
+
22
+ # --- STEP 2: Download the Model ---
23
+ model_path = hf_hub_download(repo_id=REPO_ID, filename=onnx_filename)
24
  print(f"Model downloaded to: {model_path}")
25
 
26
+ # --- STEP 3: Initialize ONNX Engine ---
 
27
  session = ort.InferenceSession(model_path)
28
  input_names = [i.name for i in session.get_inputs()]
29
  output_names = [o.name for o in session.get_outputs()]
30
 
31
+ # Standard PP-DocLayoutV3 labels
32
+ # Based on the paper/standard V3 config
33
+ LABELS = {0: "Text", 1: "Title", 2: "List", 3: "Table", 4: "Figure"}
34
 
35
+ def preprocess_image(image, target_size=(640, 640)):
36
  """
37
+ Standard RT-DETR / Paddle Preprocessing.
38
  """
39
  h, w = image.shape[:2]
40
 
41
+ # 1. Resize (Standard size for V3 is usually 640x640 or 800x800)
42
+ # We use linear interpolation
43
+ img_resized = cv2.resize(image, target_size, interpolation=cv2.INTER_LINEAR)
44
 
45
+ # 2. Normalize (ImageNet mean/std)
46
  img_data = img_resized.astype(np.float32) / 255.0
47
  mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
48
  std = np.array([0.229, 0.224, 0.225], dtype=np.float32)
49
  img_data = (img_data - mean) / std
50
 
51
+ # 3. Transpose (HWC -> CHW)
52
  img_data = img_data.transpose(2, 0, 1)[None, :, :, :]
53
 
54
+ # Scale factor for post-processing
55
+ # shape: [batch, 2] -> [height_scale, width_scale]
56
  scale_factor = np.array([target_size[0] / h, target_size[1] / w], dtype=np.float32).reshape(1, 2)
57
 
58
  return img_data, scale_factor
 
61
  if input_image is None:
62
  return None, "No image uploaded"
63
 
 
64
  image_np = np.array(input_image)
 
65
 
66
  # --- INFERENCE ---
67
+ # Prepare input
68
  input_blob, scale_factor = preprocess_image(image_np)
69
 
70
+ # Run ONNX (No Paddle dependency!)
71
  inputs = {
72
+ input_names[0]: input_blob,
73
+ input_names[1]: scale_factor # Some exports require this, others don't.
74
  }
75
 
76
+ # Handle exports that don't use scale_factor input
77
+ if len(input_names) == 1:
78
+ del inputs[input_names[1]]
79
+
80
  outputs = session.run(output_names, inputs)
81
 
82
+ # --- PARSE RESULTS ---
83
+ # Output format varies by export, but usually it's [Batch, N, 6]
84
+ # [Class, Score, X1, Y1, X2, Y2]
85
+ detections = outputs[0]
86
 
87
  viz_image = image_np.copy()
88
  log = []
89
 
90
+ # If the output is wrapped in a batch dimension, unwrap it
91
+ if len(detections.shape) == 3:
92
+ detections = detections[0]
93
+
94
  for det in detections:
95
+ # Check confidence (usually index 1)
96
  score = det[1]
97
+ if score < 0.4: continue
98
 
99
+ class_id = int(det[0])
100
+ bbox = det[2:]
101
 
102
  # Map labels
103
+ label_name = LABELS.get(class_id, f"Class {class_id}")
104
 
105
+ # Draw Box
106
+ try:
107
+ x1, y1, x2, y2 = map(int, bbox)
108
+
109
+ # Color coding
110
+ color = (0, 255, 0) # Green
111
+ if "Title" in label_name: color = (0, 0, 255)
112
+ elif "Table" in label_name: color = (255, 255, 0)
113
+ elif "Figure" in label_name: color = (255, 0, 0)
114
+
115
+ cv2.rectangle(viz_image, (x1, y1), (x2, y2), color, 3)
116
+
117
+ # Label
118
+ label_text = f"{label_name} {score:.2f}"
119
+ (w, h), _ = cv2.getTextSize(label_text, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 2)
120
+ cv2.rectangle(viz_image, (x1, y1 - 20), (x1 + w, y1), color, -1)
121
+ cv2.putText(viz_image, label_text, (x1, y1 - 5),
122
+ cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)
123
+
124
+ log.append(f"Found {label_name} at [{x1}, {y1}, {x2}, {y2}]")
125
+ except Exception:
126
+ pass
127
 
128
  return viz_image, "\n".join(log)
129
 
130
  with gr.Blocks(title="ONNX Layout Analysis") as demo:
131
  gr.Markdown("## ⚡ Fast V3 Layout Analysis (ONNX)")
132
+ gr.Markdown(f"Auto-loaded model from `{REPO_ID}`")
133
 
134
  with gr.Row():
135
  with gr.Column():