| # API参考文档 | |
| 本文档详细介绍了情绪与生理状态变化预测模型的所有API接口、类和函数。 | |
| ## 目录 | |
| 1. [模型类](#模型类) | |
| 2. [数据处理类](#数据处理类) | |
| 3. [工具类](#工具类) | |
| 4. [损失函数](#损失函数) | |
| 5. [评估指标](#评估指标) | |
| 6. [工厂函数](#工厂函数) | |
| 7. [命令行接口](#命令行接口) | |
| ## 模型类 | |
| ### `PADPredictor` | |
| 基于多层感知机的情绪与生理状态变化预测器。 | |
| ```python | |
| class PADPredictor(nn.Module): | |
| def __init__(self, | |
| input_dim: int = 7, | |
| output_dim: int = 3, | |
| hidden_dims: list = [512, 256, 128], | |
| dropout_rate: float = 0.3, | |
| weight_init: str = "xavier_uniform", | |
| bias_init: str = "zeros") | |
| ``` | |
| #### 参数 | |
| - `input_dim` (int): 输入维度,默认为7(用户PAD 3维 + Vitality 1维 + AI当前PAD 3维) | |
| - `output_dim` (int): 输出维度,默认为3(ΔPAD 3维,压力通过公式动态计算) | |
| - `hidden_dims` (list): 隐藏层维度列表,默认为[512, 256, 128] | |
| - `dropout_rate` (float): Dropout概率,默认为0.3 | |
| - `weight_init` (str): 权重初始化方法,默认为"xavier_uniform" | |
| - `bias_init` (str): 偏置初始化方法,默认为"zeros" | |
| #### 方法 | |
| ##### `forward(self, x: torch.Tensor) -> torch.Tensor` | |
| 前向传播。 | |
| **参数:** | |
| - `x` (torch.Tensor): 输入张量,形状为 (batch_size, input_dim) | |
| **返回:** | |
| - `torch.Tensor`: 输出张量,形状为 (batch_size, output_dim) | |
| **示例:** | |
| ```python | |
| import torch | |
| from src.models.pad_predictor import PADPredictor | |
| model = PADPredictor() | |
| input_data = torch.randn(4, 7) # batch_size=4, input_dim=7 | |
| output = model(input_data) | |
| print(f"Output shape: {output.shape}") # torch.Size([4, 3]) | |
| ``` | |
| ##### `predict_components(self, x: torch.Tensor) -> Dict[str, torch.Tensor]` | |
| 预测并分解输出组件。 | |
| **参数:** | |
| - `x` (torch.Tensor): 输入张量 | |
| **返回:** | |
| - `Dict[str, torch.Tensor]`: 包含各组件的字典 | |
| - `'delta_pad'`: ΔPAD (3维) | |
| - `'delta_pressure'`: ΔPressure (1维,动态计算) | |
| - `'confidence'`: Confidence (1维,可选) | |
| **示例:** | |
| ```python | |
| components = model.predict_components(input_data) | |
| print(f"ΔPAD shape: {components['delta_pad'].shape}") # torch.Size([4, 3]) | |
| print(f"ΔPressure shape: {components['delta_pressure'].shape}") # torch.Size([4, 1]) | |
| print(f"Confidence shape: {components['confidence'].shape}") # torch.Size([4, 1]) | |
| ``` | |
| ##### `get_model_info(self) -> Dict[str, Any]` | |
| 获取模型信息。 | |
| **返回:** | |
| - `Dict[str, Any]`: 包含模型信息的字典 | |
| **示例:** | |
| ```python | |
| info = model.get_model_info() | |
| print(f"Model type: {info['model_type']}") | |
| print(f"Total parameters: {info['total_parameters']}") | |
| print(f"Trainable parameters: {info['trainable_parameters']}") | |
| ``` | |
| ##### `save_model(self, filepath: str, include_optimizer: bool = False, optimizer: Optional[torch.optim.Optimizer] = None)` | |
| 保存模型到文件。 | |
| **参数:** | |
| - `filepath` (str): 保存路径 | |
| - `include_optimizer` (bool): 是否包含优化器状态,默认为False | |
| - `optimizer` (Optional[torch.optim.Optimizer]): 优化器对象 | |
| **示例:** | |
| ```python | |
| model.save_model("model.pth", include_optimizer=True, optimizer=optimizer) | |
| ``` | |
| ##### `load_model(cls, filepath: str, device: str = 'cpu') -> 'PADPredictor'` | |
| 从文件加载模型。 | |
| **参数:** | |
| - `filepath` (str): 模型文件路径 | |
| - `device` (str): 设备类型,默认为'cpu' | |
| **返回:** | |
| - `PADPredictor`: 加载的模型实例 | |
| **示例:** | |
| ```python | |
| loaded_model = PADPredictor.load_model("model.pth", device='cuda') | |
| ``` | |
| ##### `freeze_layers(self, layer_names: list = None)` | |
| 冻结指定层的参数。 | |
| **参数:** | |
| - `layer_names` (list): 要冻结的层名称列表,如果为None则冻结所有层 | |
| **示例:** | |
| ```python | |
| # 冻结所有层 | |
| model.freeze_layers() | |
| # 冻结特定层 | |
| model.freeze_layers(['network.0.weight', 'network.2.weight']) | |
| ``` | |
| ##### `unfreeze_layers(self, layer_names: list = None)` | |
| 解冻指定层的参数。 | |
| **参数:** | |
| - `layer_names` (list): 要解冻的层名称列表,如果为None则解冻所有层 | |
| ## 数据处理类 | |
| ### `DataPreprocessor` | |
| 数据预处理器,负责特征标准化和标签处理。 | |
| ```python | |
| class DataPreprocessor: | |
| def __init__(self, | |
| feature_scaler: str = "standard", | |
| label_scaler: str = "standard", | |
| feature_range: tuple = None, | |
| label_range: tuple = None) | |
| ``` | |
| #### 参数 | |
| - `feature_scaler` (str): 特征标准化方法,默认为"standard" | |
| - `label_scaler` (str): 标签标准化方法,默认为"standard" | |
| - `feature_range` (tuple): 特征范围,用于MinMax缩放 | |
| - `label_range` (tuple): 标签范围,用于MinMax缩放 | |
| #### 方法 | |
| ##### `fit(self, features: np.ndarray, labels: np.ndarray) -> 'DataPreprocessor'` | |
| 拟合预处理器参数。 | |
| **参数:** | |
| - `features` (np.ndarray): 训练特征数据 | |
| - `labels` (np.ndarray): 训练标签数据 | |
| **返回:** | |
| - `DataPreprocessor`: 自身实例 | |
| ##### `transform(self, features: np.ndarray, labels: np.ndarray = None) -> tuple` | |
| 转换数据。 | |
| **参数:** | |
| - `features` (np.ndarray): 输入特征数据 | |
| - `labels` (np.ndarray, optional): 输入标签数据 | |
| **返回:** | |
| - `tuple`: (转换后的特征, 转换后的标签) | |
| ##### `fit_transform(self, features: np.ndarray, labels: np.ndarray = None) -> tuple` | |
| 拟合并转换数据。 | |
| ##### `inverse_transform(self, features: np.ndarray, labels: np.ndarray = None) -> tuple` | |
| 逆转换数据。 | |
| ##### `save(self, filepath: str)` | |
| 保存预处理器到文件。 | |
| ##### `load(cls, filepath: str) -> 'DataPreprocessor'` | |
| 从文件加载预处理器。 | |
| **示例:** | |
| ```python | |
| from src.data.preprocessor import DataPreprocessor | |
| # 创建预处理器 | |
| preprocessor = DataPreprocessor( | |
| feature_scaler="standard", | |
| label_scaler="standard" | |
| ) | |
| # 拟合和转换数据 | |
| processed_features, processed_labels = preprocessor.fit_transform(train_features, train_labels) | |
| # 保存预处理器 | |
| preprocessor.save("preprocessor.pkl") | |
| # 加载预处理器 | |
| loaded_preprocessor = DataPreprocessor.load("preprocessor.pkl") | |
| ``` | |
| ### `SyntheticDataGenerator` | |
| 合成数据生成器,用于生成训练和测试数据。 | |
| ```python | |
| class SyntheticDataGenerator: | |
| def __init__(self, | |
| num_samples: int = 1000, | |
| seed: int = 42, | |
| noise_level: float = 0.1, | |
| correlation_strength: float = 0.5) | |
| ``` | |
| #### 参数 | |
| - `num_samples` (int): 生成的样本数量,默认为1000 | |
| - `seed` (int): 随机种子,默认为42 | |
| - `noise_level` (float): 噪声水平,默认为0.1 | |
| - `correlation_strength` (float): 相关性强度,默认为0.5 | |
| #### 方法 | |
| ##### `generate_data(self) -> tuple` | |
| 生成合成数据。 | |
| **返回:** | |
| - `tuple`: (特征数据, 标签数据) | |
| ##### `save_data(self, features: np.ndarray, labels: np.ndarray, filepath: str, format: str = 'csv')` | |
| 保存数据到文件。 | |
| **示例:** | |
| ```python | |
| from src.data.synthetic_generator import SyntheticDataGenerator | |
| # 创建数据生成器 | |
| generator = SyntheticDataGenerator(num_samples=1000, seed=42) | |
| # 生成数据 | |
| features, labels = generator.generate_data() | |
| # 保存数据 | |
| generator.save_data(features, labels, "synthetic_data.csv", format='csv') | |
| ``` | |
| ### `EmotionDataset` | |
| PyTorch数据集类,用于情绪预测任务。 | |
| ```python | |
| class EmotionDataset(Dataset): | |
| def __init__(self, | |
| features: np.ndarray, | |
| labels: np.ndarray, | |
| transform: callable = None) | |
| ``` | |
| #### 参数 | |
| - `features` (np.ndarray): 特征数据 | |
| - `labels` (np.ndarray): 标签数据 | |
| - `transform` (callable): 数据变换函数 | |
| ## 工具类 | |
| ### `InferenceEngine` | |
| 推理引擎,提供高性能的模型推理功能。 | |
| ```python | |
| class InferenceEngine: | |
| def __init__(self, | |
| model: nn.Module, | |
| preprocessor: DataPreprocessor = None, | |
| device: str = 'auto') | |
| ``` | |
| #### 方法 | |
| ##### `predict(self, input_data: Union[list, np.ndarray]) -> Dict[str, Any]` | |
| 单样本预测。 | |
| **参数:** | |
| - `input_data`: 输入数据,可以是列表或NumPy数组 | |
| **返回:** | |
| - `Dict[str, Any]`: 预测结果字典 | |
| **示例:** | |
| ```python | |
| from src.utils.inference_engine import create_inference_engine | |
| # 创建推理引擎 | |
| engine = create_inference_engine( | |
| model_path="model.pth", | |
| preprocessor_path="preprocessor.pkl" | |
| ) | |
| # 单样本预测 | |
| input_data = [0.5, 0.3, -0.2, 75.0, 0.1, 0.4, -0.1] | |
| result = engine.predict(input_data) | |
| print(f"ΔPAD: {result['delta_pad']}") | |
| print(f"Confidence: {result['confidence']}") | |
| ``` | |
| ##### `predict_batch(self, input_batch: Union[list, np.ndarray]) -> List[Dict[str, Any]]` | |
| 批量预测。 | |
| ##### `benchmark(self, num_samples: int = 1000, batch_size: int = 32) -> Dict[str, float]` | |
| 性能基准测试。 | |
| **返回:** | |
| - `Dict[str, float]`: 性能统计信息 | |
| **示例:** | |
| ```python | |
| # 性能基准测试 | |
| stats = engine.benchmark(num_samples=1000, batch_size=32) | |
| print(f"Throughput: {stats['throughput']:.2f} samples/sec") | |
| print(f"Average latency: {stats['avg_latency']:.2f}ms") | |
| ``` | |
| ### `ModelTrainer` | |
| 模型训练器,提供完整的训练流程管理。 | |
| ```python | |
| class ModelTrainer: | |
| def __init__(self, | |
| model: nn.Module, | |
| preprocessor: DataPreprocessor = None, | |
| device: str = 'auto') | |
| ``` | |
| #### 方法 | |
| ##### `train(self, train_loader: DataLoader, val_loader: DataLoader, config: Dict[str, Any]) -> Dict[str, Any]` | |
| 训练模型。 | |
| **参数:** | |
| - `train_loader` (DataLoader): 训练数据加载器 | |
| - `val_loader` (DataLoader): 验证数据加载器 | |
| - `config` (Dict[str, Any]): 训练配置 | |
| **返回:** | |
| - `Dict[str, Any]`: 训练历史记录 | |
| **示例:** | |
| ```python | |
| from src.utils.trainer import ModelTrainer | |
| # 创建训练器 | |
| trainer = ModelTrainer(model, preprocessor) | |
| # 训练配置 | |
| config = { | |
| 'epochs': 100, | |
| 'learning_rate': 0.001, | |
| 'weight_decay': 1e-4, | |
| 'patience': 10, | |
| 'save_dir': './models' | |
| } | |
| # 开始训练 | |
| history = trainer.train(train_loader, val_loader, config) | |
| ``` | |
| ##### `evaluate(self, test_loader: DataLoader) -> Dict[str, float]` | |
| 评估模型。 | |
| ## 损失函数 | |
| ### `WeightedMSELoss` | |
| 加权均方误差损失函数。 | |
| ```python | |
| class WeightedMSELoss(nn.Module): | |
| def __init__(self, | |
| delta_pad_weight: float = 1.0, | |
| delta_pressure_weight: float = 1.0, | |
| confidence_weight: float = 0.5, | |
| reduction: str = 'mean') | |
| ``` | |
| #### 参数 | |
| - `delta_pad_weight` (float): ΔPAD损失权重,默认为1.0 | |
| - `delta_pressure_weight` (float): ΔPressure损失权重,默认为1.0 | |
| - `confidence_weight` (float): 置信度损失权重,默认为0.5 | |
| - `reduction` (str): 损失缩减方式,默认为'mean' | |
| **示例:** | |
| ```python | |
| from src.models.loss_functions import WeightedMSELoss | |
| criterion = WeightedMSELoss( | |
| delta_pad_weight=1.0, | |
| delta_pressure_weight=1.0, | |
| confidence_weight=0.5 | |
| ) | |
| loss = criterion(predictions, targets) | |
| ``` | |
| ### `ConfidenceLoss` | |
| 置信度损失函数。 | |
| ```python | |
| class ConfidenceLoss(nn.Module): | |
| def __init__(self, reduction: str = 'mean') | |
| ``` | |
| ## 评估指标 | |
| ### `RegressionMetrics` | |
| 回归评估指标计算器。 | |
| ```python | |
| class RegressionMetrics: | |
| def __init__(self) | |
| ``` | |
| #### 方法 | |
| ##### `calculate_all_metrics(self, y_true: np.ndarray, y_pred: np.ndarray) -> Dict[str, float]` | |
| 计算所有回归指标。 | |
| **参数:** | |
| - `y_true` (np.ndarray): 真实值 | |
| - `y_pred` (np.ndarray): 预测值 | |
| **返回:** | |
| - `Dict[str, float]`: 包含所有指标的字典 | |
| **示例:** | |
| ```python | |
| from src.models.metrics import RegressionMetrics | |
| metrics_calculator = RegressionMetrics() | |
| metrics = metrics_calculator.calculate_all_metrics(true_labels, predictions) | |
| print(f"MSE: {metrics['mse']:.4f}") | |
| print(f"MAE: {metrics['mae']:.4f}") | |
| print(f"R²: {metrics['r2']:.4f}") | |
| ``` | |
| ### `PADMetrics` | |
| PAD专用评估指标。 | |
| ```python | |
| class PADMetrics: | |
| def __init__(self) | |
| ``` | |
| #### 方法 | |
| ##### `evaluate_predictions(self, predictions: np.ndarray, targets: np.ndarray) -> Dict[str, Any]` | |
| 评估PAD预测结果。 | |
| ## 工厂函数 | |
| ### `create_pad_predictor(config: Optional[Dict[str, Any]] = None) -> PADPredictor` | |
| 创建PAD预测器的工厂函数。 | |
| **参数:** | |
| - `config` (Dict[str, Any], optional): 配置字典 | |
| **返回:** | |
| - `PADPredictor`: PAD预测器实例 | |
| **示例:** | |
| ```python | |
| from src.models.pad_predictor import create_pad_predictor | |
| # 使用默认配置 | |
| model = create_pad_predictor() | |
| # 使用自定义配置 | |
| config = { | |
| 'dimensions': { | |
| 'input_dim': 7, | |
| 'output_dim': 4或3 | |
| }, | |
| 'architecture': { | |
| 'hidden_layers': [ | |
| {'size': 256, 'activation': 'ReLU', 'dropout': 0.3}, | |
| {'size': 128, 'activation': 'ReLU', 'dropout': 0.2} | |
| ] | |
| } | |
| } | |
| model = create_pad_predictor(config) | |
| ``` | |
| ### `create_inference_engine(model_path: str, preprocessor_path: str = None, device: str = 'auto') -> InferenceEngine` | |
| 创建推理引擎的工厂函数。 | |
| **参数:** | |
| - `model_path` (str): 模型文件路径 | |
| - `preprocessor_path` (str, optional): 预处理器文件路径 | |
| - `device` (str): 设备类型 | |
| **返回:** | |
| - `InferenceEngine`: 推理引擎实例 | |
| ### `create_training_setup(config: Dict[str, Any]) -> tuple` | |
| 创建训练设置的工厂函数。 | |
| **参数:** | |
| - `config` (Dict[str, Any]): 训练配置 | |
| **返回:** | |
| - `tuple`: (模型, 训练器, 数据加载器) | |
| ## 命令行接口 | |
| ### 主CLI工具 | |
| 项目提供了统一的命令行接口,支持多种操作: | |
| ```bash | |
| emotion-prediction <command> [options] | |
| ``` | |
| #### 可用命令 | |
| - `train`: 训练模型 | |
| - `predict`: 进行预测 | |
| - `evaluate`: 评估模型 | |
| - `inference`: 推理脚本 | |
| - `benchmark`: 性能基准测试 | |
| #### 训练命令 | |
| ```bash | |
| emotion-prediction train --config CONFIG_FILE [OPTIONS] | |
| ``` | |
| **参数:** | |
| - `--config, -c`: 训练配置文件路径(必需) | |
| - `--output-dir, -o`: 输出目录(默认: ./outputs) | |
| - `--device`: 计算设备(auto/cpu/cuda,默认: auto) | |
| - `--resume`: 从检查点恢复训练 | |
| - `--epochs`: 覆盖训练轮数 | |
| - `--batch-size`: 覆盖批次大小 | |
| - `--learning-rate`: 覆盖学习率 | |
| - `--seed`: 随机种子(默认: 42) | |
| - `--verbose, -v`: 详细输出 | |
| - `--log-level`: 日志级别(DEBUG/INFO/WARNING/ERROR) | |
| **示例:** | |
| ```bash | |
| # 基础训练 | |
| emotion-prediction train --config configs/training_config.yaml | |
| # GPU训练 | |
| emotion-prediction train --config configs/training_config.yaml --device cuda | |
| # 从检查点恢复 | |
| emotion-prediction train --config configs/training_config.yaml --resume checkpoint.pth | |
| ``` | |
| #### 预测命令 | |
| ```bash | |
| emotion-prediction predict --model MODEL_FILE [OPTIONS] | |
| ``` | |
| **参数:** | |
| - `--model, -m`: 模型文件路径(必需) | |
| - `--preprocessor, -p`: 预处理器文件路径 | |
| - `--interactive, -i`: 交互式模式 | |
| - `--quick`: 快速预测模式(7个数值) | |
| - `--batch`: 批量预测模式(输入文件) | |
| - `--output, -o`: 输出文件路径 | |
| - `--device`: 计算设备 | |
| - `--verbose, -v`: 详细输出 | |
| - `--log-level`: 日志级别 | |
| **示例:** | |
| ```bash | |
| # 交互式预测 | |
| emotion-prediction predict --model model.pth --interactive | |
| # 快速预测 | |
| emotion-prediction predict --model model.pth --quick 0.5 0.3 -0.2 75.0 0.1 0.4 -0.1 | |
| # 批量预测 | |
| emotion-prediction predict --model model.pth --batch input.csv --output results.csv | |
| ``` | |
| #### 评估命令 | |
| ```bash | |
| emotion-prediction evaluate --model MODEL_FILE --data DATA_FILE [OPTIONS] | |
| ``` | |
| **参数:** | |
| - `--model, -m`: 模型文件路径(必需) | |
| - `--data, -d`: 测试数据文件路径(必需) | |
| - `--preprocessor, -p`: 预处理器文件路径 | |
| - `--output, -o`: 评估结果输出路径 | |
| - `--report`: 生成详细报告文件路径 | |
| - `--metrics`: 评估指标列表(默认: mse mae r2) | |
| - `--batch-size`: 批次大小(默认: 32) | |
| - `--device`: 计算设备 | |
| - `--verbose, -v`: 详细输出 | |
| - `--log-level`: 日志级别 | |
| **示例:** | |
| ```bash | |
| # 基础评估 | |
| emotion-prediction evaluate --model model.pth --data test_data.csv | |
| # 生成详细报告 | |
| emotion-prediction evaluate --model model.pth --data test_data.csv --report report.html | |
| ``` | |
| #### 基准测试命令 | |
| ```bash | |
| emotion-prediction benchmark --model MODEL_FILE [OPTIONS] | |
| ``` | |
| **参数:** | |
| - `--model, -m`: 模型文件路径(必需) | |
| - `--preprocessor, -p`: 预处理器文件路径 | |
| - `--num-samples`: 测试样本数量(默认: 1000) | |
| - `--batch-size`: 批次大小(默认: 32) | |
| - `--device`: 计算设备 | |
| - `--report`: 生成性能报告文件路径 | |
| - `--warmup`: 预热轮数(默认: 10) | |
| - `--verbose, -v`: 详细输出 | |
| - `--log-level`: 日志级别 | |
| **示例:** | |
| ```bash | |
| # 标准基准测试 | |
| emotion-prediction benchmark --model model.pth | |
| # 自定义测试 | |
| emotion-prediction benchmark --model model.pth --num-samples 5000 --batch-size 64 | |
| ``` | |
| ## 配置文件API | |
| ### 模型配置 | |
| 模型配置文件使用YAML格式,支持以下参数: | |
| ```yaml | |
| # 模型基本信息 | |
| model_info: | |
| name: str # 模型名称 | |
| type: str # 模型类型 | |
| version: str # 模型版本 | |
| # 输入输出维度 | |
| dimensions: | |
| input_dim: int # 输入维度 | |
| output_dim: int # 输出维度 | |
| # 网络架构 | |
| architecture: | |
| hidden_layers: | |
| - size: int # 层大小 | |
| activation: str # 激活函数 | |
| dropout: float # Dropout率 | |
| output_layer: | |
| activation: str # 输出激活函数 | |
| use_batch_norm: bool # 是否使用批归一化 | |
| use_layer_norm: bool # 是否使用层归一化 | |
| # 初始化参数 | |
| initialization: | |
| weight_init: str # 权重初始化方法 | |
| bias_init: str # 偏置初始化方法 | |
| # 正则化 | |
| regularization: | |
| weight_decay: float # L2正则化系数 | |
| dropout_config: | |
| type: str # Dropout类型 | |
| rate: float # Dropout率 | |
| ``` | |
| ### 训练配置 | |
| 训练配置文件支持以下参数: | |
| ```yaml | |
| # 训练信息 | |
| training_info: | |
| experiment_name: str # 实验名称 | |
| description: str # 实验描述 | |
| seed: int # 随机种子 | |
| # 训练超参数 | |
| training: | |
| optimizer: | |
| type: str # 优化器类型 | |
| learning_rate: float # 学习率 | |
| weight_decay: float # 权重衰减 | |
| scheduler: | |
| type: str # 调度器类型 | |
| epochs: int # 训练轮数 | |
| early_stopping: | |
| enabled: bool # 是否启用早停 | |
| patience: int # 耐心值 | |
| min_delta: float # 最小改善 | |
| ``` | |
| ## 异常处理 | |
| 项目定义了以下自定义异常: | |
| ### `ModelLoadError` | |
| 模型加载错误。 | |
| ### `DataPreprocessingError` | |
| 数据预处理错误。 | |
| ### `InferenceError` | |
| 推理过程错误。 | |
| ### `ConfigurationError` | |
| 配置文件错误。 | |
| **示例:** | |
| ```python | |
| from src.utils.exceptions import ModelLoadError, InferenceError | |
| try: | |
| model = PADPredictor.load_model("invalid_model.pth") | |
| except ModelLoadError as e: | |
| print(f"模型加载失败: {e}") | |
| try: | |
| result = engine.predict(invalid_input) | |
| except InferenceError as e: | |
| print(f"推理失败: {e}") | |
| ``` | |
| ## 日志系统 | |
| 项目使用结构化日志系统: | |
| ```python | |
| from src.utils.logger import setup_logger | |
| import logging | |
| # 设置日志 | |
| setup_logger(level='INFO', log_file='training.log') | |
| logger = logging.getLogger(__name__) | |
| # 使用日志 | |
| logger.info("训练开始") | |
| logger.debug(f"批次大小: {batch_size}") | |
| logger.warning("检测到潜在的过拟合") | |
| logger.error("训练过程中发生错误") | |
| ``` | |
| ## 类型提示 | |
| 项目完全支持类型提示,所有公共API都有详细的类型注解: | |
| ```python | |
| from typing import Dict, List, Optional, Union, Tuple | |
| import numpy as np | |
| import torch | |
| def predict_emotion( | |
| input_data: Union[List[float], np.ndarray], | |
| model_path: str, | |
| preprocessor_path: Optional[str] = None, | |
| device: str = 'auto' | |
| ) -> Dict[str, Any]: | |
| """ | |
| 预测情绪变化 | |
| Args: | |
| input_data: 输入数据,7维向量 | |
| model_path: 模型文件路径 | |
| preprocessor_path: 预处理器文件路径 | |
| device: 计算设备 | |
| Returns: | |
| 包含预测结果的字典 | |
| Raises: | |
| InferenceError: 推理失败时抛出 | |
| """ | |
| pass | |
| ``` | |
| --- | |
| 更多详细信息请参考源代码和示例文件。如有问题,请查看[故障排除指南](TUTORIAL.md#故障排除)或提交Issue。 |