| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| """ |
| Factory module for creating models. |
| |
| Provides ModelFactory for instantiating pre-decoder models from config. |
| """ |
|
|
|
|
| class ModelFactory: |
|
|
| @staticmethod |
| def create_model(cfg): |
| if cfg.code == "surface": |
| return ModelFactory._create_surface_model(cfg) |
| else: |
| raise ValueError("Invalid model name") |
|
|
| @staticmethod |
| def _create_surface_model(cfg): |
| if cfg.model.version == "predecoder_memory_v1": |
| from model.predecoder import PreDecoderModelMemory_v1 |
| model = PreDecoderModelMemory_v1(cfg) |
| return model |
| elif cfg.model.version == "predecoder_sd_litenet_v1": |
| from model.predecoder_sd_litenet_v1 import PredecoderSDLiteNetV1 |
| model = PredecoderSDLiteNetV1( |
| input_channels=getattr(cfg.model, "input_channels", 4), |
| out_channels=getattr(cfg.model, "out_channels", 4), |
| hidden_dim=getattr(cfg.model, "hidden_dim", 64), |
| bottleneck_dim=getattr(cfg.model, "bottleneck_dim", 16), |
| dropout_p=getattr(cfg.model, "dropout_p", 0.05), |
| ) |
| return model |
| elif cfg.model.version == "predecoder_fasthyper_rf13_v1": |
| from model.predecoder_fasthyper_rf13_v1 import PredecoderFastHyperRF13V1 |
| model = PredecoderFastHyperRF13V1( |
| input_channels=getattr(cfg.model, "input_channels", 4), |
| out_channels=getattr(cfg.model, "out_channels", 4), |
| hidden_dim=getattr(cfg.model, "hidden_dim", 96), |
| mid_dim=getattr(cfg.model, "mid_dim", 144), |
| mix_groups=getattr(cfg.model, "mix_groups", 6), |
| num_blocks=getattr(cfg.model, "num_blocks", 5), |
| stem_kernel_size=getattr(cfg.model, "stem_kernel_size", 3), |
| dropout_p=getattr(cfg.model, "dropout_p", 0.02), |
| gate_reduction=getattr(cfg.model, "gate_reduction", 4), |
| ) |
| return model |
| else: |
| raise ValueError(f"Invalid model version: {cfg.model.version}") |
|
|