|
|
""" |
|
|
模型工厂模块 |
|
|
Model Factory for PAD Predictor |
|
|
|
|
|
该模块提供了从配置文件创建模型、损失函数和优化器的工厂函数, |
|
|
支持不同的模型变体和配置。 |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.optim as optim |
|
|
import yaml |
|
|
import json |
|
|
from typing import Dict, Any, Optional, Union, Tuple |
|
|
from pathlib import Path |
|
|
import logging |
|
|
|
|
|
from .pad_predictor import PADPredictor |
|
|
from .loss_functions import create_loss_function |
|
|
from .metrics import create_metrics |
|
|
|
|
|
|
|
|
class ModelFactory: |
|
|
"""模型工厂类""" |
|
|
|
|
|
def __init__(self): |
|
|
self.logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
self.model_registry = { |
|
|
'pad_predictor': PADPredictor, |
|
|
} |
|
|
|
|
|
|
|
|
self.optimizer_registry = { |
|
|
'adam': optim.Adam, |
|
|
'adamw': optim.AdamW, |
|
|
'sgd': optim.SGD, |
|
|
'rmsprop': optim.RMSprop, |
|
|
'adagrad': optim.Adagrad, |
|
|
} |
|
|
|
|
|
|
|
|
self.scheduler_registry = { |
|
|
'step': optim.lr_scheduler.StepLR, |
|
|
'exponential': optim.lr_scheduler.ExponentialLR, |
|
|
'cosine': optim.lr_scheduler.CosineAnnealingLR, |
|
|
'plateau': optim.lr_scheduler.ReduceLROnPlateau, |
|
|
'cyclic': optim.lr_scheduler.CyclicLR, |
|
|
} |
|
|
|
|
|
def create_model(self, |
|
|
model_config: Union[str, Dict[str, Any]]) -> nn.Module: |
|
|
""" |
|
|
创建模型 |
|
|
|
|
|
Args: |
|
|
model_config: 模型配置,可以是配置文件路径或配置字典 |
|
|
|
|
|
Returns: |
|
|
模型实例 |
|
|
""" |
|
|
|
|
|
if isinstance(model_config, str): |
|
|
config = self._load_config(model_config) |
|
|
else: |
|
|
config = model_config |
|
|
|
|
|
|
|
|
model_type = config.get('model_info', {}).get('type', 'pad_predictor') |
|
|
|
|
|
if model_type not in self.model_registry: |
|
|
raise ValueError(f"不支持的模型类型: {model_type}. 支持的类型: {list(self.model_registry.keys())}") |
|
|
|
|
|
|
|
|
model_class = self.model_registry[model_type] |
|
|
|
|
|
if model_type == 'pad_predictor': |
|
|
model = self._create_pad_predictor(config) |
|
|
else: |
|
|
|
|
|
model = model_class(**config.get('model_params', {})) |
|
|
|
|
|
self.logger.info(f"成功创建模型: {model_type}") |
|
|
return model |
|
|
|
|
|
def _create_pad_predictor(self, config: Dict[str, Any]) -> PADPredictor: |
|
|
""" |
|
|
创建PAD预测器 |
|
|
|
|
|
Args: |
|
|
config: 配置字典 |
|
|
|
|
|
Returns: |
|
|
PADPredictor实例 |
|
|
""" |
|
|
dimensions = config.get('dimensions', {}) |
|
|
architecture = config.get('architecture', {}) |
|
|
initialization = config.get('initialization', {}) |
|
|
|
|
|
|
|
|
hidden_layers = architecture.get('hidden_layers', []) |
|
|
hidden_dims = [layer['size'] for layer in hidden_layers] |
|
|
|
|
|
|
|
|
if not hidden_dims: |
|
|
hidden_dims = [128, 64, 32] |
|
|
|
|
|
|
|
|
dropout_rate = architecture.get('dropout_config', {}).get('rate', 0.3) |
|
|
|
|
|
model = PADPredictor( |
|
|
input_dim=dimensions.get('input_dim', 10), |
|
|
output_dim=dimensions.get('output_dim', 4), |
|
|
hidden_dims=hidden_dims, |
|
|
dropout_rate=dropout_rate, |
|
|
weight_init=initialization.get('weight_init', 'xavier_uniform'), |
|
|
bias_init=initialization.get('bias_init', 'zeros') |
|
|
) |
|
|
|
|
|
return model |
|
|
|
|
|
def create_loss_function(self, |
|
|
loss_config: Union[str, Dict[str, Any]]) -> nn.Module: |
|
|
""" |
|
|
创建损失函数 |
|
|
|
|
|
Args: |
|
|
loss_config: 损失函数配置,可以是配置文件路径或配置字典 |
|
|
|
|
|
Returns: |
|
|
损失函数实例 |
|
|
""" |
|
|
|
|
|
if isinstance(loss_config, str): |
|
|
config = self._load_config(loss_config) |
|
|
else: |
|
|
config = loss_config |
|
|
|
|
|
|
|
|
loss_type = config.get('type', 'wmse') |
|
|
loss_params = config.get('params', {}) |
|
|
|
|
|
return create_loss_function(loss_type, **loss_params) |
|
|
|
|
|
def create_optimizer(self, |
|
|
model: nn.Module, |
|
|
optimizer_config: Union[str, Dict[str, Any]]) -> optim.Optimizer: |
|
|
""" |
|
|
创建优化器 |
|
|
|
|
|
Args: |
|
|
model: 模型 |
|
|
optimizer_config: 优化器配置 |
|
|
|
|
|
Returns: |
|
|
优化器实例 |
|
|
""" |
|
|
|
|
|
if isinstance(optimizer_config, str): |
|
|
config = self._load_config(optimizer_config) |
|
|
else: |
|
|
config = optimizer_config |
|
|
|
|
|
|
|
|
optimizer_type = config.get('type', 'adamw') |
|
|
optimizer_params = config.get('params', {}) |
|
|
|
|
|
if optimizer_type not in self.optimizer_registry: |
|
|
raise ValueError(f"不支持的优化器类型: {optimizer_type}. 支持的类型: {list(self.optimizer_registry.keys())}") |
|
|
|
|
|
|
|
|
default_params = { |
|
|
'lr': 1e-3, |
|
|
'weight_decay': 1e-4, |
|
|
} |
|
|
default_params.update(optimizer_params) |
|
|
|
|
|
optimizer_class = self.optimizer_registry[optimizer_type] |
|
|
optimizer = optimizer_class(model.parameters(), **default_params) |
|
|
|
|
|
self.logger.info(f"成功创建优化器: {optimizer_type}") |
|
|
return optimizer |
|
|
|
|
|
def create_scheduler(self, |
|
|
optimizer: optim.Optimizer, |
|
|
scheduler_config: Union[str, Dict[str, Any]]) -> Optional[optim.lr_scheduler._LRScheduler]: |
|
|
""" |
|
|
创建学习率调度器 |
|
|
|
|
|
Args: |
|
|
optimizer: 优化器 |
|
|
scheduler_config: 调度器配置 |
|
|
|
|
|
Returns: |
|
|
学习率调度器实例,如果配置为空则返回None |
|
|
""" |
|
|
if not scheduler_config: |
|
|
return None |
|
|
|
|
|
|
|
|
if isinstance(scheduler_config, str): |
|
|
config = self._load_config(scheduler_config) |
|
|
else: |
|
|
config = scheduler_config |
|
|
|
|
|
|
|
|
scheduler_type = config.get('type', 'step') |
|
|
scheduler_params = config.get('params', {}) |
|
|
|
|
|
if scheduler_type not in self.scheduler_registry: |
|
|
raise ValueError(f"不支持的调度器类型: {scheduler_type}. 支持的类型: {list(self.scheduler_registry.keys())}") |
|
|
|
|
|
|
|
|
default_params = {} |
|
|
if scheduler_type == 'step': |
|
|
default_params = {'step_size': 10, 'gamma': 0.1} |
|
|
elif scheduler_type == 'exponential': |
|
|
default_params = {'gamma': 0.95} |
|
|
elif scheduler_type == 'cosine': |
|
|
default_params = {'T_max': 100} |
|
|
elif scheduler_type == 'plateau': |
|
|
default_params = {'mode': 'min', 'patience': 10, 'factor': 0.5} |
|
|
|
|
|
default_params.update(scheduler_params) |
|
|
|
|
|
scheduler_class = self.scheduler_registry[scheduler_type] |
|
|
scheduler = scheduler_class(optimizer, **default_params) |
|
|
|
|
|
self.logger.info(f"成功创建学习率调度器: {scheduler_type}") |
|
|
return scheduler |
|
|
|
|
|
def create_metrics(self, |
|
|
metrics_config: Union[str, Dict[str, Any]]) -> Any: |
|
|
""" |
|
|
创建评估指标 |
|
|
|
|
|
Args: |
|
|
metrics_config: 指标配置 |
|
|
|
|
|
Returns: |
|
|
指标实例 |
|
|
""" |
|
|
|
|
|
if isinstance(metrics_config, str): |
|
|
config = self._load_config(metrics_config) |
|
|
else: |
|
|
config = metrics_config |
|
|
|
|
|
metric_type = config.get('type', 'pad') |
|
|
metric_params = config.get('params', {}) |
|
|
|
|
|
return create_metrics(metric_type, **metric_params) |
|
|
|
|
|
def create_training_components(self, |
|
|
config: Union[str, Dict[str, Any]]) -> Tuple[nn.Module, nn.Module, optim.Optimizer, Optional[optim.lr_scheduler._LRScheduler]]: |
|
|
""" |
|
|
创建训练所需的所有组件 |
|
|
|
|
|
Args: |
|
|
config: 完整配置 |
|
|
|
|
|
Returns: |
|
|
(模型, 损失函数, 优化器, 学习率调度器) |
|
|
""" |
|
|
|
|
|
if isinstance(config, str): |
|
|
full_config = self._load_config(config) |
|
|
else: |
|
|
full_config = config |
|
|
|
|
|
|
|
|
model = self.create_model(full_config) |
|
|
|
|
|
|
|
|
loss_config = full_config.get('loss', {'type': 'wmse'}) |
|
|
loss_function = self.create_loss_function(loss_config) |
|
|
|
|
|
|
|
|
optimizer_config = full_config.get('optimizer', {'type': 'adamw'}) |
|
|
optimizer = self.create_optimizer(model, optimizer_config) |
|
|
|
|
|
|
|
|
scheduler_config = full_config.get('scheduler', {}) |
|
|
scheduler = self.create_scheduler(optimizer, scheduler_config) |
|
|
|
|
|
return model, loss_function, optimizer, scheduler |
|
|
|
|
|
def _load_config(self, config_path: str) -> Dict[str, Any]: |
|
|
""" |
|
|
加载配置文件 |
|
|
|
|
|
Args: |
|
|
config_path: 配置文件路径 |
|
|
|
|
|
Returns: |
|
|
配置字典 |
|
|
""" |
|
|
config_path = Path(config_path) |
|
|
|
|
|
if not config_path.exists(): |
|
|
raise FileNotFoundError(f"配置文件不存在: {config_path}") |
|
|
|
|
|
with open(config_path, 'r', encoding='utf-8') as f: |
|
|
if config_path.suffix.lower() in ['.yaml', '.yml']: |
|
|
config = yaml.safe_load(f) |
|
|
elif config_path.suffix.lower() == '.json': |
|
|
config = json.load(f) |
|
|
else: |
|
|
raise ValueError(f"不支持的配置文件格式: {config_path.suffix}") |
|
|
|
|
|
self.logger.info(f"成功加载配置文件: {config_path}") |
|
|
return config |
|
|
|
|
|
def register_model(self, name: str, model_class: type): |
|
|
""" |
|
|
注册新的模型类型 |
|
|
|
|
|
Args: |
|
|
name: 模型名称 |
|
|
model_class: 模型类 |
|
|
""" |
|
|
self.model_registry[name] = model_class |
|
|
self.logger.info(f"注册新模型类型: {name}") |
|
|
|
|
|
def register_optimizer(self, name: str, optimizer_class: type): |
|
|
""" |
|
|
注册新的优化器类型 |
|
|
|
|
|
Args: |
|
|
name: 优化器名称 |
|
|
optimizer_class: 优化器类 |
|
|
""" |
|
|
self.optimizer_registry[name] = optimizer_class |
|
|
self.logger.info(f"注册新优化器类型: {name}") |
|
|
|
|
|
def get_available_models(self) -> list: |
|
|
"""获取可用的模型类型""" |
|
|
return list(self.model_registry.keys()) |
|
|
|
|
|
def get_available_optimizers(self) -> list: |
|
|
"""获取可用的优化器类型""" |
|
|
return list(self.optimizer_registry.keys()) |
|
|
|
|
|
def get_available_schedulers(self) -> list: |
|
|
"""获取可用的调度器类型""" |
|
|
return list(self.scheduler_registry.keys()) |
|
|
|
|
|
|
|
|
|
|
|
model_factory = ModelFactory() |
|
|
|
|
|
|
|
|
def create_model_from_config(config_path: str) -> nn.Module: |
|
|
""" |
|
|
从配置文件创建模型的便捷函数 |
|
|
|
|
|
Args: |
|
|
config_path: 配置文件路径 |
|
|
|
|
|
Returns: |
|
|
模型实例 |
|
|
""" |
|
|
return model_factory.create_model(config_path) |
|
|
|
|
|
|
|
|
def create_training_setup(config_path: str) -> Tuple[nn.Module, nn.Module, optim.Optimizer, Optional[optim.lr_scheduler._LRScheduler]]: |
|
|
""" |
|
|
从配置文件创建完整训练设置的便捷函数 |
|
|
|
|
|
Args: |
|
|
config_path: 配置文件路径 |
|
|
|
|
|
Returns: |
|
|
(模型, 损失函数, 优化器, 学习率调度器) |
|
|
""" |
|
|
return model_factory.create_training_components(config_path) |
|
|
|
|
|
|
|
|
def save_model_config(model: nn.Module, config_path: str, additional_info: Dict[str, Any] = None): |
|
|
""" |
|
|
保存模型配置 |
|
|
|
|
|
Args: |
|
|
model: 模型实例 |
|
|
config_path: 配置文件保存路径 |
|
|
additional_info: 额外信息 |
|
|
""" |
|
|
config = { |
|
|
'model_info': { |
|
|
'type': model.__class__.__name__, |
|
|
'version': '1.0' |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
if isinstance(model, PADPredictor): |
|
|
config.update({ |
|
|
'dimensions': { |
|
|
'input_dim': model.input_dim, |
|
|
'output_dim': model.output_dim |
|
|
}, |
|
|
'architecture': { |
|
|
'hidden_layers': [ |
|
|
{'size': dim, 'activation': 'ReLU', 'dropout': model.dropout_rate} |
|
|
for dim in model.hidden_dims[:-1] |
|
|
] + [ |
|
|
{'size': model.hidden_dims[-1], 'activation': 'ReLU', 'dropout': 0.0} |
|
|
], |
|
|
'output_layer': {'activation': 'Linear'} |
|
|
}, |
|
|
'initialization': { |
|
|
'weight_init': model.weight_init, |
|
|
'bias_init': model.bias_init |
|
|
} |
|
|
}) |
|
|
|
|
|
|
|
|
if additional_info: |
|
|
config['additional_info'] = additional_info |
|
|
|
|
|
|
|
|
config_path = Path(config_path) |
|
|
config_path.parent.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
with open(config_path, 'w', encoding='utf-8') as f: |
|
|
if config_path.suffix.lower() in ['.yaml', '.yml']: |
|
|
yaml.dump(config, f, default_flow_style=False, allow_unicode=True) |
|
|
elif config_path.suffix.lower() == '.json': |
|
|
json.dump(config, f, indent=2, ensure_ascii=False) |
|
|
else: |
|
|
raise ValueError(f"不支持的配置文件格式: {config_path.suffix}") |
|
|
|
|
|
logging.info(f"模型配置已保存到: {config_path}") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
import tempfile |
|
|
import os |
|
|
|
|
|
print("测试模型工厂:") |
|
|
|
|
|
|
|
|
config = { |
|
|
'model_info': { |
|
|
'name': 'Test_PAD_Predictor', |
|
|
'type': 'pad_predictor', |
|
|
'version': '1.0' |
|
|
}, |
|
|
'dimensions': { |
|
|
'input_dim': 10, |
|
|
'output_dim': 4 |
|
|
}, |
|
|
'architecture': { |
|
|
'hidden_layers': [ |
|
|
{'size': 128, 'activation': 'ReLU', 'dropout': 0.3}, |
|
|
{'size': 64, 'activation': 'ReLU', 'dropout': 0.3}, |
|
|
{'size': 32, 'activation': 'ReLU', 'dropout': 0.0} |
|
|
], |
|
|
'output_layer': {'activation': 'Linear'} |
|
|
}, |
|
|
'initialization': { |
|
|
'weight_init': 'xavier_uniform', |
|
|
'bias_init': 'zeros' |
|
|
}, |
|
|
'loss': { |
|
|
'type': 'wmse', |
|
|
'params': { |
|
|
'delta_pad_weight': 1.0, |
|
|
'delta_pressure_weight': 1.0, |
|
|
'confidence_weight': 0.5 |
|
|
} |
|
|
}, |
|
|
'optimizer': { |
|
|
'type': 'adamw', |
|
|
'params': { |
|
|
'lr': 0.001, |
|
|
'weight_decay': 0.0001 |
|
|
} |
|
|
}, |
|
|
'scheduler': { |
|
|
'type': 'step', |
|
|
'params': { |
|
|
'step_size': 10, |
|
|
'gamma': 0.1 |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
with tempfile.NamedTemporaryFile(mode='w', suffix='.yaml', delete=False) as f: |
|
|
yaml.dump(config, f) |
|
|
temp_config_path = f.name |
|
|
|
|
|
try: |
|
|
|
|
|
model = model_factory.create_model(temp_config_path) |
|
|
print(f"成功创建模型: {model.__class__.__name__}") |
|
|
|
|
|
|
|
|
loss_fn = model_factory.create_loss_function(config['loss']) |
|
|
print(f"成功创建损失函数: {loss_fn.__class__.__name__}") |
|
|
|
|
|
|
|
|
optimizer = model_factory.create_optimizer(model, config['optimizer']) |
|
|
print(f"成功创建优化器: {optimizer.__class__.__name__}") |
|
|
|
|
|
|
|
|
scheduler = model_factory.create_scheduler(optimizer, config['scheduler']) |
|
|
if scheduler: |
|
|
print(f"成功创建学习率调度器: {scheduler.__class__.__name__}") |
|
|
|
|
|
|
|
|
model, loss_fn, optimizer, scheduler = model_factory.create_training_components(temp_config_path) |
|
|
print(f"成功创建完整训练设置") |
|
|
|
|
|
|
|
|
print(f"\n可用模型类型: {model_factory.get_available_models()}") |
|
|
print(f"可用优化器类型: {model_factory.get_available_optimizers()}") |
|
|
print(f"可用调度器类型: {model_factory.get_available_schedulers()}") |
|
|
|
|
|
finally: |
|
|
|
|
|
os.unlink(temp_config_path) |
|
|
|
|
|
print("\n模型工厂测试完成!") |