| import torch |
| import torch.nn as nn |
|
|
|
|
| class Exp: |
| """ |
| Configuration class for the page element model. |
| """ |
| def __init__(self): |
| self.name = "page-element-v3" |
| self.ckpt = "weights.pth" |
| self.device = "cuda:0" if torch.cuda.is_available() else "cpu" |
|
|
| |
| self.act = "silu" |
| self.depth = 1.00 |
| self.width = 1.00 |
| self.labels = ["table", "chart", "title", "infographic", "paragraph", "header_footer"] |
| self.num_classes = len(self.labels) |
|
|
| |
| self.size = (1024, 1024) |
| self.min_bbox_size = 0 |
| self.normalize_boxes = True |
|
|
| |
| self.conf_thresh = 0.01 |
| self.iou_thresh = 0.5 |
| self.class_agnostic = True |
|
|
| self.thresholds_per_class = { |
| "table": 0.1, |
| "chart": 0.01, |
| "infographic": 0.01, |
| "title": 0.1, |
| "paragraph": 0.1, |
| "header_footer": 0.1, |
| } |
|
|
| def get_model(self): |
| """ |
| Get the YOLOX model. |
| """ |
| from yolox import YOLOX, YOLOPAFPN, YOLOXHead |
|
|
| |
| if getattr(self, "model", None) is None: |
| in_channels = [256, 512, 1024] |
| backbone = YOLOPAFPN(self.depth, self.width, in_channels=in_channels, act=self.act) |
| head = YOLOXHead(self.num_classes, self.width, in_channels=in_channels, act=self.act) |
| self.model = YOLOX(backbone, head) |
|
|
| |
| def init_yolo(M): |
| for m in M.modules(): |
| if isinstance(m, nn.BatchNorm2d): |
| m.eps = 1e-3 |
| m.momentum = 0.03 |
| self.model.apply(init_yolo) |
|
|
| return self.model |
|
|