File size: 17,444 Bytes
0a6452f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
"""
简化的预测CLI工具
Simplified CLI Tool for emotion and physiological state prediction

该工具提供了简化的命令行界面,支持:
- 交互式输入
- 批量文件处理
- 清晰的输出格式和解释
- 快速预测模式
"""

import argparse
import sys
import os
import json
import csv
import logging
from pathlib import Path
from typing import List, Dict, Any, Optional
import numpy as np

# 添加项目根目录到Python路径
project_root = Path(__file__).parent.parent.parent
sys.path.insert(0, str(project_root))

from src.utils.inference_engine import create_inference_engine
from src.utils.logger import setup_logger


class PredictCLI:
    """简化的预测CLI类"""
    
    def __init__(self, model_path: str, preprocessor_path: Optional[str] = None):
        """
        初始化预测CLI
        
        Args:
            model_path: 模型文件路径
            preprocessor_path: 预处理器文件路径
        """
        self.logger = logging.getLogger(__name__)
        
        try:
            # 创建推理引擎
            self.engine = create_inference_engine(
                model_path=model_path,
                preprocessor_path=preprocessor_path,
                device='auto'
            )
            
            # 获取模型信息
            self.model_info = self.engine.get_model_info()
            self.logger.info("预测CLI初始化成功")
            
        except Exception as e:
            self.logger.error(f"预测CLI初始化失败: {e}")
            raise
    
    def interactive_mode(self):
        """交互式预测模式"""
        print("\n" + "="*60)
        print("情绪与生理状态变化预测工具 - 交互式模式")
        print("="*60)
        print("\n请输入以下7个参数:")
        print("1. 用户快乐度 (User Pleasure): [-1.0, 1.0]")
        print("2. 用户激活度 (User Arousal): [-1.0, 1.0]")
        print("3. 用户支配度 (User Dominance): [-1.0, 1.0]")
        print("4. 活力值 (Vitality): [0.0, 100.0]")
        print("5. 当前快乐度 (Current Pleasure): [-1.0, 1.0]")
        print("6. 当前激活度 (Current Arousal): [-1.0, 1.0]")
        print("7. 当前支配度 (Current Dominance): [-1.0, 1.0]")
        print("\n输入 'quit' 退出交互模式")
        print("-"*60)
        
        while True:
            try:
                print("\n请输入7个数值 (用空格分隔):")
                user_input = input("> ").strip()
                
                if user_input.lower() in ['quit', 'exit', 'q']:
                    print("退出交互模式")
                    break
                
                # 解析输入
                values = user_input.split()
                if len(values) != 7:
                    print(f"错误: 需要输入7个数值,但得到{len(values)}个")
                    continue
                
                try:
                    input_data = np.array([float(v) for v in values], dtype=np.float32)
                except ValueError:
                    print("错误: 输入必须是数字")
                    continue
                
                # 验证输入范围
                if not self._validate_input_ranges(input_data):
                    continue
                
                # 执行预测
                result = self.predict_single(input_data)
                
                # 显示结果
                self._display_result(result, input_data)
                
            except KeyboardInterrupt:
                print("\n\n用户中断,退出交互模式")
                break
            except Exception as e:
                print(f"预测出错: {e}")
    
    def _validate_input_ranges(self, input_data: np.ndarray) -> bool:
        """验证输入数据范围"""
        user_pad = input_data[:3]
        vitality = input_data[3]
        current_pad = input_data[4:]
        
        # 检查PAD值范围
        if np.any(np.abs(user_pad) > 1.5):
            print("警告: 用户PAD值超出正常范围 [-1.0, 1.0]")
            response = input("是否继续? (y/n): ").strip().lower()
            if response != 'y':
                return False
        
        if np.any(np.abs(current_pad) > 1.5):
            print("警告: 当前PAD值超出正常范围 [-1.0, 1.0]")
            response = input("是否继续? (y/n): ").strip().lower()
            if response != 'y':
                return False
        
        # 检查活力值范围
        if not (0 <= vitality <= 150):
            print("警告: 活力值超出正常范围 [0.0, 100.0]")
            response = input("是否继续? (y/n): ").strip().lower()
            if response != 'y':
                return False
        
        return True
    
    def predict_single(self, input_data: np.ndarray):
        """预测单个样本"""
        try:
            result = self.engine.predict(input_data)
            return result
        except Exception as e:
            raise RuntimeError(f"预测失败: {e}")
    
    def _display_result(self, result, input_data: np.ndarray):
        """显示预测结果"""
        print("\n" + "="*50)
        print("预测结果")
        print("="*50)
        
        # 显示输入信息
        print(f"\n输入信息:")
        print(f"  用户PAD: 快乐度={input_data[0]:+.3f}, 激活度={input_data[1]:+.3f}, 支配度={input_data[2]:+.3f}")
        print(f"  活力值: {input_data[3]:.1f}")
        print(f"  当前PAD: 快乐度={input_data[4]:+.3f}, 激活度={input_data[5]:+.3f}, 支配度={input_data[6]:+.3f}")
        
        # 显示预测结果
        delta_pad = result.delta_pad[0]
        delta_pressure = result.delta_pressure[0]
        confidence = result.confidence[0]
        
        print(f"\n预测变化:")
        print(f"  情绪变化 (ΔPAD):")
        print(f"    快乐度变化: {delta_pad:+.6f} {'↗' if delta_pad > 0 else '↘' if delta_pad < 0 else '→'}")
        print(f"    激活度变化: {delta_pad:+.6f} {'↗' if delta_pad > 0 else '↘' if delta_pad < 0 else '→'}")
        print(f"    支配度变化: {delta_pad:+.6f} {'↗' if delta_pad > 0 else '↘' if delta_pad < 0 else '→'}")
        print(f"  压力变化: {delta_pressure:+.6f} {'↗' if delta_pressure > 0 else '↘' if delta_pressure < 0 else '→'}")
        print(f"  预测置信度: {confidence:.6f} ({confidence*100:.1f}%)")
        
        # 提供解释
        self._provide_interpretation(delta_pad, delta_pressure, confidence)
        
        # 显示性能信息
        print(f"\n性能信息:")
        print(f"  推理时间: {result.inference_time*1000:.2f}ms")
        
        print("="*50)
    
    def _provide_interpretation(self, delta_pad: np.ndarray, delta_pressure: float, confidence: float):
        """提供预测结果解释"""
        print(f"\n结果解释:")
        
        # 情绪变化解释
        pleasure_change = delta_pad[0]
        arousal_change = delta_pad[1]
        dominance_change = delta_pad[2]
        
        if abs(pleasure_change) > 0.1:
            if pleasure_change > 0:
                print("  • 情绪趋向积极愉快")
            else:
                print("  • 情绪趋向消极低落")
        
        if abs(arousal_change) > 0.1:
            if arousal_change > 0:
                print("  • 激活度提升,趋向兴奋")
            else:
                print("  • 激活度降低,趋向平静")
        
        if abs(dominance_change) > 0.1:
            if dominance_change > 0:
                print("  • 支配感增强,趋向自信")
            else:
                print("  • 支配感减弱,趋向顺从")
        
        # 压力变化解释
        if abs(delta_pressure) > 0.05:
            if delta_pressure > 0:
                print("  • 压力水平可能上升")
            else:
                print("  • 压力水平可能下降")
        
        # 置信度解释
        if confidence > 0.8:
            print("  • 预测置信度很高")
        elif confidence > 0.6:
            print("  • 预测置信度中等")
        else:
            print("  • 预测置信度较低,结果可能不太准确")
    
    def batch_predict(self, input_file: str, output_file: Optional[str] = None):
        """批量预测"""
        print(f"\n批量预测模式")
        print(f"输入文件: {input_file}")
        
        try:
            # 加载输入数据
            input_data = self._load_batch_input(input_file)
            print(f"加载了 {len(input_data)} 个样本")
            
            # 执行批量预测
            print("执行批量预测...")
            results = self.engine.predict_batch(input_data)
            
            # 处理结果
            processed_results = []
            for i, result in enumerate(results):
                processed_results.append({
                    'sample_id': i + 1,
                    'delta_pleasure': float(result.delta_pad[0]),
                    'delta_arousal': float(result.delta_pad[1]),
                    'delta_dominance': float(result.delta_pad[2]),
                    'delta_pressure': float(result.delta_pressure[0]),
                    'confidence': float(result.confidence[0]),
                    'inference_time': float(result.inference_time)
                })
            
            # 保存结果
            if output_file:
                self._save_batch_results(processed_results, output_file)
                print(f"结果已保存到: {output_file}")
            else:
                self._display_batch_summary(processed_results)
            
        except Exception as e:
            print(f"批量预测失败: {e}")
            raise
    
    def _load_batch_input(self, input_file: str) -> np.ndarray:
        """加载批量输入数据"""
        file_path = Path(input_file)
        
        if not file_path.exists():
            raise FileNotFoundError(f"输入文件不存在: {input_file}")
        
        if file_path.suffix.lower() == '.json':
            return self._load_json_batch(file_path)
        elif file_path.suffix.lower() == '.csv':
            return self._load_csv_batch(file_path)
        else:
            raise ValueError(f"不支持的文件格式: {file_path.suffix}")
    
    def _load_json_batch(self, file_path: Path) -> np.ndarray:
        """加载JSON格式的批量数据"""
        with open(file_path, 'r', encoding='utf-8') as f:
            data = json.load(f)
        
        if isinstance(data, list):
            return np.array(data, dtype=np.float32)
        elif isinstance(data, dict) and 'data' in data:
            return np.array(data['data'], dtype=np.float32)
        else:
            raise ValueError("JSON格式不正确,需要数据数组或包含'data'字段的对象")
    
    def _load_csv_batch(self, file_path: Path) -> np.ndarray:
        """加载CSV格式的批量数据"""
        import pandas as pd
        
        df = pd.read_csv(file_path)
        
        # 检查列数
        if len(df.columns) < 7:
            raise ValueError(f"CSV文件至少需要7列,但得到{len(df.columns)}列")
        
        # 使用前7列
        data = df.iloc[:, :7].values
        return data.astype(np.float32)
    
    def _save_batch_results(self, results: List[Dict[str, Any]], output_file: str):
        """保存批量结果"""
        file_path = Path(output_file)
        
        if file_path.suffix.lower() == '.json':
            with open(file_path, 'w', encoding='utf-8') as f:
                json.dump({
                    'results': results,
                    'summary': {
                        'total_samples': len(results),
                        'avg_confidence': np.mean([r['confidence'] for r in results]),
                        'avg_inference_time': np.mean([r['inference_time'] for r in results])
                    }
                }, f, indent=2, ensure_ascii=False)
        
        elif file_path.suffix.lower() == '.csv':
            with open(file_path, 'w', newline='', encoding='utf-8') as f:
                if results:
                    fieldnames = results[0].keys()
                    writer = csv.DictWriter(f, fieldnames=fieldnames)
                    writer.writeheader()
                    writer.writerows(results)
        
        else:
            raise ValueError(f"不支持的输出格式: {file_path.suffix}")
    
    def _display_batch_summary(self, results: List[Dict[str, Any]]):
        """显示批量预测摘要"""
        if not results:
            print("没有预测结果")
            return
        
        print(f"\n批量预测摘要:")
        print(f"总样本数: {len(results)}")
        
        # 统计信息
        confidences = [r['confidence'] for r in results]
        inference_times = [r['inference_time'] for r in results]
        
        print(f"平均置信度: {np.mean(confidences):.4f}")
        print(f"置信度范围: [{np.min(confidences):.4f}, {np.max(confidences):.4f}]")
        print(f"平均推理时间: {np.mean(inference_times)*1000:.2f}ms")
        print(f"总推理时间: {np.sum(inference_times):.4f}s")
        
        # 显示前几个结果
        print(f"\n前5个样本结果:")
        for i, result in enumerate(results[:5]):
            print(f"样本{result['sample_id']}: "
                  f"ΔPAD=[{result['delta_pleasure']:+.3f}, {result['delta_arousal']:+.3f}, {result['delta_dominance']:+.3f}], "
                  f"Δ压力={result['delta_pressure']:+.3f}, "
                  f"置信度={result['confidence']:.3f}")
    
    def quick_predict(self, values: List[float]):
        """快速预测模式"""
        if len(values) != 7:
            raise ValueError(f"需要7个输入值,但得到{len(values)}个")
        
        input_data = np.array(values, dtype=np.float32)
        
        try:
            result = self.predict_single(input_data)
            
            # 简洁输出
            delta_pad = result.delta_pad[0]
            delta_pressure = result.delta_pressure[0]
            confidence = result.confidence[0]
            
            print(f"ΔPAD: [{delta_pad[0]:+.4f}, {delta_pad[1]:+.4f}, {delta_pad[2]:+.4f}], "
                  f"Δ压力: {delta_pressure:+.4f}, "
                  f"置信度: {confidence:.4f}")
            
        except Exception as e:
            print(f"预测失败: {e}")
            return False
        
        return True


def main():
    """主函数"""
    parser = argparse.ArgumentParser(
        description="情绪与生理状态变化预测工具",
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
使用示例:
  # 交互式模式
  python predict.py --model model.pth
  
  # 快速预测
  python predict.py --model model.pth --quick 0.5 0.3 -0.2 80 0.1 0.4 -0.1
  
  # 批量预测
  python predict.py --model model.pth --batch input.json --output results.json
        """
    )
    
    # 必需参数
    parser.add_argument('--model', '-m', type=str, required=True,
                       help='模型文件路径 (.pth)')
    parser.add_argument('--preprocessor', '-p', type=str,
                       help='预处理器文件路径')
    
    # 模式选择
    mode_group = parser.add_mutually_exclusive_group()
    mode_group.add_argument('--interactive', '-i', action='store_true',
                           help='交互式模式')
    mode_group.add_argument('--quick', nargs=7, type=float, metavar='VALUE',
                           help='快速预测模式 (7个数值)')
    mode_group.add_argument('--batch', type=str, metavar='FILE',
                           help='批量预测模式 (输入文件)')
    
    # 输出选项
    parser.add_argument('--output', '-o', type=str,
                       help='输出文件路径 (批量模式)')
    parser.add_argument('--verbose', '-v', action='store_true',
                       help='详细输出')
    parser.add_argument('--log-level', type=str, 
                       choices=['DEBUG', 'INFO', 'WARNING', 'ERROR'],
                       default='WARNING', help='日志级别')
    
    args = parser.parse_args()
    
    # 设置日志
    setup_logger(level=args.log_level)
    
    try:
        # 创建预测CLI
        cli = PredictCLI(args.model, args.preprocessor)
        
        # 显示模型信息
        if args.verbose:
            print(f"模型信息:")
            print(f"  设备: {cli.model_info['device']}")
            print(f"  总参数量: {cli.model_info['total_parameters']:,}")
            print(f"  输入维度: {cli.model_info['input_dim']}")
            print(f"  输出维度: {cli.model_info['output_dim']}")
        
        # 根据模式执行
        if args.interactive or (not args.quick and not args.batch):
            # 默认进入交互模式
            cli.interactive_mode()
        
        elif args.quick:
            # 快速预测
            success = cli.quick_predict(args.quick)
            if not success:
                sys.exit(1)
        
        elif args.batch:
            # 批量预测
            cli.batch_predict(args.batch, args.output)
        
    except Exception as e:
        print(f"错误: {e}")
        if args.verbose:
            import traceback
            traceback.print_exc()
        sys.exit(1)


if __name__ == "__main__":
    main()