File size: 16,143 Bytes
0a6452f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
"""
推理脚本
Inference Script for emotion and physiological state prediction

该脚本实现了完整的推理功能,支持:
- 单样本和批量推理
- 多种输入格式(JSON、CSV、命令行参数)
- 多种输出格式(JSON、CSV、文本)
- 输入数据验证和预处理
- 性能基准测试
"""

import argparse
import json
import csv
import sys
import os
import logging
import time
from pathlib import Path
from typing import List, Dict, Any, Union, Optional
import numpy as np
import pandas as pd

# 添加项目根目录到Python路径
project_root = Path(__file__).parent.parent.parent
sys.path.insert(0, str(project_root))

from src.utils.inference_engine import InferenceEngine, create_inference_engine
from src.utils.logger import setup_logger


def parse_command_line_input(args: List[str]) -> np.ndarray:
    """
    解析命令行输入数据

    Args:
        args: 命令行参数列表

    Returns:
        输入数据数组(7维,将被推理引擎增强到10维)
    """
    if len(args) != 7:
        raise ValueError(f"需要7个输入参数,但提供了{len(args)}个。参数顺序:user_pleasure, user_arousal, user_dominance, vitality, current_pleasure, current_arousal, current_dominance")
    
    try:
        data = np.array([float(arg) for arg in args], dtype=np.float32)
        return data.reshape(1, -1)
    except ValueError as e:
        raise ValueError(f"输入参数必须是数字: {e}")


def load_json_input(input_path: str) -> np.ndarray:
    """
    从JSON文件加载输入数据
    
    Args:
        input_path: JSON文件路径
        
    Returns:
        输入数据数组
    """
    try:
        with open(input_path, 'r', encoding='utf-8') as f:
            data = json.load(f)
        
        # 处理不同的JSON格式
        if isinstance(data, dict):
            if 'data' in data:
                # 格式: {"data": [[...], [...], ...]}
                input_data = np.array(data['data'], dtype=np.float32)
            elif 'features' in data:
                # 格式: {"features": [[...], [...], ...]}
                input_data = np.array(data['features'], dtype=np.float32)
            else:
                # 格式: {"user_pleasure": ..., "user_arousal": ..., ...}
                single_sample = [
                    data.get('user_pleasure', 0),
                    data.get('user_arousal', 0),
                    data.get('user_dominance', 0),
                    data.get('vitality', 0),
                    data.get('current_pleasure', 0),
                    data.get('current_arousal', 0),
                    data.get('current_dominance', 0)
                ]
                input_data = np.array(single_sample, dtype=np.float32).reshape(1, -1)
        
        elif isinstance(data, list):
            # 格式: [[...], [...], ...] 或 [...]
            if len(data) > 0 and isinstance(data[0], list):
                input_data = np.array(data, dtype=np.float32)
            else:
                input_data = np.array(data, dtype=np.float32).reshape(1, -1)
        
        else:
            raise ValueError("不支持的JSON格式")
        
        # 验证数据维度
        if input_data.ndim == 1:
            input_data = input_data.reshape(1, -1)
        elif input_data.ndim > 2:
            raise ValueError("输入数据应该是1维或2维的")
        
        if input_data.shape[1] != 7:
            raise ValueError(f"输入数据应该有7个特征,但得到{input_data.shape[1]}个")
        
        return input_data
        
    except Exception as e:
        raise ValueError(f"无法解析JSON文件 {input_path}: {e}")


def load_csv_input(input_path: str, 
                  feature_columns: Optional[List[str]] = None) -> np.ndarray:
    """
    从CSV文件加载输入数据
    
    Args:
        input_path: CSV文件路径
        feature_columns: 特征列名列表
        
    Returns:
        输入数据数组
    """
    try:
        # 默认列名
        default_columns = [
            'user_pleasure', 'user_arousal', 'user_dominance',
            'vitality', 'current_pleasure', 'current_arousal', 'current_dominance'
        ]
        
        # 读取CSV文件
        if feature_columns:
            df = pd.read_csv(input_path, usecols=feature_columns)
        else:
            df = pd.read_csv(input_path)
            
            # 自动检测列名
            if len(df.columns) >= 7:
                df = df.iloc[:, :7]  # 使用前7列
                df.columns = default_columns
            elif len(df.columns) == 7:
                df.columns = default_columns
            else:
                raise ValueError(f"CSV文件应该至少有7列,但得到{len(df.columns)}列")
        
        # 转换为numpy数组
        input_data = df.values.astype(np.float32)
        
        return input_data
        
    except Exception as e:
        raise ValueError(f"无法解析CSV文件 {input_path}: {e}")


def save_json_output(results: List[Dict[str, Any]], output_path: str) -> None:
    """
    保存结果为JSON格式
    
    Args:
        results: 推理结果列表
        output_path: 输出路径
    """
    output_data = {
        'predictions': results,
        'metadata': {
            'total_samples': len(results),
            'output_format': 'json',
            'description': 'Emotion and physiological state prediction results'
        }
    }
    
    with open(output_path, 'w', encoding='utf-8') as f:
        json.dump(output_data, f, indent=2, ensure_ascii=False)


def save_csv_output(results: List[Dict[str, Any]], output_path: str) -> None:
    """
    保存结果为CSV格式
    
    Args:
        results: 推理结果列表
        output_path: 输出路径
    """
    if not results:
        return
    
    # 展开结果数据
    rows = []
    for i, result in enumerate(results):
        row = {
            'sample_id': i,
            'delta_pleasure': result['delta_pad'][0],
            'delta_arousal': result['delta_pad'][1],
            'delta_dominance': result['delta_pad'][2],
            'delta_pressure': result['delta_pressure'][0],
            'confidence': result['confidence'][0],
            'inference_time': result.get('inference_time', 0)
        }
        rows.append(row)
    
    # 写入CSV文件
    with open(output_path, 'w', newline='', encoding='utf-8') as f:
        fieldnames = rows[0].keys()
        writer = csv.DictWriter(f, fieldnames=fieldnames)
        writer.writeheader()
        writer.writerows(rows)


def save_text_output(results: List[Dict[str, Any]], output_path: str) -> None:
    """
    保存结果为文本格式
    
    Args:
        results: 推理结果列表
        output_path: 输出路径
    """
    with open(output_path, 'w', encoding='utf-8') as f:
        f.write("情绪与生理状态变化预测结果\n")
        f.write("=" * 50 + "\n\n")
        
        for i, result in enumerate(results):
            f.write(f"样本 {i+1}:\n")
            f.write(f"  ΔPAD (情绪变化):\n")
            f.write(f"    快乐度变化: {result['delta_pad'][0]:.6f}\n")
            f.write(f"    激活度变化: {result['delta_pad'][1]:.6f}\n")
            f.write(f"    支配度变化: {result['delta_pad'][2]:.6f}\n")
            f.write(f"  Δ压力: {result['delta_pressure'][0]:.6f}\n")
            f.write(f"  置信度: {result['confidence'][0]:.6f}\n")
            f.write(f"  推理时间: {result.get('inference_time', 0):.6f}秒\n")
            f.write("-" * 30 + "\n")


def print_results(results: List[Dict[str, Any]], verbose: bool = True) -> None:
    """
    打印推理结果
    
    Args:
        results: 推理结果列表
        verbose: 是否显示详细信息
    """
    if not verbose:
        # 简洁输出
        for i, result in enumerate(results):
            print(f"样本{i+1}: ΔPAD={result['delta_pad']}, Δ压力={result['delta_pressure'][0]:.4f}, 置信度={result['confidence'][0]:.4f}")
        return
    
    # 详细输出
    print("\n情绪与生理状态变化预测结果")
    print("=" * 60)
    
    for i, result in enumerate(results):
        print(f"\n样本 {i+1}:")
        print(f"  ΔPAD (情绪变化):")
        print(f"    快乐度变化: {result['delta_pad'][0]:+.6f}")
        print(f"    激活度变化: {result['delta_pad'][1]:+.6f}")
        print(f"    支配度变化: {result['delta_pad'][2]:+.6f}")
        print(f"  Δ压力: {result['delta_pressure'][0]:+.6f}")
        print(f"  置信度: {result['confidence'][0]:.6f}")
        if 'inference_time' in result:
            print(f"  推理时间: {result['inference_time']*1000:.2f}ms")
        print("-" * 40)


def run_benchmark(engine: InferenceEngine, 
                 num_samples: int = 1000,
                 batch_size: int = 32) -> None:
    """
    运行性能基准测试
    
    Args:
        engine: 推理引擎
        num_samples: 测试样本数量
        batch_size: 批次大小
    """
    print(f"\n运行性能基准测试...")
    print(f"测试样本数: {num_samples}")
    print(f"批次大小: {batch_size}")
    
    try:
        stats = engine.benchmark(num_samples, batch_size)
        
        print("\n基准测试结果:")
        print(f"  总样本数: {stats['total_samples']}")
        print(f"  总时间: {stats['total_time']:.4f}秒")
        print(f"  吞吐量: {stats['throughput']:.2f} 样本/秒")
        print(f"  平均延迟: {stats['avg_latency']:.2f}ms")
        print(f"  最小延迟: {stats['min_time']*1000:.2f}ms")
        print(f"  最大延迟: {stats['max_time']*1000:.2f}ms")
        print(f"  P95延迟: {stats['p95_latency']:.2f}ms")
        print(f"  P99延迟: {stats['p99_latency']:.2f}ms")
        
    except Exception as e:
        print(f"基准测试失败: {e}")


def main():
    """主函数"""
    parser = argparse.ArgumentParser(
        description="情绪与生理状态变化预测推理脚本",
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
输入格式说明:
  1. 命令行参数: --input-cli 0.5 0.3 -0.2 80 0.1 0.4 -0.1
  2. JSON文件: --input-json data.json
  3. CSV文件: --input-csv data.csv

输出格式说明:
  - JSON: 结构化数据,便于程序处理
  - CSV: 表格数据,便于Excel处理
  - TXT: 人类可读的文本格式

示例用法:
  # 单样本推理
  python inference.py --model model.pth --input-cli 0.5 0.3 -0.2 80 0.1 0.4 -0.1
  
  # 批量推理
  python inference.py --model model.pth --input-json batch_data.json --output-json results.json
  
  # 基准测试
  python inference.py --model model.pth --benchmark --num-samples 1000
        """
    )
    
    # 模型相关参数
    parser.add_argument('--model', '-m', type=str, required=True,
                       help='模型文件路径 (.pth)')
    parser.add_argument('--preprocessor', '-p', type=str,
                       help='预处理器文件路径')
    parser.add_argument('--device', type=str, choices=['auto', 'cpu', 'cuda'], 
                       default='auto', help='计算设备')
    
    # 输入相关参数
    input_group = parser.add_mutually_exclusive_group(required=True)
    input_group.add_argument('--input-cli', nargs='+', metavar='VALUE',
                            help='命令行输入 (7个数值: user_pleasure user_arousal user_dominance vitality current_pleasure current_arousal current_dominance)')
    input_group.add_argument('--input-json', type=str, metavar='FILE',
                            help='JSON输入文件路径')
    input_group.add_argument('--input-csv', type=str, metavar='FILE',
                            help='CSV输入文件路径')
    
    # 输出相关参数
    parser.add_argument('--output-json', type=str, metavar='FILE',
                       help='JSON输出文件路径')
    parser.add_argument('--output-csv', type=str, metavar='FILE',
                       help='CSV输出文件路径')
    parser.add_argument('--output-txt', type=str, metavar='FILE',
                       help='文本输出文件路径')
    parser.add_argument('--quiet', '-q', action='store_true',
                       help='静默模式,不打印结果')
    
    # 推理参数
    parser.add_argument('--batch-size', type=int, default=32,
                       help='批量推理的批次大小')
    
    # 基准测试参数
    parser.add_argument('--benchmark', action='store_true',
                       help='运行性能基准测试')
    parser.add_argument('--num-samples', type=int, default=1000,
                       help='基准测试的样本数量')
    
    # 其他参数
    parser.add_argument('--verbose', '-v', action='store_true',
                       help='详细输出')
    parser.add_argument('--log-level', type=str, choices=['DEBUG', 'INFO', 'WARNING', 'ERROR'],
                       default='INFO', help='日志级别')
    
    args = parser.parse_args()
    
    # 设置日志
    setup_logger(level=args.log_level)
    logger = logging.getLogger(__name__)
    
    try:
        # 创建推理引擎
        logger.info("初始化推理引擎...")
        engine = create_inference_engine(
            model_path=args.model,
            preprocessor_path=args.preprocessor,
            device=args.device
        )
        
        # 打印模型信息
        if args.verbose:
            model_info = engine.get_model_info()
            print(f"\n模型信息:")
            print(f"  设备: {model_info['device']}")
            print(f"  总参数量: {model_info['total_parameters']:,}")
            print(f"  输入维度: {model_info['input_dim']}")
            print(f"  输出维度: {model_info['output_dim']}")
        
        # 运行基准测试
        if args.benchmark:
            run_benchmark(engine, args.num_samples, args.batch_size)
            return
        
        # 加载输入数据
        logger.info("加载输入数据...")
        if args.input_cli:
            input_data = parse_command_line_input(args.input_cli)
        elif args.input_json:
            input_data = load_json_input(args.input_json)
        elif args.input_csv:
            input_data = load_csv_input(args.input_csv)
        
        logger.info(f"加载了 {len(input_data)} 个样本")
        
        # 执行推理
        logger.info("执行推理...")
        start_time = time.time()
        
        if len(input_data) == 1:
            # 单样本推理
            result = engine.predict(input_data[0])
            results = [result.to_dict()]
        else:
            # 批量推理
            results = engine.predict_batch(input_data, args.batch_size)
            results = [result.to_dict() for result in results]
        
        total_time = time.time() - start_time
        
        logger.info(f"推理完成,总时间: {total_time:.4f}秒")
        
        # 打印结果
        if not args.quiet:
            print_results(results, verbose=args.verbose)
        
        # 保存结果
        if args.output_json:
            save_json_output(results, args.output_json)
            print(f"结果已保存到: {args.output_json}")
        
        if args.output_csv:
            save_csv_output(results, args.output_csv)
            print(f"结果已保存到: {args.output_csv}")
        
        if args.output_txt:
            save_text_output(results, args.output_txt)
            print(f"结果已保存到: {args.output_txt}")
        
        # 性能统计
        if args.verbose:
            stats = engine.get_performance_stats()
            print(f"\n性能统计:")
            print(f"  总推理次数: {stats['total_inferences']}")
            print(f"  平均时间: {stats['avg_time']*1000:.2f}ms")
            print(f"  最小时间: {stats['min_time']*1000:.2f}ms")
            print(f"  最大时间: {stats['max_time']*1000:.2f}ms")
        
    except Exception as e:
        logger.error(f"推理失败: {e}")
        if args.verbose:
            import traceback
            traceback.print_exc()
        sys.exit(1)


if __name__ == "__main__":
    main()