Theo Viel commited on
Commit
ab35335
·
1 Parent(s): a38bc63
model.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import torch
4
+ import importlib
5
+ import numpy as np
6
+ import numpy.typing as npt
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from typing import Dict, List, Tuple, Union
10
+ from yolox.boxes import postprocess
11
+
12
+
13
+ def define_model(config_name: str = "page_element_v3", verbose: bool = True) -> nn.Module:
14
+ """
15
+ Defines and initializes the model based on the configuration.
16
+
17
+ Args:
18
+ config_name (str): Configuration name. Defaults to "page_element_v3".
19
+ verbose (bool): Whether to print verbose output. Defaults to True.
20
+
21
+ Returns:
22
+ torch.nn.Module: The initialized YOLOX model.
23
+ """
24
+ # Load model from exp_file
25
+ sys.path.append(os.path.dirname(config_name))
26
+ exp_module = importlib.import_module(os.path.basename(config_name).split(".")[0])
27
+
28
+ config = exp_module.Exp()
29
+ model = config.get_model()
30
+
31
+ # Load weights
32
+ if verbose:
33
+ print(" -> Loading weights from", config.ckpt)
34
+
35
+ ckpt = torch.load(config.ckpt, map_location="cpu", weights_only=False)
36
+ model.load_state_dict(ckpt["model"], strict=True)
37
+
38
+ model = YoloXWrapper(model, config)
39
+ return model.eval().to(config.device)
40
+
41
+
42
+ def resize_pad(img: torch.Tensor, size: tuple) -> torch.Tensor:
43
+ """
44
+ Resizes and pads an image to a given size.
45
+ The goal is to preserve the aspect ratio of the image.
46
+
47
+ Args:
48
+ img (torch.Tensor[C x H x W]): The image to resize and pad.
49
+ size (tuple[2]): The size to resize and pad the image to.
50
+
51
+ Returns:
52
+ torch.Tensor: The resized and padded image.
53
+ """
54
+ img = img.float()
55
+ _, h, w = img.shape
56
+ scale = min(size[0] / h, size[1] / w)
57
+ nh = int(h * scale)
58
+ nw = int(w * scale)
59
+ img = F.interpolate(
60
+ img.unsqueeze(0), size=(nh, nw), mode="bilinear", align_corners=False
61
+ ).squeeze(0)
62
+ img = torch.clamp(img, 0, 255)
63
+ pad_b = size[0] - nh
64
+ pad_r = size[1] - nw
65
+ img = F.pad(img, (0, pad_r, 0, pad_b), value=114.0)
66
+ return img
67
+
68
+
69
+ class YoloXWrapper(nn.Module):
70
+ """
71
+ Wrapper for YoloX models.
72
+ """
73
+ def __init__(self, model: nn.Module, config) -> None:
74
+ """
75
+ Constructor
76
+
77
+ Args:
78
+ model (torch model): Yolo model.
79
+ config (Config): Config object containing model parameters.
80
+ """
81
+ super().__init__()
82
+ self.model = model
83
+ self.config = config
84
+
85
+ # Copy config parameters
86
+ self.device = config.device
87
+ self.img_size = config.size
88
+ self.min_bbox_size = config.min_bbox_size
89
+ self.normalize_boxes = config.normalize_boxes
90
+ self.conf_thresh = config.conf_thresh
91
+ self.iou_thresh = config.iou_thresh
92
+ self.class_agnostic = config.class_agnostic
93
+ self.threshold = config.threshold
94
+ self.labels = config.labels
95
+ self.num_classes = config.num_classes
96
+
97
+ def reformat_input(
98
+ self,
99
+ x: torch.Tensor,
100
+ orig_sizes: Union[torch.Tensor, List, Tuple, npt.NDArray]
101
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
102
+ """
103
+ Reformats the input data and original sizes to the correct format.
104
+
105
+ Args:
106
+ x (torch.Tensor[BS x C x H x W]): Input image batch.
107
+ orig_sizes (torch.Tensor or list or np.ndarray): Original image sizes.
108
+ Returns:
109
+ torch tensor [BS x C x H x W]: Input image batch.
110
+ torch tensor [BS x 2]: Original image sizes (before resizing and padding).
111
+ """
112
+ # Convert image size to tensor
113
+ if isinstance(orig_sizes, (list, tuple)):
114
+ orig_sizes = np.array(orig_sizes)
115
+ if orig_sizes.shape[-1] == 3: # remove channel
116
+ orig_sizes = orig_sizes[..., :2]
117
+ if isinstance(orig_sizes, np.ndarray):
118
+ orig_sizes = torch.from_numpy(orig_sizes).to(self.device)
119
+
120
+ # Add batch dimension if not present
121
+ if len(x.size()) == 3:
122
+ x = x.unsqueeze(0)
123
+ if len(orig_sizes.size()) == 1:
124
+ orig_sizes = orig_sizes.unsqueeze(0)
125
+
126
+ return x, orig_sizes
127
+
128
+ def preprocess(self, image: Union[torch.Tensor, npt.NDArray]) -> torch.Tensor:
129
+ """
130
+ YoloX preprocessing function:
131
+ - Resizes to the longest edge to img_size while preserving the aspect ratio
132
+ - Pads the shortest edge to img_size
133
+
134
+ Args:
135
+ image (torch tensor or np array [H x W x 3]): Input images in uint8 format.
136
+
137
+ Returns:
138
+ torch tensor [3 x H x W]: Processed image.
139
+ """
140
+ if not isinstance(image, torch.Tensor):
141
+ image = torch.from_numpy(image)
142
+ image = image.permute(2, 0, 1) # [H, W, 3] -> [3, H, W]
143
+ image = resize_pad(image, self.img_size)
144
+ return image.float()
145
+
146
+ def forward(
147
+ self,
148
+ x: torch.Tensor,
149
+ orig_sizes: Union[torch.Tensor, List, Tuple, npt.NDArray]
150
+ ) -> List[Dict[str, torch.Tensor]]:
151
+ """
152
+ Forward pass of the model.
153
+ Applies NMS and reformats the predictions.
154
+
155
+ Args:
156
+ x (torch.Tensor[BS x C x H x W]): Input image batch.
157
+ orig_sizes (torch.Tensor or list or np.ndarray): Original image sizes.
158
+
159
+ Returns:
160
+ list[dict]: List of prediction dictionaries. Each dictionary contains:
161
+ - labels (torch.Tensor[N]): Class labels
162
+ - boxes (torch.Tensor[N x 4]): Bounding boxes
163
+ - scores (torch.Tensor[N]): Confidence scores.
164
+ """
165
+ x, orig_sizes = self.reformat_input(x, orig_sizes)
166
+
167
+ # Scale to 0-255 if in range 0-1
168
+ if x.max() <= 1:
169
+ x *= 255
170
+
171
+ pred_boxes = self.model(x.to(self.device))
172
+
173
+ # NMS
174
+ pred_boxes = postprocess(
175
+ pred_boxes,
176
+ self.config.num_classes,
177
+ self.conf_thresh,
178
+ self.iou_thresh,
179
+ class_agnostic=self.class_agnostic,
180
+ )
181
+
182
+ # Reformat output
183
+ preds = []
184
+ for i, (p, size) in enumerate(zip(pred_boxes, orig_sizes)):
185
+ if p is None: # No detections
186
+ preds.append({
187
+ "labels": torch.empty(0),
188
+ "boxes": torch.empty((0, 4)),
189
+ "scores": torch.empty(0),
190
+ })
191
+ continue
192
+
193
+ p = p.view(-1, p.size(-1))
194
+ ratio = min(self.img_size[0] / size[0], self.img_size[1] / size[1])
195
+ boxes = p[:, :4] / ratio
196
+
197
+ # Clip
198
+ boxes[:, [0, 2]] = torch.clamp(boxes[:, [0, 2]], 0, size[1])
199
+ boxes[:, [1, 3]] = torch.clamp(boxes[:, [1, 3]], 0, size[0])
200
+
201
+ # Remove too small
202
+ kept = (
203
+ (boxes[:, 2] - boxes[:, 0] > self.min_bbox_size) &
204
+ (boxes[:, 3] - boxes[:, 1] > self.min_bbox_size)
205
+ )
206
+ boxes = boxes[kept]
207
+ p = p[kept]
208
+
209
+ # Normalize to 0-1
210
+ if self.normalize_boxes:
211
+ boxes[:, [0, 2]] /= size[1]
212
+ boxes[:, [1, 3]] /= size[0]
213
+
214
+ scores = p[:, 4] * p[:, 5]
215
+ labels = p[:, 6]
216
+
217
+ preds.append({"labels": labels, "boxes": boxes, "scores": scores})
218
+
219
+ return preds
post_processing/table_struct_pp.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import numpy.typing as npt
3
+ from typing import List, Tuple, Optional
4
+
5
+
6
+ def expand_boxes(
7
+ boxes: npt.NDArray[np.float64],
8
+ r_x: Tuple[float, float] = (1, 1),
9
+ r_y: Tuple[float, float] = (1, 1),
10
+ size_agnostic: bool = True,
11
+ ) -> npt.NDArray[np.float64]:
12
+ """
13
+ Expands bounding boxes by a specified ratio.
14
+ Expected box format is normalized [x_min, y_min, x_max, y_max].
15
+
16
+ Args:
17
+ boxes (numpy.ndarray): Array of bounding boxes with shape (N, 4).
18
+ r_x (tuple, optional): Left, right expansion ratios. Defaults to (1, 1) (no expansion).
19
+ r_y (tuple, optional): Up, down expansion ratios. Defaults to (1, 1) (no expansion).
20
+ size_agnostic (bool, optional): Expand independently of the box shape. Defaults to True.
21
+
22
+ Returns:
23
+ numpy.ndarray: Adjusted bounding boxes clipped to the [0, 1] range.
24
+ """
25
+ old_boxes = boxes.copy()
26
+
27
+ if not size_agnostic:
28
+ h = boxes[:, 3] - boxes[:, 1]
29
+ w = boxes[:, 2] - boxes[:, 0]
30
+ else:
31
+ h, w = 1, 1
32
+
33
+ boxes[:, 0] -= w * (r_x[0] - 1) # left
34
+ boxes[:, 2] += w * (r_x[1] - 1) # right
35
+ boxes[:, 1] -= h * (r_y[0] - 1) # up
36
+ boxes[:, 3] += h * (r_y[1] - 1) # down
37
+
38
+ boxes = np.clip(boxes, 0, 1)
39
+
40
+ # Enforce non-overlapping boxes
41
+ for i in range(len(boxes)):
42
+ for j in range(i + 1, len(boxes)):
43
+ iou = bb_iou_array(boxes[i][None], boxes[j])[0]
44
+ old_iou = bb_iou_array(old_boxes[i][None], old_boxes[j])[0]
45
+ # print(iou, old_iou)
46
+ if iou > 0.05 and old_iou < 0.1:
47
+ if boxes[i, 1] < boxes[j, 1]: # i above j
48
+ boxes[j, 1] = min(old_boxes[j, 1], boxes[i, 3])
49
+ if old_iou > 0:
50
+ boxes[i, 3] = max(old_boxes[i, 3], boxes[j, 1])
51
+ else:
52
+ boxes[i, 1] = min(old_boxes[i, 1], boxes[j, 3])
53
+ if old_iou > 0:
54
+ boxes[j, 3] = max(old_boxes[j, 3], boxes[i, 1])
55
+
56
+ return boxes
57
+
58
+
59
+ def merge_boxes(
60
+ b1: npt.NDArray[np.float64], b2: npt.NDArray[np.float64]
61
+ ) -> npt.NDArray[np.float64]:
62
+ """
63
+ Merges two bounding boxes into a single box that encompasses both.
64
+
65
+ Args:
66
+ b1 (numpy.ndarray): First bounding box [x_min, y_min, x_max, y_max].
67
+ b2 (numpy.ndarray): Second bounding box [x_min, y_min, x_max, y_max].
68
+
69
+ Returns:
70
+ numpy.ndarray: A single bounding box that covers both input boxes.
71
+ """
72
+ b = b1.copy()
73
+ b[0] = min(b1[0], b2[0])
74
+ b[1] = min(b1[1], b2[1])
75
+ b[2] = max(b1[2], b2[2])
76
+ b[3] = max(b1[3], b2[3])
77
+ return b
78
+
79
+
80
+ def bb_iou_array(
81
+ boxes: npt.NDArray[np.float64], new_box: npt.NDArray[np.float64]
82
+ ) -> npt.NDArray[np.float64]:
83
+ """
84
+ Calculates the Intersection over Union (IoU) between a box and an array of boxes.
85
+
86
+ Args:
87
+ boxes (numpy.ndarray): Array of bounding boxes with shape (N, 4).
88
+ new_box (numpy.ndarray): A single bounding box [x_min, y_min, x_max, y_max].
89
+
90
+ Returns:
91
+ numpy.ndarray: Array of IoU values between the new_box and each box in the array.
92
+ """
93
+ # bb interesection over union
94
+ xA = np.maximum(boxes[:, 0], new_box[0])
95
+ yA = np.maximum(boxes[:, 1], new_box[1])
96
+ xB = np.minimum(boxes[:, 2], new_box[2])
97
+ yB = np.minimum(boxes[:, 3], new_box[3])
98
+
99
+ interArea = np.maximum(xB - xA, 0) * np.maximum(yB - yA, 0)
100
+
101
+ # compute the area of both the prediction and ground-truth rectangles
102
+ boxAArea = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
103
+ boxBArea = (new_box[2] - new_box[0]) * (new_box[3] - new_box[1])
104
+
105
+ iou = interArea / (boxAArea + boxBArea - interArea)
106
+
107
+ return iou
108
+
109
+
110
+ def match_with_title(
111
+ box: npt.NDArray[np.float64],
112
+ title_boxes: npt.NDArray[np.float64],
113
+ match_dist: float = 0.1,
114
+ delta: float = 1.,
115
+ already_matched: List[int] = [],
116
+ ) -> Tuple[Optional[npt.NDArray[np.float64]], Optional[List[int]]]:
117
+ """
118
+ Matches a bounding box with a title bounding box based on IoU or proximity.
119
+
120
+ Args:
121
+ box (numpy.ndarray): Bounding box to match with title [x_min, y_min, x_max, y_max].
122
+ title_boxes (numpy.ndarray): Array of title bounding boxes with shape (N, 4).
123
+ match_dist (float, optional): Maximum distance for matching. Defaults to 0.1.
124
+ delta (float, optional): Multiplier for matching several titles. Defaults to 1..
125
+ already_matched (list, optional): List of already matched title indices. Defaults to [].
126
+
127
+ Returns:
128
+ tuple or None: If matched, returns a tuple of (merged_bbox, updated_title_boxes).
129
+ If no match is found, returns None, None.
130
+ """
131
+ if not len(title_boxes):
132
+ return None, None
133
+
134
+ dist_above = np.abs(title_boxes[:, 3] - box[1])
135
+ dist_below = np.abs(box[3] - title_boxes[:, 1])
136
+
137
+ dist_left = np.abs(title_boxes[:, 0] - box[0])
138
+ dist_center = np.abs(title_boxes[:, 0] + title_boxes[:, 2] - box[0] - box[2]) / 2
139
+
140
+ dists = np.min([dist_above, dist_below], 0)
141
+ dists += np.min([dist_left, dist_center], 0) / 2
142
+
143
+ ious = bb_iou_array(title_boxes, box)
144
+ dists = np.where(ious > 0, min(match_dist - 0.01, np.min(dists)) / delta, dists)
145
+
146
+ if len(already_matched):
147
+ dists[already_matched] = match_dist * 10 # Remove already matched titles
148
+
149
+ matches = None
150
+ if np.min(dists) <= match_dist:
151
+ matches = np.where(
152
+ dists <= min(match_dist, np.min(dists) * delta)
153
+ )[0]
154
+
155
+ if matches is not None:
156
+ new_bbox = box
157
+ for match in matches:
158
+ new_bbox = merge_boxes(new_bbox, title_boxes[match])
159
+ return new_bbox, list(matches)
160
+ else:
161
+ return None, None
162
+
163
+
164
+ def match_boxes_with_title(
165
+ boxes: npt.NDArray[np.float64],
166
+ confs: npt.NDArray[np.float64],
167
+ labels: npt.NDArray[np.int_],
168
+ classes: List[str],
169
+ to_match_labels: List[str] = ["chart"],
170
+ remove_matched_titles: bool = False,
171
+ match_dist: float = 0.1,
172
+ ) -> Tuple[
173
+ npt.NDArray[np.float64],
174
+ npt.NDArray[np.float64],
175
+ npt.NDArray[np.int_],
176
+ List[int],
177
+ ]:
178
+ """
179
+ Matches charts with title.
180
+
181
+ Args:
182
+ boxes (numpy.ndarray): Array of bounding boxes with shape (N, 4).
183
+ confs (numpy.ndarray): Array of confidence scores with shape (N,).
184
+ labels (numpy.ndarray): Array of labels with shape (N,).
185
+ classes (list): List of class names.
186
+ to_match_labels (list): List of class names to match with titles.
187
+ remove_matched_titles (bool): Whether to remove matched titles from the boxes.
188
+
189
+ Returns:
190
+ boxes (numpy.ndarray): Array of bounding boxes with shape (M, 4).
191
+ confs (numpy.ndarray): Array of confidence scores with shape (M,).
192
+ labels (numpy.ndarray): Array of labels with shape (M,).
193
+ found_title (list): List of indices of matched titles.
194
+ no_found_title (list): List of indices of unmatched titles.
195
+ match_dist (float, optional): Maximum distance for matching. Defaults to 0.1.
196
+ """
197
+ # Put titles at the end
198
+ title_ids = np.where(labels == classes.index("title"))[0]
199
+ order = np.concatenate([np.delete(np.arange(len(boxes)), title_ids), title_ids])
200
+ boxes = boxes[order]
201
+ confs = confs[order]
202
+ labels = labels[order]
203
+
204
+ # Ids
205
+ title_ids = np.where(labels == classes.index("title"))[0]
206
+ to_match = np.where(np.isin(labels, [classes.index(c) for c in to_match_labels]))[0]
207
+
208
+ # Matching
209
+ found_title, already_matched = [], []
210
+ for i in range(len(boxes)):
211
+ if i not in to_match:
212
+ continue
213
+ merged_box, matched_title_ids = match_with_title(
214
+ boxes[i],
215
+ boxes[title_ids],
216
+ already_matched=already_matched,
217
+ match_dist=match_dist,
218
+ )
219
+ if matched_title_ids is not None:
220
+ # print(f'Merged {classes[int(labels[i])]} at idx #{i} with title {matched_title_ids[-1]}') # noqa
221
+ boxes[i] = merged_box
222
+ already_matched += matched_title_ids
223
+ found_title.append(i)
224
+
225
+ if remove_matched_titles and len(already_matched):
226
+ boxes = np.delete(boxes, title_ids[already_matched], axis=0)
227
+ confs = np.delete(confs, title_ids[already_matched], axis=0)
228
+ labels = np.delete(labels, title_ids[already_matched], axis=0)
229
+
230
+ return boxes, confs, labels, found_title
post_processing/wbf.py ADDED
@@ -0,0 +1,321 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from:
2
+ # https://github.com/ZFTurbo/Weighted-Boxes-Fusion/blob/master/ensemble_boxes/ensemble_boxes_wbf.py
3
+
4
+ import warnings
5
+ from typing import Dict, List, Tuple, Union, Literal
6
+ import numpy as np
7
+ import numpy.typing as npt
8
+
9
+
10
+ def prefilter_boxes(
11
+ boxes: List[npt.NDArray[np.float64]],
12
+ scores: List[npt.NDArray[np.float64]],
13
+ labels: List[npt.NDArray[np.int_]],
14
+ weights: List[float],
15
+ thr: float,
16
+ class_agnostic: bool = False,
17
+ ) -> Dict[Union[str, int], npt.NDArray[np.float64]]:
18
+ """
19
+ Reformats and filters boxes.
20
+ Output is a dict of boxes to merge separately.
21
+
22
+ Args:
23
+ boxes (list[np array[n x 4]]): List of boxes. One list per model.
24
+ scores (list[np array[n]]): List of confidences.
25
+ labels (list[np array[n]]): List of labels.
26
+ weights (list): Model weights.
27
+ thr (float): Confidence threshold
28
+ class_agnostic (bool, optional): Merge boxes from different classes. Defaults to False.
29
+
30
+ Returns:
31
+ dict[np array [? x 8]]: Filtered boxes.
32
+ """
33
+ # Create dict with boxes stored by its label
34
+ new_boxes = dict()
35
+
36
+ for t in range(len(boxes)):
37
+ assert len(boxes[t]) == len(scores[t]), "len(boxes) != len(scores)"
38
+ assert len(boxes[t]) == len(labels[t]), "len(boxes) != len(labels)"
39
+
40
+ for j in range(len(boxes[t])):
41
+ score = scores[t][j]
42
+ if score < thr:
43
+ continue
44
+ label = int(labels[t][j])
45
+ box_part = boxes[t][j]
46
+ x1 = float(box_part[0])
47
+ y1 = float(box_part[1])
48
+ x2 = float(box_part[2])
49
+ y2 = float(box_part[3])
50
+
51
+ # Box data checks
52
+ if x2 < x1:
53
+ warnings.warn("X2 < X1 value in box. Swap them.")
54
+ x1, x2 = x2, x1
55
+ if y2 < y1:
56
+ warnings.warn("Y2 < Y1 value in box. Swap them.")
57
+ y1, y2 = y2, y1
58
+
59
+ array = np.array([x1, x2, y1, y2])
60
+ if array.min() < 0 or array.max() > 1:
61
+ warnings.warn("Coordinates outside [0, 1]")
62
+ array = np.clip(array, 0, 1)
63
+ x1, x2, y1, y2 = array
64
+
65
+ if (x2 - x1) * (y2 - y1) == 0.0:
66
+ warnings.warn("Zero area box skipped: {}.".format(box_part))
67
+ continue
68
+
69
+ # [label, score, weight, model index, x1, y1, x2, y2]
70
+ b = [int(label), float(score) * weights[t], weights[t], t, x1, y1, x2, y2]
71
+
72
+ label_k = "*" if class_agnostic else label
73
+ if label_k not in new_boxes:
74
+ new_boxes[label_k] = []
75
+ new_boxes[label_k].append(b)
76
+
77
+ # Sort each list in dict by score and transform it to numpy array
78
+ for k in new_boxes:
79
+ current_boxes = np.array(new_boxes[k])
80
+ new_boxes[k] = current_boxes[current_boxes[:, 1].argsort()[::-1]]
81
+
82
+ return new_boxes
83
+
84
+
85
+ def merge_labels(
86
+ labels: npt.NDArray[np.int_], confs: npt.NDArray[np.float64]
87
+ ) -> int:
88
+ """
89
+ Custom function for merging labels.
90
+ If all labels are the same, return the unique value.
91
+ Else, return the label of the most confident non-title (class 2) box.
92
+
93
+ Args:
94
+ labels (np array [n]): Labels.
95
+ confs (np array [n]): Confidence.
96
+
97
+ Returns:
98
+ int: Label.
99
+ """
100
+ if len(np.unique(labels)) == 1:
101
+ return labels[0]
102
+ else: # Most confident and not a title
103
+ confs = confs[confs != 2]
104
+ labels = labels[labels != 2]
105
+ return labels[np.argmax(confs)]
106
+
107
+
108
+ def get_weighted_box(
109
+ boxes: npt.NDArray[np.float64], conf_type: Literal["avg", "max"] = "avg"
110
+ ) -> npt.NDArray[np.float64]:
111
+ """
112
+ Merges boxes by using the weighted fusion.
113
+
114
+ Args:
115
+ boxes (np array [n x 8]): Boxes to merge.
116
+ conf_type (str, optional): Confidence merging type. Defaults to "avg".
117
+
118
+ Returns:
119
+ np array [8]: Merged box.
120
+ """
121
+ box = np.zeros(8, dtype=np.float32)
122
+ conf = 0
123
+ conf_list = []
124
+ w = 0
125
+ for b in boxes:
126
+ box[4:] += b[1] * b[4:]
127
+ conf += b[1]
128
+ conf_list.append(b[1])
129
+ w += b[2]
130
+
131
+ box[0] = merge_labels(
132
+ np.array([b[0] for b in boxes]), np.array([b[1] for b in boxes])
133
+ )
134
+
135
+ box[1] = np.max(conf_list) if conf_type == "max" else np.mean(conf_list)
136
+ box[2] = w
137
+ box[3] = -1 # model index field is retained for consistency but is not used.
138
+ box[4:] /= conf
139
+ return box
140
+
141
+
142
+ def get_biggest_box(
143
+ boxes: npt.NDArray[np.float64], conf_type: Literal["avg", "max"] = "avg"
144
+ ) -> npt.NDArray[np.float64]:
145
+ """
146
+ Merges boxes by using the biggest box.
147
+
148
+ Args:
149
+ boxes (np array [n x 8]): Boxes to merge.
150
+ conf_type (str, optional): Confidence merging type. Defaults to "avg".
151
+
152
+ Returns:
153
+ np array [8]: Merged box.
154
+ """
155
+ box = np.zeros(8, dtype=np.float32)
156
+ box[4:] = boxes[0][4:]
157
+ conf_list = []
158
+ w = 0
159
+ for b in boxes:
160
+ box[4] = min(box[4], b[4])
161
+ box[5] = min(box[5], b[5])
162
+ box[6] = max(box[6], b[6])
163
+ box[7] = max(box[7], b[7])
164
+ conf_list.append(b[1])
165
+ w += b[2]
166
+
167
+ box[0] = merge_labels(
168
+ np.array([b[0] for b in boxes]), np.array([b[1] for b in boxes])
169
+ )
170
+ # print(box[0], np.array([b[0] for b in boxes]))
171
+
172
+ box[1] = np.max(conf_list) if conf_type == "max" else np.mean(conf_list)
173
+ box[2] = w
174
+ box[3] = -1 # model index field is retained for consistency but is not used.
175
+ return box
176
+
177
+
178
+ def find_matching_box_fast(
179
+ boxes_list: npt.NDArray[np.float64],
180
+ new_box: npt.NDArray[np.float64],
181
+ match_iou: float,
182
+ ) -> Tuple[int, float]:
183
+ """
184
+ Reimplementation of find_matching_box with numpy instead of loops.
185
+ Gives significant speed up for larger arrays (~100x).
186
+ This was previously the bottleneck since the function is called for every entry in the array.
187
+
188
+ Args:
189
+ boxes_list (np.ndarray): Array of boxes with shape (N, 8).
190
+ new_box (np.ndarray): New box to match with shape (8,).
191
+ match_iou (float): IoU threshold for matching.
192
+
193
+ Returns:
194
+ Tuple[int, float]: Index of best matching box (-1 if no match) and IoU value.
195
+ """
196
+
197
+ def bb_iou_array(
198
+ boxes: npt.NDArray[np.float64], new_box: npt.NDArray[np.float64]
199
+ ) -> npt.NDArray[np.float64]:
200
+ # bb interesection over union
201
+ xA = np.maximum(boxes[:, 0], new_box[0])
202
+ yA = np.maximum(boxes[:, 1], new_box[1])
203
+ xB = np.minimum(boxes[:, 2], new_box[2])
204
+ yB = np.minimum(boxes[:, 3], new_box[3])
205
+
206
+ interArea = np.maximum(xB - xA, 0) * np.maximum(yB - yA, 0)
207
+
208
+ # compute the area of both the prediction and ground-truth rectangles
209
+ boxAArea = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
210
+ boxBArea = (new_box[2] - new_box[0]) * (new_box[3] - new_box[1])
211
+
212
+ iou = interArea / (boxAArea + boxBArea - interArea)
213
+
214
+ return iou
215
+
216
+ if boxes_list.shape[0] == 0:
217
+ return -1, match_iou
218
+
219
+ ious = bb_iou_array(boxes_list[:, 4:], new_box[4:])
220
+ # ious[boxes[:, 0] != new_box[0]] = -1
221
+
222
+ best_idx = np.argmax(ious)
223
+ best_iou = ious[best_idx]
224
+
225
+ if best_iou <= match_iou:
226
+ best_iou = match_iou
227
+ best_idx = -1
228
+
229
+ return best_idx, best_iou
230
+
231
+
232
+ def weighted_boxes_fusion(
233
+ boxes_list: List[npt.NDArray[np.float64]],
234
+ labels_list: List[npt.NDArray[np.int_]],
235
+ scores_list: List[npt.NDArray[np.float64]],
236
+ iou_thr: float = 0.5,
237
+ skip_box_thr: float = 0.0,
238
+ conf_type: Literal["avg", "max"] = "avg",
239
+ merge_type: Literal["weighted", "biggest"] = "weighted",
240
+ class_agnostic: bool = False,
241
+ ) -> Tuple[npt.NDArray[np.float64], npt.NDArray[np.float64], npt.NDArray[np.int_]]:
242
+ """
243
+ Custom WBF implementation that supports a class_agnostic mode and a biggest box fusion.
244
+ Boxes are expected to be in normalized (x0, y0, x1, y1) format.
245
+
246
+ Args:
247
+ boxes_list (list[np.ndarray[n x 4]]): List of boxes. One list per model.
248
+ labels_list (list[np.ndarray[n]]): List of labels.
249
+ scores_list (list[np.ndarray[n]]): List of confidences.
250
+ iou_thr (float, optional): IoU threshold for matching. Defaults to 0.55.
251
+ skip_box_thr (float, optional): Exclude boxes with score < skip_box_thr. Defaults to 0.0.
252
+ conf_type (str, optional): Confidence merging type ("avg" or "max"). Defaults to "avg".
253
+ merge_type (str, optional): Merge type ("weighted" or "biggest"). Defaults to "weighted".
254
+ class_agnostic (bool, optional): Merge boxes from different classes. Defaults to False.
255
+
256
+ Returns:
257
+ numpy.ndarray [N x 4]: Array of bounding boxes.
258
+ numpy.ndarray [N]: Array of labels.
259
+ numpy.ndarray [N]: Array of scores.
260
+ """
261
+ weights = np.ones(len(boxes_list))
262
+
263
+ assert conf_type in ["avg", "max"], 'Conf type must be "avg" or "max"'
264
+ assert merge_type in ["weighted", "biggest"], 'Conf type must be "weighted" or "biggest"'
265
+
266
+ filtered_boxes = prefilter_boxes(
267
+ boxes_list,
268
+ scores_list,
269
+ labels_list,
270
+ weights,
271
+ skip_box_thr,
272
+ class_agnostic=class_agnostic,
273
+ )
274
+ if len(filtered_boxes) == 0:
275
+ return np.zeros((0, 4)), np.zeros((0,)), np.zeros((0,))
276
+
277
+ overall_boxes = []
278
+ for label in filtered_boxes:
279
+ boxes = filtered_boxes[label]
280
+ clusters = []
281
+
282
+ # Clusterize boxes
283
+ for j in range(len(boxes)):
284
+ ids = [i for i in range(len(boxes)) if i != j]
285
+ index, best_iou = find_matching_box_fast(boxes[ids], boxes[j], iou_thr)
286
+
287
+ if index != -1:
288
+ index = ids[index]
289
+ cluster_idx = [
290
+ clust_idx
291
+ for clust_idx, clust in enumerate(clusters)
292
+ if (j in clust or index in clust)
293
+ ]
294
+ if len(cluster_idx):
295
+ cluster_idx = cluster_idx[0]
296
+ clusters[cluster_idx] = list(
297
+ set(clusters[cluster_idx] + [index, j])
298
+ )
299
+ else:
300
+ clusters.append([index, j])
301
+ else:
302
+ clusters.append([j])
303
+
304
+ for j, c in enumerate(clusters):
305
+ if merge_type == "weighted":
306
+ weighted_box = get_weighted_box(boxes[c], conf_type)
307
+ elif merge_type == "biggest":
308
+ weighted_box = get_biggest_box(boxes[c], conf_type)
309
+
310
+ if conf_type == "max":
311
+ weighted_box[1] = weighted_box[1] / weights.max()
312
+ else: # avg
313
+ weighted_box[1] = weighted_box[1] * len(c) / weights.sum()
314
+ overall_boxes.append(weighted_box)
315
+
316
+ overall_boxes = np.array(overall_boxes)
317
+ overall_boxes = overall_boxes[overall_boxes[:, 1].argsort()[::-1]]
318
+ boxes = overall_boxes[:, 4:]
319
+ scores = overall_boxes[:, 1]
320
+ labels = overall_boxes[:, 0]
321
+ return boxes, labels, scores
table_structure_v1.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from typing import List, Tuple
4
+
5
+
6
+ class Exp:
7
+ """
8
+ Configuration class for the table structure model.
9
+
10
+ This class contains all configuration parameters for the YOLOX-based
11
+ table structure detection model, including architecture settings, inference
12
+ parameters, and class-specific thresholds.
13
+ """
14
+
15
+ def __init__(self) -> None:
16
+ """Initialize the configuration with default parameters."""
17
+ self.name: str = "page-element-v3"
18
+ self.ckpt: str = "weights.pth"
19
+ self.device: str = "cuda:0" if torch.cuda.is_available() else "cpu"
20
+
21
+ # YOLOX architecture parameters
22
+ self.act: str = "silu"
23
+ self.depth: float = 1.00
24
+ self.width: float = 1.00
25
+ self.labels: List[str] = [
26
+ "border", # not used
27
+ "cell",
28
+ "row",
29
+ "column",
30
+ "header" # not used
31
+ ]
32
+ self.num_classes: int = len(self.labels)
33
+
34
+ # Inference parameters
35
+ self.size: Tuple[int, int] = (1024, 1024)
36
+ self.min_bbox_size: int = 0
37
+ self.normalize_boxes: bool = True
38
+
39
+ # NMS & thresholding. These can be updated
40
+ self.conf_thresh: float = 0.01
41
+ self.iou_thresh: float = 0.25
42
+ self.class_agnostic: bool = False
43
+
44
+ self.threshold: float = 0.05
45
+
46
+ def get_model(self) -> nn.Module:
47
+ """
48
+ Get the YOLOX model.
49
+
50
+ Builds and returns a YOLOX model with the configured architecture.
51
+ Also updates batch normalization parameters for optimal inference.
52
+
53
+ Returns:
54
+ nn.Module: The YOLOX model with configured parameters.
55
+ """
56
+ from yolox import YOLOX, YOLOPAFPN, YOLOXHead
57
+
58
+ # Build model
59
+ if getattr(self, "model", None) is None:
60
+ in_channels = [256, 512, 1024]
61
+ backbone = YOLOPAFPN(
62
+ self.depth, self.width, in_channels=in_channels, act=self.act
63
+ )
64
+ head = YOLOXHead(
65
+ self.num_classes, self.width, in_channels=in_channels, act=self.act
66
+ )
67
+ self.model = YOLOX(backbone, head)
68
+
69
+ # Update batch-norm parameters
70
+ def init_yolo(M: nn.Module) -> None:
71
+ for m in M.modules():
72
+ if isinstance(m, nn.BatchNorm2d):
73
+ m.eps = 1e-3
74
+ m.momentum = 0.03
75
+
76
+ self.model.apply(init_yolo)
77
+
78
+ return self.model
utils.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import numpy as np
3
+ import pandas as pd
4
+ import numpy.typing as npt
5
+ import matplotlib.pyplot as plt
6
+ from matplotlib.patches import Rectangle
7
+ from typing import Dict, List, Tuple, Optional, Union
8
+
9
+
10
+ COLORS = [
11
+ "#003EFF",
12
+ "#FF8F00",
13
+ "#079700",
14
+ "#A123FF",
15
+ "#87CEEB",
16
+ "#FF5733",
17
+ "#C70039",
18
+ "#900C3F",
19
+ "#581845",
20
+ "#11998E",
21
+ ]
22
+
23
+
24
+ def reformat_for_plotting(
25
+ boxes: npt.NDArray[np.float64],
26
+ labels: npt.NDArray[np.int_],
27
+ scores: npt.NDArray[np.float64],
28
+ shape: Tuple[int, int],
29
+ num_classes: int,
30
+ ) -> Tuple[List[npt.NDArray[np.int_]], List[npt.NDArray[np.float64]]]:
31
+ """
32
+ Reformat YOLOX predictions for plotting.
33
+ - Unnormalizes boxes to original image size.
34
+ - Reformats boxes to [xmin, ymin, width, height].
35
+ - Converts to list of boxes and scores per class.
36
+
37
+ Args:
38
+ boxes (np.ndarray [N, 4]): Array of bounding boxes in format [xmin, ymin, xmax, ymax].
39
+ labels (np.ndarray [N]): Array of labels.
40
+ scores (np.ndarray [N]): Array of confidence scores.
41
+ shape (tuple [2]): Shape of the image (height, width).
42
+ num_classes (int): Number of classes.
43
+
44
+ Returns:
45
+ list[np.ndarray[N]]: List of box bounding boxes per class.
46
+ list[np.ndarray[N]]: List of confidence scores per class.
47
+ """
48
+ boxes_plot = boxes.copy()
49
+ boxes_plot[:, [0, 2]] *= shape[1]
50
+ boxes_plot[:, [1, 3]] *= shape[0]
51
+ boxes_plot = boxes_plot.astype(int)
52
+ boxes_plot[:, 2] -= boxes_plot[:, 0]
53
+ boxes_plot[:, 3] -= boxes_plot[:, 1]
54
+ boxes_plot = [boxes_plot[labels == c] for c in range(num_classes)]
55
+ confs = [scores[labels == c] for c in range(num_classes)]
56
+ return boxes_plot, confs
57
+
58
+
59
+ def plot_sample(
60
+ img: npt.NDArray[np.uint8],
61
+ boxes_list: List[npt.NDArray[np.int_]],
62
+ confs_list: List[npt.NDArray[np.float64]],
63
+ labels: List[str],
64
+ show_text: bool = True,
65
+ ) -> None:
66
+ """
67
+ Plots an image with bounding boxes.
68
+ Coordinates are expected in format [x_min, y_min, width, height].
69
+
70
+ Args:
71
+ img (numpy.ndarray): The input image to be plotted.
72
+ boxes_list (list[np.ndarray]): List of box bounding boxes per class.
73
+ confs_list (list[np.ndarray]): List of confidence scores per class.
74
+ labels (list): List of class labels.
75
+ show_text (bool, optional): Whether to show the text. Defaults to True.
76
+ """
77
+ plt.imshow(img, cmap="gray")
78
+ plt.axis(False)
79
+
80
+ for boxes, confs, col, l in zip(boxes_list, confs_list, COLORS, labels):
81
+ for box_idx, box in enumerate(boxes):
82
+ # Better display around boundaries
83
+ h, w, _ = img.shape
84
+ box = np.copy(box)
85
+ box[:2] = np.clip(box[:2], 2, max(h, w))
86
+ box[2] = min(box[2], w - 2 - box[0])
87
+ box[3] = min(box[3], h - 2 - box[1])
88
+
89
+ rect = Rectangle(
90
+ (box[0], box[1]),
91
+ box[2],
92
+ box[3],
93
+ linewidth=2,
94
+ facecolor="none",
95
+ edgecolor=col,
96
+ )
97
+ plt.gca().add_patch(rect)
98
+
99
+ # Add class and index label with proper alignment
100
+ if show_text:
101
+ plt.text(
102
+ box[0], box[1],
103
+ f"{l}_{box_idx} conf={confs[box_idx]:.3f}",
104
+ color='white',
105
+ fontsize=8,
106
+ bbox=dict(facecolor=col, alpha=1, edgecolor=col, pad=0, linewidth=2),
107
+ verticalalignment='bottom',
108
+ horizontalalignment='left'
109
+ )
110
+
111
+
112
+ def reorder_boxes(
113
+ boxes: npt.NDArray[np.float64],
114
+ labels: npt.NDArray[np.int_],
115
+ classes: Optional[List[str]] = None,
116
+ scores: Optional[npt.NDArray[np.float64]] = None,
117
+ ) -> Union[
118
+ Tuple[npt.NDArray[np.float64], npt.NDArray[np.int_]],
119
+ Tuple[npt.NDArray[np.float64], npt.NDArray[np.int_], npt.NDArray[np.float64]],
120
+ ]:
121
+ """
122
+ Reorder boxes, labels and scores by box coordinates.
123
+ Columns are sorted by x first, rows and cells are sorted by y first.
124
+
125
+ Args:
126
+ boxes (np.ndarray [N, 4]): Array of bounding boxes in format [xmin, ymin, xmax, ymax].
127
+ labels (np.ndarray [N]): Array of labels.
128
+ classes (list, optional): List of class labels. Defaults to None.
129
+ scores (np.ndarray [N], optional): Array of confidence scores. Defaults to None.
130
+
131
+ Returns:
132
+ np.ndarray [N, 4]: Ordered boxes in format [xmin, ymin, xmax, ymax].
133
+ np.ndarray [N]: Ordered labels.
134
+ np.ndarray [N]: Ordered scores if scores is not None.
135
+ """
136
+ n_classes = labels.max() if classes is None else len(classes)
137
+ classes = labels.unique() if classes is None else classes
138
+
139
+ ordered_boxes, ordered_labels, ordered_scores = [], [], []
140
+ for c in range(n_classes):
141
+ boxes_class = boxes[labels == c]
142
+ if len(boxes_class):
143
+ # Reorder
144
+ sort = ["x0", "y0"] if classes[c] == "column" else ["y0", "x0"]
145
+
146
+ df_coords = pd.DataFrame({
147
+ "y0": np.round(boxes_class[:, 1] - boxes_class[:, 1].min(), 2),
148
+ "x0": np.round(boxes_class[:, 0] - boxes_class[:, 0].min(), 2),
149
+ })
150
+
151
+ idxs = df_coords.sort_values(sort).index
152
+
153
+ ordered_boxes.append(boxes_class[idxs])
154
+ ordered_labels.append(labels[labels == c][idxs])
155
+
156
+ if scores is not None:
157
+ ordered_scores.append(scores[labels == c][idxs])
158
+
159
+ ordered_boxes = np.concatenate(ordered_boxes)
160
+ ordered_labels = np.concatenate(ordered_labels)
161
+ if scores is not None:
162
+ ordered_scores = np.concatenate(ordered_scores)
163
+ return ordered_boxes, ordered_labels, ordered_scores
164
+ return ordered_boxes, ordered_labels
165
+
166
+
167
+ def postprocess_preds_table_structure(
168
+ preds: Dict[str, npt.NDArray],
169
+ threshold: float = 0.1,
170
+ class_labels: Optional[List[str]] = None,
171
+ reorder: bool = True,
172
+ ) -> Tuple[npt.NDArray[np.float64], npt.NDArray[np.int_], npt.NDArray[np.float64]]:
173
+ """
174
+ Post process predictions for table structure task.
175
+ - Applies thresholding
176
+ - Reorders boxes using the reading order
177
+
178
+ Args:
179
+ preds (dict): Predictions. Keys are "scores", "boxes", "labels".
180
+ threshold (float, optional): Threshold for the confidence scores. Defaults to 0.1.
181
+ class_labels (list, optional): List of class labels. Defaults to None.
182
+ reorder (bool, optional): Whether to apply reordering. Defaults to True.
183
+
184
+ Returns:
185
+ numpy.ndarray [N x 4]: Array of bounding boxes.
186
+ numpy.ndarray [N]: Array of labels.
187
+ numpy.ndarray [N]: Array of scores.
188
+ """
189
+ boxes = preds["boxes"].cpu().numpy()
190
+ labels = preds["labels"].cpu().numpy()
191
+ scores = preds["scores"].cpu().numpy()
192
+
193
+ # Threshold
194
+ boxes = boxes[scores > threshold]
195
+ labels = labels[scores > threshold]
196
+ scores = scores[scores > threshold]
197
+
198
+ if len(boxes) > 0 and reorder:
199
+ boxes, labels, scores = reorder_boxes(boxes, labels, class_labels, scores)
200
+
201
+ return boxes, labels, scores
yolox/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding:utf-8 -*-
3
+ # Copyright (c) Megvii Inc. All rights reserved.
4
+
5
+ from .yolo_head import YOLOXHead
6
+ from .yolo_pafpn import YOLOPAFPN
7
+ from .yolox import YOLOX
yolox/boxes.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright (c) Megvii Inc. All rights reserved.
3
+
4
+ import torch
5
+ import torchvision
6
+
7
+
8
+ def postprocess(prediction, num_classes, conf_thre=0.7, nms_thre=0.45, class_agnostic=False):
9
+ """
10
+ Copied from YOLOX/yolox/utils/boxes.py
11
+ """
12
+ box_corner = prediction.new(prediction.shape)
13
+ box_corner[:, :, 0] = prediction[:, :, 0] - prediction[:, :, 2] / 2
14
+ box_corner[:, :, 1] = prediction[:, :, 1] - prediction[:, :, 3] / 2
15
+ box_corner[:, :, 2] = prediction[:, :, 0] + prediction[:, :, 2] / 2
16
+ box_corner[:, :, 3] = prediction[:, :, 1] + prediction[:, :, 3] / 2
17
+ prediction[:, :, :4] = box_corner[:, :, :4]
18
+
19
+ output = [None for _ in range(len(prediction))]
20
+ for i, image_pred in enumerate(prediction):
21
+
22
+ # If none are remaining => process next image
23
+ if not image_pred.size(0):
24
+ continue
25
+ # Get score and class with highest confidence
26
+ class_conf, class_pred = torch.max(image_pred[:, 5: 5 + num_classes], 1, keepdim=True)
27
+
28
+ conf_mask = (image_pred[:, 4] * class_conf.squeeze() >= conf_thre).squeeze()
29
+ # Detections ordered as (x1, y1, x2, y2, obj_conf, class_conf, class_pred)
30
+ detections = torch.cat((image_pred[:, :5], class_conf, class_pred.float()), 1)
31
+ detections = detections[conf_mask]
32
+ if not detections.size(0):
33
+ continue
34
+
35
+ if class_agnostic:
36
+ nms_out_index = torchvision.ops.nms(
37
+ detections[:, :4],
38
+ detections[:, 4] * detections[:, 5],
39
+ nms_thre,
40
+ )
41
+ else:
42
+ nms_out_index = torchvision.ops.batched_nms(
43
+ detections[:, :4],
44
+ detections[:, 4] * detections[:, 5],
45
+ detections[:, 6],
46
+ nms_thre,
47
+ )
48
+
49
+ detections = detections[nms_out_index]
50
+ if output[i] is None:
51
+ output[i] = detections
52
+ else:
53
+ output[i] = torch.cat((output[i], detections))
54
+
55
+ return output
yolox/darknet.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- encoding: utf-8 -*-
3
+ # Copyright (c) Megvii Inc. All rights reserved.
4
+
5
+ from torch import nn
6
+
7
+ from .network_blocks import BaseConv, CSPLayer, DWConv, Focus, ResLayer, SPPBottleneck
8
+
9
+
10
+ class Darknet(nn.Module):
11
+ # number of blocks from dark2 to dark5.
12
+ depth2blocks = {21: [1, 2, 2, 1], 53: [2, 8, 8, 4]}
13
+
14
+ def __init__(
15
+ self,
16
+ depth,
17
+ in_channels=3,
18
+ stem_out_channels=32,
19
+ out_features=("dark3", "dark4", "dark5"),
20
+ ):
21
+ """
22
+ Args:
23
+ depth (int): depth of darknet used in model, usually use [21, 53] for this param.
24
+ in_channels (int): number of input channels, for example, use 3 for RGB image.
25
+ stem_out_channels (int): number of output channels of darknet stem.
26
+ It decides channels of darknet layer2 to layer5.
27
+ out_features (Tuple[str]): desired output layer name.
28
+ """
29
+ super().__init__()
30
+ assert out_features, "please provide output features of Darknet"
31
+ self.out_features = out_features
32
+ self.stem = nn.Sequential(
33
+ BaseConv(in_channels, stem_out_channels, ksize=3, stride=1, act="lrelu"),
34
+ *self.make_group_layer(stem_out_channels, num_blocks=1, stride=2),
35
+ )
36
+ in_channels = stem_out_channels * 2 # 64
37
+
38
+ num_blocks = Darknet.depth2blocks[depth]
39
+ # create darknet with `stem_out_channels` and `num_blocks` layers.
40
+ # to make model structure more clear, we don't use `for` statement in python.
41
+ self.dark2 = nn.Sequential(
42
+ *self.make_group_layer(in_channels, num_blocks[0], stride=2)
43
+ )
44
+ in_channels *= 2 # 128
45
+ self.dark3 = nn.Sequential(
46
+ *self.make_group_layer(in_channels, num_blocks[1], stride=2)
47
+ )
48
+ in_channels *= 2 # 256
49
+ self.dark4 = nn.Sequential(
50
+ *self.make_group_layer(in_channels, num_blocks[2], stride=2)
51
+ )
52
+ in_channels *= 2 # 512
53
+
54
+ self.dark5 = nn.Sequential(
55
+ *self.make_group_layer(in_channels, num_blocks[3], stride=2),
56
+ *self.make_spp_block([in_channels, in_channels * 2], in_channels * 2),
57
+ )
58
+
59
+ def make_group_layer(self, in_channels: int, num_blocks: int, stride: int = 1):
60
+ "starts with conv layer then has `num_blocks` `ResLayer`"
61
+ return [
62
+ BaseConv(in_channels, in_channels * 2, ksize=3, stride=stride, act="lrelu"),
63
+ *[(ResLayer(in_channels * 2)) for _ in range(num_blocks)],
64
+ ]
65
+
66
+ def make_spp_block(self, filters_list, in_filters):
67
+ m = nn.Sequential(
68
+ *[
69
+ BaseConv(in_filters, filters_list[0], 1, stride=1, act="lrelu"),
70
+ BaseConv(filters_list[0], filters_list[1], 3, stride=1, act="lrelu"),
71
+ SPPBottleneck(
72
+ in_channels=filters_list[1],
73
+ out_channels=filters_list[0],
74
+ activation="lrelu",
75
+ ),
76
+ BaseConv(filters_list[0], filters_list[1], 3, stride=1, act="lrelu"),
77
+ BaseConv(filters_list[1], filters_list[0], 1, stride=1, act="lrelu"),
78
+ ]
79
+ )
80
+ return m
81
+
82
+ def forward(self, x):
83
+ outputs = {}
84
+ x = self.stem(x)
85
+ outputs["stem"] = x
86
+ x = self.dark2(x)
87
+ outputs["dark2"] = x
88
+ x = self.dark3(x)
89
+ outputs["dark3"] = x
90
+ x = self.dark4(x)
91
+ outputs["dark4"] = x
92
+ x = self.dark5(x)
93
+ outputs["dark5"] = x
94
+ return {k: v for k, v in outputs.items() if k in self.out_features}
95
+
96
+
97
+ class CSPDarknet(nn.Module):
98
+ def __init__(
99
+ self,
100
+ dep_mul,
101
+ wid_mul,
102
+ out_features=("dark3", "dark4", "dark5"),
103
+ depthwise=False,
104
+ act="silu",
105
+ ):
106
+ super().__init__()
107
+ assert out_features, "please provide output features of Darknet"
108
+ self.out_features = out_features
109
+ Conv = DWConv if depthwise else BaseConv
110
+
111
+ base_channels = int(wid_mul * 64) # 64
112
+ base_depth = max(round(dep_mul * 3), 1) # 3
113
+
114
+ # stem
115
+ self.stem = Focus(3, base_channels, ksize=3, act=act)
116
+
117
+ # dark2
118
+ self.dark2 = nn.Sequential(
119
+ Conv(base_channels, base_channels * 2, 3, 2, act=act),
120
+ CSPLayer(
121
+ base_channels * 2,
122
+ base_channels * 2,
123
+ n=base_depth,
124
+ depthwise=depthwise,
125
+ act=act,
126
+ ),
127
+ )
128
+
129
+ # dark3
130
+ self.dark3 = nn.Sequential(
131
+ Conv(base_channels * 2, base_channels * 4, 3, 2, act=act),
132
+ CSPLayer(
133
+ base_channels * 4,
134
+ base_channels * 4,
135
+ n=base_depth * 3,
136
+ depthwise=depthwise,
137
+ act=act,
138
+ ),
139
+ )
140
+
141
+ # dark4
142
+ self.dark4 = nn.Sequential(
143
+ Conv(base_channels * 4, base_channels * 8, 3, 2, act=act),
144
+ CSPLayer(
145
+ base_channels * 8,
146
+ base_channels * 8,
147
+ n=base_depth * 3,
148
+ depthwise=depthwise,
149
+ act=act,
150
+ ),
151
+ )
152
+
153
+ # dark5
154
+ self.dark5 = nn.Sequential(
155
+ Conv(base_channels * 8, base_channels * 16, 3, 2, act=act),
156
+ SPPBottleneck(base_channels * 16, base_channels * 16, activation=act),
157
+ CSPLayer(
158
+ base_channels * 16,
159
+ base_channels * 16,
160
+ n=base_depth,
161
+ shortcut=False,
162
+ depthwise=depthwise,
163
+ act=act,
164
+ ),
165
+ )
166
+
167
+ def forward(self, x):
168
+ outputs = {}
169
+ x = self.stem(x)
170
+ outputs["stem"] = x
171
+ x = self.dark2(x)
172
+ outputs["dark2"] = x
173
+ x = self.dark3(x)
174
+ outputs["dark3"] = x
175
+ x = self.dark4(x)
176
+ outputs["dark4"] = x
177
+ x = self.dark5(x)
178
+ outputs["dark5"] = x
179
+ return {k: v for k, v in outputs.items() if k in self.out_features}
yolox/network_blocks.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- encoding: utf-8 -*-
3
+ # Copyright (c) Megvii Inc. All rights reserved.
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+
9
+ class SiLU(nn.Module):
10
+ """export-friendly version of nn.SiLU()"""
11
+
12
+ @staticmethod
13
+ def forward(x):
14
+ return x * torch.sigmoid(x)
15
+
16
+
17
+ def get_activation(name="silu", inplace=True):
18
+ if name == "silu":
19
+ module = nn.SiLU(inplace=inplace)
20
+ elif name == "relu":
21
+ module = nn.ReLU(inplace=inplace)
22
+ elif name == "lrelu":
23
+ module = nn.LeakyReLU(0.1, inplace=inplace)
24
+ else:
25
+ raise AttributeError("Unsupported act type: {}".format(name))
26
+ return module
27
+
28
+
29
+ class BaseConv(nn.Module):
30
+ """A Conv2d -> Batchnorm -> silu/leaky relu block"""
31
+
32
+ def __init__(
33
+ self, in_channels, out_channels, ksize, stride, groups=1, bias=False, act="silu"
34
+ ):
35
+ super().__init__()
36
+ # same padding
37
+ pad = (ksize - 1) // 2
38
+ self.conv = nn.Conv2d(
39
+ in_channels,
40
+ out_channels,
41
+ kernel_size=ksize,
42
+ stride=stride,
43
+ padding=pad,
44
+ groups=groups,
45
+ bias=bias,
46
+ )
47
+ self.bn = nn.BatchNorm2d(out_channels)
48
+ self.act = get_activation(act, inplace=True)
49
+
50
+ def forward(self, x):
51
+ return self.act(self.bn(self.conv(x)))
52
+
53
+ def fuseforward(self, x):
54
+ return self.act(self.conv(x))
55
+
56
+
57
+ class DWConv(nn.Module):
58
+ """Depthwise Conv + Conv"""
59
+
60
+ def __init__(self, in_channels, out_channels, ksize, stride=1, act="silu"):
61
+ super().__init__()
62
+ self.dconv = BaseConv(
63
+ in_channels,
64
+ in_channels,
65
+ ksize=ksize,
66
+ stride=stride,
67
+ groups=in_channels,
68
+ act=act,
69
+ )
70
+ self.pconv = BaseConv(
71
+ in_channels, out_channels, ksize=1, stride=1, groups=1, act=act
72
+ )
73
+
74
+ def forward(self, x):
75
+ x = self.dconv(x)
76
+ return self.pconv(x)
77
+
78
+
79
+ class Bottleneck(nn.Module):
80
+ # Standard bottleneck
81
+ def __init__(
82
+ self,
83
+ in_channels,
84
+ out_channels,
85
+ shortcut=True,
86
+ expansion=0.5,
87
+ depthwise=False,
88
+ act="silu",
89
+ ):
90
+ super().__init__()
91
+ hidden_channels = int(out_channels * expansion)
92
+ Conv = DWConv if depthwise else BaseConv
93
+ self.conv1 = BaseConv(in_channels, hidden_channels, 1, stride=1, act=act)
94
+ self.conv2 = Conv(hidden_channels, out_channels, 3, stride=1, act=act)
95
+ self.use_add = shortcut and in_channels == out_channels
96
+
97
+ def forward(self, x):
98
+ y = self.conv2(self.conv1(x))
99
+ if self.use_add:
100
+ y = y + x
101
+ return y
102
+
103
+
104
+ class ResLayer(nn.Module):
105
+ "Residual layer with `in_channels` inputs."
106
+
107
+ def __init__(self, in_channels: int):
108
+ super().__init__()
109
+ mid_channels = in_channels // 2
110
+ self.layer1 = BaseConv(
111
+ in_channels, mid_channels, ksize=1, stride=1, act="lrelu"
112
+ )
113
+ self.layer2 = BaseConv(
114
+ mid_channels, in_channels, ksize=3, stride=1, act="lrelu"
115
+ )
116
+
117
+ def forward(self, x):
118
+ out = self.layer2(self.layer1(x))
119
+ return x + out
120
+
121
+
122
+ class SPPBottleneck(nn.Module):
123
+ """Spatial pyramid pooling layer used in YOLOv3-SPP"""
124
+
125
+ def __init__(
126
+ self, in_channels, out_channels, kernel_sizes=(5, 9, 13), activation="silu"
127
+ ):
128
+ super().__init__()
129
+ hidden_channels = in_channels // 2
130
+ self.conv1 = BaseConv(in_channels, hidden_channels, 1, stride=1, act=activation)
131
+ self.m = nn.ModuleList(
132
+ [
133
+ nn.MaxPool2d(kernel_size=ks, stride=1, padding=ks // 2)
134
+ for ks in kernel_sizes
135
+ ]
136
+ )
137
+ conv2_channels = hidden_channels * (len(kernel_sizes) + 1)
138
+ self.conv2 = BaseConv(conv2_channels, out_channels, 1, stride=1, act=activation)
139
+
140
+ def forward(self, x):
141
+ x = self.conv1(x)
142
+ x = torch.cat([x] + [m(x) for m in self.m], dim=1)
143
+ x = self.conv2(x)
144
+ return x
145
+
146
+
147
+ class CSPLayer(nn.Module):
148
+ """C3 in yolov5, CSP Bottleneck with 3 convolutions"""
149
+
150
+ def __init__(
151
+ self,
152
+ in_channels,
153
+ out_channels,
154
+ n=1,
155
+ shortcut=True,
156
+ expansion=0.5,
157
+ depthwise=False,
158
+ act="silu",
159
+ ):
160
+ """
161
+ Args:
162
+ in_channels (int): input channels.
163
+ out_channels (int): output channels.
164
+ n (int): number of Bottlenecks. Default value: 1.
165
+ """
166
+ # ch_in, ch_out, number, shortcut, groups, expansion
167
+ super().__init__()
168
+ hidden_channels = int(out_channels * expansion) # hidden channels
169
+ self.conv1 = BaseConv(in_channels, hidden_channels, 1, stride=1, act=act)
170
+ self.conv2 = BaseConv(in_channels, hidden_channels, 1, stride=1, act=act)
171
+ self.conv3 = BaseConv(2 * hidden_channels, out_channels, 1, stride=1, act=act)
172
+ module_list = [
173
+ Bottleneck(
174
+ hidden_channels, hidden_channels, shortcut, 1.0, depthwise, act=act
175
+ )
176
+ for _ in range(n)
177
+ ]
178
+ self.m = nn.Sequential(*module_list)
179
+
180
+ def forward(self, x):
181
+ x_1 = self.conv1(x)
182
+ x_2 = self.conv2(x)
183
+ x_1 = self.m(x_1)
184
+ x = torch.cat((x_1, x_2), dim=1)
185
+ return self.conv3(x)
186
+
187
+
188
+ class Focus(nn.Module):
189
+ """Focus width and height information into channel space."""
190
+
191
+ def __init__(self, in_channels, out_channels, ksize=1, stride=1, act="silu"):
192
+ super().__init__()
193
+ self.conv = BaseConv(in_channels * 4, out_channels, ksize, stride, act=act)
194
+
195
+ def forward(self, x):
196
+ # shape of x (b,c,w,h) -> y(b,4c,w/2,h/2)
197
+ patch_top_left = x[..., ::2, ::2]
198
+ patch_top_right = x[..., ::2, 1::2]
199
+ patch_bot_left = x[..., 1::2, ::2]
200
+ patch_bot_right = x[..., 1::2, 1::2]
201
+ x = torch.cat(
202
+ (
203
+ patch_top_left,
204
+ patch_bot_left,
205
+ patch_top_right,
206
+ patch_bot_right,
207
+ ),
208
+ dim=1,
209
+ )
210
+ return self.conv(x)
yolox/yolo_fpn.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- encoding: utf-8 -*-
3
+ # Copyright (c) Megvii Inc. All rights reserved.
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ from .darknet import Darknet
9
+ from .network_blocks import BaseConv
10
+
11
+
12
+ class YOLOFPN(nn.Module):
13
+ """
14
+ YOLOFPN module. Darknet 53 is the default backbone of this model.
15
+ """
16
+
17
+ def __init__(
18
+ self,
19
+ depth=53,
20
+ in_features=["dark3", "dark4", "dark5"],
21
+ ):
22
+ super().__init__()
23
+
24
+ self.backbone = Darknet(depth)
25
+ self.in_features = in_features
26
+
27
+ # out 1
28
+ self.out1_cbl = self._make_cbl(512, 256, 1)
29
+ self.out1 = self._make_embedding([256, 512], 512 + 256)
30
+
31
+ # out 2
32
+ self.out2_cbl = self._make_cbl(256, 128, 1)
33
+ self.out2 = self._make_embedding([128, 256], 256 + 128)
34
+
35
+ # upsample
36
+ self.upsample = nn.Upsample(scale_factor=2, mode="nearest")
37
+
38
+ def _make_cbl(self, _in, _out, ks):
39
+ return BaseConv(_in, _out, ks, stride=1, act="lrelu")
40
+
41
+ def _make_embedding(self, filters_list, in_filters):
42
+ m = nn.Sequential(
43
+ *[
44
+ self._make_cbl(in_filters, filters_list[0], 1),
45
+ self._make_cbl(filters_list[0], filters_list[1], 3),
46
+ self._make_cbl(filters_list[1], filters_list[0], 1),
47
+ self._make_cbl(filters_list[0], filters_list[1], 3),
48
+ self._make_cbl(filters_list[1], filters_list[0], 1),
49
+ ]
50
+ )
51
+ return m
52
+
53
+ def load_pretrained_model(self, filename="./weights/darknet53.mix.pth"):
54
+ with open(filename, "rb") as f:
55
+ state_dict = torch.load(f, map_location="cpu")
56
+ print("loading pretrained weights...")
57
+ self.backbone.load_state_dict(state_dict)
58
+
59
+ def forward(self, inputs):
60
+ """
61
+ Args:
62
+ inputs (Tensor): input image.
63
+
64
+ Returns:
65
+ Tuple[Tensor]: FPN output features..
66
+ """
67
+ # backbone
68
+ out_features = self.backbone(inputs)
69
+ x2, x1, x0 = [out_features[f] for f in self.in_features]
70
+
71
+ # yolo branch 1
72
+ x1_in = self.out1_cbl(x0)
73
+ x1_in = self.upsample(x1_in)
74
+ x1_in = torch.cat([x1_in, x1], 1)
75
+ out_dark4 = self.out1(x1_in)
76
+
77
+ # yolo branch 2
78
+ x2_in = self.out2_cbl(out_dark4)
79
+ x2_in = self.upsample(x2_in)
80
+ x2_in = torch.cat([x2_in, x2], 1)
81
+ out_dark3 = self.out2(x2_in)
82
+
83
+ outputs = (out_dark3, out_dark4, x0)
84
+ return outputs
yolox/yolo_head.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding:utf-8 -*-
3
+ # Copyright (c) Megvii Inc. All rights reserved.
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from .network_blocks import BaseConv, DWConv
8
+
9
+
10
+ _TORCH_VER = [int(x) for x in torch.__version__.split(".")[:2]]
11
+
12
+
13
+ def meshgrid(*tensors):
14
+ """
15
+ Copied from YOLOX/yolox/utils/compat.py
16
+ """
17
+ if _TORCH_VER >= [1, 10]:
18
+ return torch.meshgrid(*tensors, indexing="ij")
19
+ else:
20
+ return torch.meshgrid(*tensors)
21
+
22
+
23
+ def bboxes_iou(bboxes_a, bboxes_b, xyxy=True):
24
+ """
25
+ Copied from YOLOX/yolox/utils/boxes.py
26
+ """
27
+ if bboxes_a.shape[1] != 4 or bboxes_b.shape[1] != 4:
28
+ raise IndexError
29
+
30
+ if xyxy:
31
+ tl = torch.max(bboxes_a[:, None, :2], bboxes_b[:, :2])
32
+ br = torch.min(bboxes_a[:, None, 2:], bboxes_b[:, 2:])
33
+ area_a = torch.prod(bboxes_a[:, 2:] - bboxes_a[:, :2], 1)
34
+ area_b = torch.prod(bboxes_b[:, 2:] - bboxes_b[:, :2], 1)
35
+ else:
36
+ tl = torch.max(
37
+ (bboxes_a[:, None, :2] - bboxes_a[:, None, 2:] / 2),
38
+ (bboxes_b[:, :2] - bboxes_b[:, 2:] / 2),
39
+ )
40
+ br = torch.min(
41
+ (bboxes_a[:, None, :2] + bboxes_a[:, None, 2:] / 2),
42
+ (bboxes_b[:, :2] + bboxes_b[:, 2:] / 2),
43
+ )
44
+
45
+ area_a = torch.prod(bboxes_a[:, 2:], 1)
46
+ area_b = torch.prod(bboxes_b[:, 2:], 1)
47
+ en = (tl < br).type(tl.type()).prod(dim=2)
48
+ area_i = torch.prod(br - tl, 2) * en # * ((tl < br).all())
49
+ return area_i / (area_a[:, None] + area_b - area_i)
50
+
51
+
52
+ class YOLOXHead(nn.Module):
53
+ def __init__(
54
+ self,
55
+ num_classes,
56
+ width=1.0,
57
+ strides=[8, 16, 32],
58
+ in_channels=[256, 512, 1024],
59
+ act="silu",
60
+ depthwise=False,
61
+ ):
62
+ """
63
+ Args:
64
+ act (str): activation type of conv. Defalut value: "silu".
65
+ depthwise (bool): whether apply depthwise conv in conv branch. Defalut value: False.
66
+ """
67
+ super().__init__()
68
+
69
+ self.num_classes = num_classes
70
+ self.decode_in_inference = True # for deploy, set to False
71
+
72
+ self.cls_convs = nn.ModuleList()
73
+ self.reg_convs = nn.ModuleList()
74
+ self.cls_preds = nn.ModuleList()
75
+ self.reg_preds = nn.ModuleList()
76
+ self.obj_preds = nn.ModuleList()
77
+ self.stems = nn.ModuleList()
78
+ Conv = DWConv if depthwise else BaseConv
79
+
80
+ for i in range(len(in_channels)):
81
+ self.stems.append(
82
+ BaseConv(
83
+ in_channels=int(in_channels[i] * width),
84
+ out_channels=int(256 * width),
85
+ ksize=1,
86
+ stride=1,
87
+ act=act,
88
+ )
89
+ )
90
+ self.cls_convs.append(
91
+ nn.Sequential(
92
+ *[
93
+ Conv(
94
+ in_channels=int(256 * width),
95
+ out_channels=int(256 * width),
96
+ ksize=3,
97
+ stride=1,
98
+ act=act,
99
+ ),
100
+ Conv(
101
+ in_channels=int(256 * width),
102
+ out_channels=int(256 * width),
103
+ ksize=3,
104
+ stride=1,
105
+ act=act,
106
+ ),
107
+ ]
108
+ )
109
+ )
110
+ self.reg_convs.append(
111
+ nn.Sequential(
112
+ *[
113
+ Conv(
114
+ in_channels=int(256 * width),
115
+ out_channels=int(256 * width),
116
+ ksize=3,
117
+ stride=1,
118
+ act=act,
119
+ ),
120
+ Conv(
121
+ in_channels=int(256 * width),
122
+ out_channels=int(256 * width),
123
+ ksize=3,
124
+ stride=1,
125
+ act=act,
126
+ ),
127
+ ]
128
+ )
129
+ )
130
+ self.cls_preds.append(
131
+ nn.Conv2d(
132
+ in_channels=int(256 * width),
133
+ out_channels=self.num_classes,
134
+ kernel_size=1,
135
+ stride=1,
136
+ padding=0,
137
+ )
138
+ )
139
+ self.reg_preds.append(
140
+ nn.Conv2d(
141
+ in_channels=int(256 * width),
142
+ out_channels=4,
143
+ kernel_size=1,
144
+ stride=1,
145
+ padding=0,
146
+ )
147
+ )
148
+ self.obj_preds.append(
149
+ nn.Conv2d(
150
+ in_channels=int(256 * width),
151
+ out_channels=1,
152
+ kernel_size=1,
153
+ stride=1,
154
+ padding=0,
155
+ )
156
+ )
157
+
158
+ self.use_l1 = False
159
+ self.l1_loss = nn.L1Loss(reduction="none")
160
+ self.bcewithlog_loss = nn.BCEWithLogitsLoss(reduction="none")
161
+ self.iou_loss = None
162
+ self.strides = strides
163
+ self.grids = [torch.zeros(1)] * len(in_channels)
164
+
165
+ def forward(self, xin, labels=None, imgs=None):
166
+ outputs = []
167
+ for k, (cls_conv, reg_conv, stride_this_level, x) in enumerate(
168
+ zip(self.cls_convs, self.reg_convs, self.strides, xin)
169
+ ):
170
+ x = self.stems[k](x)
171
+ cls_x = x
172
+ reg_x = x
173
+
174
+ cls_feat = cls_conv(cls_x)
175
+ cls_output = self.cls_preds[k](cls_feat)
176
+
177
+ reg_feat = reg_conv(reg_x)
178
+ reg_output = self.reg_preds[k](reg_feat)
179
+ obj_output = self.obj_preds[k](reg_feat)
180
+
181
+ output = torch.cat(
182
+ [reg_output, obj_output.sigmoid(), cls_output.sigmoid()], 1
183
+ )
184
+
185
+ outputs.append(output)
186
+
187
+ self.hw = [x.shape[-2:] for x in outputs]
188
+ # [batch, n_anchors_all, 85]
189
+ outputs = torch.cat(
190
+ [x.flatten(start_dim=2) for x in outputs], dim=2
191
+ ).permute(0, 2, 1)
192
+ if self.decode_in_inference:
193
+ return self.decode_outputs(outputs, dtype=xin[0].type())
194
+ else:
195
+ return outputs
196
+
197
+ def get_output_and_grid(self, output, k, stride, dtype):
198
+ grid = self.grids[k]
199
+
200
+ batch_size = output.shape[0]
201
+ n_ch = 5 + self.num_classes
202
+ hsize, wsize = output.shape[-2:]
203
+ if grid.shape[2:4] != output.shape[2:4]:
204
+ yv, xv = meshgrid([torch.arange(hsize), torch.arange(wsize)])
205
+ grid = torch.stack((xv, yv), 2).view(1, 1, hsize, wsize, 2).type(dtype)
206
+ self.grids[k] = grid
207
+
208
+ output = output.view(batch_size, 1, n_ch, hsize, wsize)
209
+ output = output.permute(0, 1, 3, 4, 2).reshape(
210
+ batch_size, hsize * wsize, -1
211
+ )
212
+ grid = grid.view(1, -1, 2)
213
+ output[..., :2] = (output[..., :2] + grid) * stride
214
+ output[..., 2:4] = torch.exp(output[..., 2:4]) * stride
215
+ return output, grid
216
+
217
+ def decode_outputs(self, outputs, dtype):
218
+ grids = []
219
+ strides = []
220
+ for (hsize, wsize), stride in zip(self.hw, self.strides):
221
+ yv, xv = meshgrid([torch.arange(hsize), torch.arange(wsize)])
222
+ grid = torch.stack((xv, yv), 2).view(1, -1, 2)
223
+ grids.append(grid)
224
+ shape = grid.shape[:2]
225
+ strides.append(torch.full((*shape, 1), stride))
226
+
227
+ grids = torch.cat(grids, dim=1).type(dtype)
228
+ strides = torch.cat(strides, dim=1).type(dtype)
229
+
230
+ outputs = torch.cat([
231
+ (outputs[..., 0:2] + grids) * strides,
232
+ torch.exp(outputs[..., 2:4]) * strides,
233
+ outputs[..., 4:]
234
+ ], dim=-1)
235
+ return outputs
yolox/yolo_pafpn.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- encoding: utf-8 -*-
3
+ # Copyright (c) Megvii Inc. All rights reserved.
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ from .darknet import CSPDarknet
9
+ from .network_blocks import BaseConv, CSPLayer, DWConv
10
+
11
+
12
+ class YOLOPAFPN(nn.Module):
13
+ """
14
+ YOLOv3 model. Darknet 53 is the default backbone of this model.
15
+ """
16
+
17
+ def __init__(
18
+ self,
19
+ depth=1.0,
20
+ width=1.0,
21
+ in_features=("dark3", "dark4", "dark5"),
22
+ in_channels=[256, 512, 1024],
23
+ depthwise=False,
24
+ act="silu",
25
+ ):
26
+ super().__init__()
27
+ self.backbone = CSPDarknet(depth, width, depthwise=depthwise, act=act)
28
+ self.in_features = in_features
29
+ self.in_channels = in_channels
30
+ Conv = DWConv if depthwise else BaseConv
31
+
32
+ self.upsample = nn.Upsample(scale_factor=2, mode="nearest")
33
+ self.lateral_conv0 = BaseConv(
34
+ int(in_channels[2] * width), int(in_channels[1] * width), 1, 1, act=act
35
+ )
36
+ self.C3_p4 = CSPLayer(
37
+ int(2 * in_channels[1] * width),
38
+ int(in_channels[1] * width),
39
+ round(3 * depth),
40
+ False,
41
+ depthwise=depthwise,
42
+ act=act,
43
+ ) # cat
44
+
45
+ self.reduce_conv1 = BaseConv(
46
+ int(in_channels[1] * width), int(in_channels[0] * width), 1, 1, act=act
47
+ )
48
+ self.C3_p3 = CSPLayer(
49
+ int(2 * in_channels[0] * width),
50
+ int(in_channels[0] * width),
51
+ round(3 * depth),
52
+ False,
53
+ depthwise=depthwise,
54
+ act=act,
55
+ )
56
+
57
+ # bottom-up conv
58
+ self.bu_conv2 = Conv(
59
+ int(in_channels[0] * width), int(in_channels[0] * width), 3, 2, act=act
60
+ )
61
+ self.C3_n3 = CSPLayer(
62
+ int(2 * in_channels[0] * width),
63
+ int(in_channels[1] * width),
64
+ round(3 * depth),
65
+ False,
66
+ depthwise=depthwise,
67
+ act=act,
68
+ )
69
+
70
+ # bottom-up conv
71
+ self.bu_conv1 = Conv(
72
+ int(in_channels[1] * width), int(in_channels[1] * width), 3, 2, act=act
73
+ )
74
+ self.C3_n4 = CSPLayer(
75
+ int(2 * in_channels[1] * width),
76
+ int(in_channels[2] * width),
77
+ round(3 * depth),
78
+ False,
79
+ depthwise=depthwise,
80
+ act=act,
81
+ )
82
+
83
+ def forward(self, input):
84
+ """
85
+ Args:
86
+ inputs: input images.
87
+
88
+ Returns:
89
+ Tuple[Tensor]: FPN feature.
90
+ """
91
+
92
+ # backbone
93
+ out_features = self.backbone(input)
94
+ features = [out_features[f] for f in self.in_features]
95
+ [x2, x1, x0] = features
96
+
97
+ fpn_out0 = self.lateral_conv0(x0) # 1024->512/32
98
+ f_out0 = self.upsample(fpn_out0) # 512/16
99
+ f_out0 = torch.cat([f_out0, x1], 1) # 512->1024/16
100
+ f_out0 = self.C3_p4(f_out0) # 1024->512/16
101
+
102
+ fpn_out1 = self.reduce_conv1(f_out0) # 512->256/16
103
+ f_out1 = self.upsample(fpn_out1) # 256/8
104
+ f_out1 = torch.cat([f_out1, x2], 1) # 256->512/8
105
+ pan_out2 = self.C3_p3(f_out1) # 512->256/8
106
+
107
+ p_out1 = self.bu_conv2(pan_out2) # 256->256/16
108
+ p_out1 = torch.cat([p_out1, fpn_out1], 1) # 256->512/16
109
+ pan_out1 = self.C3_n3(p_out1) # 512->512/16
110
+
111
+ p_out0 = self.bu_conv1(pan_out1) # 512->512/32
112
+ p_out0 = torch.cat([p_out0, fpn_out0], 1) # 512->1024/32
113
+ pan_out0 = self.C3_n4(p_out0) # 1024->1024/32
114
+
115
+ outputs = (pan_out2, pan_out1, pan_out0)
116
+ return outputs
yolox/yolox.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- encoding: utf-8 -*-
3
+ # Copyright (c) Megvii Inc. All rights reserved.
4
+
5
+ import torch.nn as nn
6
+
7
+ from .yolo_head import YOLOXHead
8
+ from .yolo_pafpn import YOLOPAFPN
9
+
10
+
11
+ class YOLOX(nn.Module):
12
+ """
13
+ YOLOX model module. The module list is defined by create_yolov3_modules function.
14
+ The network returns loss values from three YOLO layers during training
15
+ and detection results during test.
16
+ """
17
+
18
+ def __init__(self, backbone=None, head=None):
19
+ super().__init__()
20
+ if backbone is None:
21
+ backbone = YOLOPAFPN()
22
+ if head is None:
23
+ head = YOLOXHead(80)
24
+
25
+ self.backbone = backbone
26
+ self.head = head
27
+
28
+ def forward(self, x, targets=None):
29
+ assert not self.training, "Training mode not supported, please refer to the YOLOX repo"
30
+ fpn_outs = self.backbone(x)
31
+ outputs = self.head(fpn_outs)
32
+ return outputs