Spaces:
Running on Zero
Running on Zero
| from models.attention_fusion import LocalFusion | |
| from models.bezier_control_point_estimator import BCPE | |
| from models.color_naming import ColorNaming | |
| from models.backbone import Backbone | |
| from torch import nn | |
| from PIL import Image | |
| from torchvision.transforms import functional as TF | |
| import torch | |
| class NamedCurves(nn.Module): | |
| def __init__(self, configs: dict): | |
| super().__init__() | |
| self.model_configs = configs | |
| self.backbone = Backbone(**configs['backbone']['params']) | |
| self.color_naming = ColorNaming(num_categories=configs['color_naming']['num_categories']) | |
| self.bcpe = BCPE(**configs['bezier_control_points_estimator']['params']) | |
| self.local_fusion = LocalFusion(**configs['local_fusion']['params']) | |
| def forward(self, x, return_backbone=False): | |
| x_backbone = self.backbone(x) | |
| cn_probs = self.color_naming(x_backbone) | |
| x_global = self.bcpe(x_backbone, cn_probs) | |
| out = self.local_fusion(x_global, cn_probs, q=x_backbone) | |
| if return_backbone: | |
| return out, x_backbone | |
| return out |