PiloBi commited on
Commit
cc427cc
·
verified ·
1 Parent(s): 67c5a2e

Upload RL_infer.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. RL_infer.py +880 -0
RL_infer.py ADDED
@@ -0,0 +1,880 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # import json
2
+ # import re
3
+ # from PIL import Image
4
+ # from transformers import AutoModelForVision2Seq, AutoProcessor
5
+ # import torch
6
+ # import os
7
+ # from qwen_vl_utils import process_vision_info
8
+ # # --- 1. 辅助函数 ---
9
+
10
+ # def load_test_data(file_path):
11
+ # """
12
+ # 根据文件扩展名自动加载 .json 或 .jsonl 文件。
13
+ # 对于 .json 文件,尝试不同的常见键来查找样本列表。
14
+ # """
15
+ # _, ext = os.path.splitext(file_path)
16
+ # ext = ext.lower()
17
+
18
+ # test_samples = []
19
+ # if ext == '.jsonl':
20
+ # print(f"Loading data from JSON Lines file: {file_path}")
21
+ # with open(file_path, 'r', encoding='utf-8') as f:
22
+ # for i, line in enumerate(f):
23
+ # line = line.strip()
24
+ # if not line:
25
+ # continue
26
+ # try:
27
+ # test_samples.append(json.loads(line))
28
+ # except json.JSONDecodeError as e:
29
+ # print(f"Warning: Skipping invalid JSON line {i+1} in {file_path}: {e}")
30
+ # elif ext == '.json':
31
+ # print(f"Loading data from JSON file: {file_path}")
32
+ # try:
33
+ # with open(file_path, 'r', encoding='utf-8') as f:
34
+ # data = json.load(f)
35
+
36
+ # if isinstance(data, list):
37
+ # print(" Detected JSON array format.")
38
+ # test_samples = data
39
+ # elif isinstance(data, dict):
40
+ # print(" Detected JSON object format. Searching for samples...")
41
+ # possible_keys = ['data', 'samples', 'instances', 'items', 'conversations', 'messages']
42
+ # found = False
43
+ # for key in possible_keys:
44
+ # if key in data and isinstance(data[key], list) and len(data[key]) > 0:
45
+ # # 简单检查列表第一个元素是否像样本 (dict with 'messages')
46
+ # first_item = data[key][0]
47
+ # if isinstance(first_item, dict) and 'messages' in first_item:
48
+ # print(f" Found samples under key '{key}'.")
49
+ # test_samples = data[key]
50
+ # found = True
51
+ # break
52
+ # if not found:
53
+ # # 启发式:查找第一个值是列表且列表元素是字典的键
54
+ # for key, value in data.items():
55
+ # if isinstance(value, list) and len(value) > 0 and isinstance(value[0], dict) and 'messages' in value[0]:
56
+ # print(f" Found samples under key '{key}' (heuristic).")
57
+ # test_samples = value
58
+ # found = True
59
+ # break
60
+ # if not found:
61
+ # print(f" Error: Could not find a list of samples in the JSON object. Keys found: {list(data.keys())}")
62
+ # else:
63
+ # print(f" Error: Unexpected JSON structure. Root element type: {type(data)}")
64
+
65
+ # except json.JSONDecodeError as e:
66
+ # print(f"Error: Failed to decode JSON from {file_path}: {e}")
67
+ # except Exception as e:
68
+ # print(f"Error: An unexpected error occurred while loading {file_path}: {e}")
69
+ # else:
70
+ # print(f"Error: Unsupported file extension '{ext}'. Please provide a .json or .jsonl file.")
71
+
72
+ # print(f"Loaded {len(test_samples)} samples.")
73
+
74
+ # # 验证加载的样本结构
75
+ # if test_samples and isinstance(test_samples, list):
76
+ # print("Performing basic structure validation on loaded samples...")
77
+ # sample_count_to_check = min(5, len(test_samples))
78
+ # for i in range(sample_count_to_check):
79
+ # s = test_samples[i]
80
+ # if not isinstance(s, dict):
81
+ # print(f" CRITICAL: Sample {i} is not a dict. Type: {type(s)}")
82
+ # # 可以选择在这里中断或清理数据
83
+ # # return []
84
+ # elif 'messages' not in s or 'images' not in s:
85
+ # print(f" WARNING: Sample {i} might be missing 'messages' or 'images' keys. Found keys: {list(s.keys())}")
86
+ # else:
87
+ # if not isinstance(s['messages'], list):
88
+ # print(f" CRITICAL: Sample {i} 'messages' is not a list. Type: {type(s['messages'])}")
89
+ # if not isinstance(s['images'], list):
90
+ # print(f" CRITICAL: Sample {i} 'images' is not a list. Type: {type(s['images'])}")
91
+ # print("Structure validation complete.")
92
+ # elif test_samples:
93
+ # print(f"CRITICAL: Expected test_samples to be a list after loading, got {type(test_samples)}.")
94
+ # test_samples = [] # Reset to empty list on critical error
95
+
96
+ # return test_samples
97
+
98
+ # def extract_components(text):
99
+ # """从模型输出或标签中提取 <think>, <control>, <answer> 组件"""
100
+ # think_match = re.search(r'<think>(.*?)</think>', text, re.DOTALL)
101
+ # control_match = re.search(r'<control>(.*?)</control>', text)
102
+ # answer_match = re.search(r'<answer>(.*?)</answer>', text)
103
+
104
+ # return {
105
+ # 'think': think_match.group(1).strip() if think_match else "",
106
+ # 'control': control_match.group(1).strip() if control_match else "",
107
+ # 'answer': answer_match.group(1).strip() if answer_match else ""
108
+ # }
109
+
110
+
111
+ # def calculate_accuracy(pred_list, true_list):
112
+ # """计算准确率 (用于 <answer>)"""
113
+ # if len(pred_list) != len(true_list):
114
+ # raise ValueError("Prediction and truth lists must have the same length for accuracy calculation.")
115
+ # if not pred_list:
116
+ # return 0.0
117
+ # correct = sum(p == t for p, t in zip(pred_list, true_list))
118
+ # return correct / len(pred_list)
119
+
120
+ # # --- 2. 主评估逻辑 ---
121
+
122
+ # def main():
123
+ # # --- 配置 ---
124
+ # # 替换为您的模型路径
125
+ # model_path = "/data/LLM-SFT/SFT_Output/multiclsTask/Qwen2.5-VL-3B-Instruct/SFT/checkpoint-894"
126
+ # # 替换为您的测试集路径 (.json 或 .jsonl)
127
+ # test_data_path = "/data/LLM-SFT/datasets/driver_behavior_datasets/output_test.jsonl"
128
+ # output_file = model_path + "/eval/detailed_model_evaluation_results.json"
129
+
130
+ # # --- 加载模型和处理器 ---
131
+ # print("Loading model and processor...")
132
+ # try:
133
+ # processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
134
+ # model = AutoModelForVision2Seq.from_pretrained(
135
+ # model_path,
136
+ # trust_remote_code=True,
137
+ # torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32
138
+ # )
139
+ # model.eval()
140
+ # if torch.cuda.is_available():
141
+ # model = model.to('cuda')
142
+ # print("Model loaded on GPU.")
143
+ # else:
144
+ # print("Model loaded on CPU.")
145
+ # except Exception as e:
146
+ # print(f"Failed to load model/processor: {e}")
147
+ # return # Exit if model loading fails
148
+
149
+ # # --- 加载测试数据 ---
150
+ # try:
151
+ # test_samples = load_test_data(test_data_path)
152
+ # # print('test_samples',test_samples)
153
+ # if not test_samples:
154
+ # print("No samples loaded. Exiting.")
155
+ # return
156
+ # except Exception as e:
157
+ # print(f"Failed to load test data: {e}")
158
+ # return
159
+
160
+ # # --- 推理和收集结果 (带解析) ---
161
+ # results = []
162
+ # pred_answers = []
163
+ # true_answers = []
164
+ # pred_controls = [] # 存储 control 字符串用于后续分析
165
+ # true_controls = []
166
+
167
+ # print("Starting inference...")
168
+ # for i, sample in enumerate(test_samples):
169
+ # try:
170
+ # conversation = sample['messages']
171
+ # image_path = sample['images'][0]
172
+
173
+ # if not os.path.exists(image_path):
174
+ # print(f"Warning: Image not found: {image_path}. Skipping sample {i}.")
175
+ # # 为保持列表对齐,添加空占位符
176
+ # pred_answers.append("")
177
+ # true_answers.append(extract_components(conversation[-1]['content'])['answer'])
178
+ # pred_controls.append("")
179
+ # true_controls.append(extract_components(conversation[-1]['content'])['control'])
180
+ # continue
181
+
182
+ # image = Image.open(image_path).convert('RGB')
183
+
184
+ # # 准备输入
185
+ # # 注意:Qwen VL 系列通常期望 messages 是一个列表,其中包含 role 和 content
186
+ # # processor 会处理 <image> token 和图像的对齐
187
+ # # print('conversation[:-1]',conversation[:-1])
188
+ # texts = processor.apply_chat_template(conversation[:-1], tokenize=False, add_generation_prompt=True)
189
+
190
+ # image_inputs, video_inputs = process_vision_info(conversation[:-1])
191
+ # inputs = processor(
192
+ # text=texts,
193
+ # images=image_inputs,
194
+ # videos=video_inputs,
195
+ # padding=True,
196
+ # return_tensors="pt",
197
+ # )
198
+
199
+ # if torch.cuda.is_available():
200
+ # inputs = {k: v.to('cuda') for k, v in inputs.items()}
201
+
202
+ # # 生成
203
+ # with torch.no_grad():
204
+ # generated_ids = model.generate(**inputs,
205
+ # max_new_tokens=200,
206
+ # # num_beams=5,
207
+ # do_sample=True,
208
+ # top_p=0.75,
209
+ # top_k=50,
210
+ # temperature=0.2
211
+ # # repetition_penalty=1.2,
212
+ # # early_stopping=True
213
+ # )
214
+ # generated_ids_trimmed = [
215
+ # out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
216
+ # ]
217
+ # output_text = processor.batch_decode(
218
+ # generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
219
+ # )
220
+
221
+
222
+ # # 提取标签
223
+ # ground_truth = conversation[-1]['content']
224
+
225
+ # # 解析模型输出和标签
226
+ # pred_components = extract_components(output_text)
227
+ # true_components = extract_components(ground_truth)
228
+
229
+ # # 收集用于评估的数据
230
+ # pred_answers.append(pred_components['answer'])
231
+ # true_answers.append(true_components['answer'])
232
+ # pred_controls.append(pred_components['control'])
233
+ # true_controls.append(true_components['control'])
234
+
235
+ # # 打印部分样本进行观察
236
+ # if i < 3: # 打印前3个样本
237
+ # print(f"\n--- Sample {i+1} ---")
238
+ # print(f" Image: {image_path}")
239
+ # print(f" Input Text: {input_text}")
240
+ # print(f" Full Decoded Output: {decoded_output}")
241
+ # print(f" Processed Output Text: {output_text}")
242
+ # print(f" Parsed Prediction: {pred_components}")
243
+ # print(f" Ground Truth: {ground_truth}")
244
+ # print(f" Parsed Truth: {true_components}")
245
+
246
+ # # 存储详细结果
247
+ # results.append({
248
+ # "sample_id": i,
249
+ # "image_path": image_path,
250
+ # "input_text": input_text,
251
+ # "model_output_raw": decoded_output,
252
+ # "model_output_processed": output_text,
253
+ # "parsed_prediction": pred_components,
254
+ # "ground_truth_raw": ground_truth,
255
+ # "parsed_truth": true_components
256
+ # })
257
+
258
+ # except Exception as e:
259
+ # print(f"Error processing sample {i}: {e}")
260
+ # # 错误样本也计入评估列表,但标记为空或错误
261
+ # pred_answers.append("ERROR")
262
+ # true_answers.append(extract_components(conversation[-1]['content'])['answer'] if 'conversation' in locals() else "")
263
+ # pred_controls.append("ERROR")
264
+ # true_controls.append(extract_components(conversation[-1]['content'])['control'] if 'conversation' in locals() else "")
265
+
266
+ # results.append({
267
+ # "sample_id": i,
268
+ # "image_path": image_path if 'image_path' in locals() else "N/A",
269
+ # "input_text": conversation[-2]['content'] if 'conversation' in locals() else "N/A",
270
+ # "model_output_raw": f"ERROR: {e}",
271
+ # "model_output_processed": f"ERROR: {e}",
272
+ # "parsed_prediction": {"think": "", "control": "", "answer": "ERROR"},
273
+ # "ground_truth_raw": conversation[-1]['content'] if 'conversation' in locals() else "N/A",
274
+ # "parsed_truth": extract_components(conversation[-1]['content']) if 'conversation' in locals() else {"think": "", "control": "", "answer": ""}
275
+ # })
276
+
277
+
278
+ # # --- 保存详细结果 ---
279
+ # with open(output_file, 'w', encoding='utf-8') as f:
280
+ # json.dump(results, f, indent=2, ensure_ascii=False)
281
+ # print(f"\nDetailed results saved to {output_file}")
282
+
283
+ # # --- 深入定量评估 ---
284
+
285
+ # print(f"\n--- Quantitative Evaluation ---")
286
+ # total_samples = len(test_samples)
287
+ # successful_samples = len([r for r in results if not r['model_output_raw'].startswith("ERROR")])
288
+ # print(f"Total samples: {total_samples}, Successfully processed: {successful_samples}")
289
+
290
+ # if successful_samples == 0:
291
+ # print("No samples were processed successfully. Skipping quantitative evaluation.")
292
+ # return
293
+
294
+ # # a. <answer> 标签准确率 (仅计算成功处理的样本)
295
+ # # 过滤掉错误样本
296
+ # filtered_pred_answers = [p for p in pred_answers if p != "ERROR"]
297
+ # filtered_true_answers = [t for p, t in zip(pred_answers, true_answers) if p != "ERROR"]
298
+
299
+ # if filtered_pred_answers:
300
+ # answer_accuracy = calculate_accuracy(filtered_pred_answers, filtered_true_answers)
301
+ # print(f"<answer> Tag Accuracy (on successful samples): {answer_accuracy:.4f} ({sum(p==t for p,t in zip(filtered_pred_answers, filtered_true_answers))}/{len(filtered_true_answers)})")
302
+ # else:
303
+ # print("No valid <answer> predictions to evaluate.")
304
+ # answer_accuracy = 0.0
305
+
306
+ # # b. <control> 指令分析
307
+ # filtered_pred_controls = [c for p, c in zip(pred_answers, pred_controls) if p != "ERROR"]
308
+ # filtered_true_controls = [t for p, t in zip(pred_answers, true_controls) if p != "ERROR"]
309
+
310
+ # if filtered_pred_controls:
311
+ # control_non_empty_pred = [c != "" for c in filtered_pred_controls]
312
+ # control_non_empty_true = [c != "" for c in filtered_true_controls]
313
+ # control_existence_acc = calculate_accuracy(control_non_empty_pred, control_non_empty_true)
314
+ # print(f"<control> Tag Presence Accuracy (on successful samples): {control_existence_acc:.4f}")
315
+ # else:
316
+ # print("No valid <control> predictions to evaluate.")
317
+ # control_existence_acc = 0.0
318
+
319
+ # # c. 分类别 <answer> 准确率
320
+ # if filtered_true_answers:
321
+ # unique_labels = sorted(list(set(filtered_true_answers + filtered_pred_answers)))
322
+ # print("\nPer-class <answer> accuracy:")
323
+ # class_acc = {}
324
+ # for label in unique_labels:
325
+ # tp = sum(1 for p, t in zip(filtered_pred_answers, filtered_true_answers) if p == label and t == label)
326
+ # total_true = sum(1 for t in filtered_true_answers if t == label)
327
+ # class_acc[label] = tp / total_true if total_true > 0 else 0.0
328
+ # print(f" Accuracy for '{label}': {class_acc[label]:.4f} ({tp}/{total_true if total_true > 0 else 'N/A'})")
329
+
330
+ # # d. (可选) 文本相似度评估 (需要安装 nltk 或 rouge-score)
331
+ # # 示例使用 ROUGE (需要 pip install rouge)
332
+ # from rouge import Rouge
333
+ # rouge = Rouge()
334
+ # avg_rouge_scores = {'rouge-1': 0.0, 'rouge-2': 0.0, 'rouge-l': 0.0}
335
+ # valid_samples_for_rouge = 0
336
+ # for res in results:
337
+ # if not res['model_output_raw'].startswith("ERROR") and res['parsed_truth']['think'] and res['parsed_prediction']['think']:
338
+ # try:
339
+ # scores = rouge.get_scores(res['parsed_prediction']['think'], res['parsed_truth']['think'])
340
+ # for metric in avg_rouge_scores:
341
+ # avg_rouge_scores[metric] += scores[0][metric]['f']
342
+ # valid_samples_for_rouge += 1
343
+ # except Exception as e:
344
+ # print(f"ROUGE calculation error for sample {res['sample_id']}: {e}")
345
+
346
+ # if valid_samples_for_rouge > 0:
347
+ # for metric in avg_rouge_scores:
348
+ # avg_rouge_scores[metric] /= valid_samples_for_rouge
349
+ # print(f"\nAverage ROUGE Scores (on <think> tags, {valid_samples_for_rouge} valid samples):")
350
+ # for metric, score in avg_rouge_scores.items():
351
+ # print(f" {metric.upper()}: {score:.4f}")
352
+ # else:
353
+ # print("\nNo valid samples for ROUGE calculation on <think> tags.")
354
+
355
+ # # --- 7. 错误案例分析 ---
356
+ # print(f"\n--- Error Analysis ---")
357
+ # error_count = sum(1 for r in results if r['model_output_raw'].startswith("ERROR"))
358
+ # if error_count > 0:
359
+ # print(f"Number of samples with processing errors: {error_count}")
360
+ # # 可以在这里打印错误详情
361
+ # else:
362
+ # print("No processing errors detected during inference.")
363
+
364
+ # print("Samples where <answer> prediction was incorrect (excluding errors):")
365
+ # incorrect_count = 0
366
+ # for res in results:
367
+ # # 只分析成功处理且预测错误的样本
368
+ # if not res['model_output_raw'].startswith("ERROR") and \
369
+ # res['parsed_prediction']['answer'] != res['parsed_truth']['answer']:
370
+ # incorrect_count += 1
371
+ # if incorrect_count <= 5: # 只打印前5个错误案例
372
+ # print(f" Sample ID: {res['sample_id']}")
373
+ # print(f" Image: {res['image_path']}")
374
+ # print(f" Input: {res['input_text']}")
375
+ # print(f" Predicted Answer: '{res['parsed_prediction']['answer']}'")
376
+ # print(f" True Answer: '{res['parsed_truth']['answer']}'")
377
+ # print(f" Predicted Control: '{res['parsed_prediction']['control']}'")
378
+ # print(f" True Control: '{res['parsed_truth']['control']}'")
379
+ # # print(f" Predicted Think: '{res['parsed_prediction']['think']}'") # 可选
380
+ # # print(f" True Think: '{res['parsed_truth']['think']}'") # 可选
381
+ # print("-" * 20)
382
+ # if incorrect_count > 5:
383
+ # print(f"... and {incorrect_count - 5} more incorrect predictions.")
384
+ # elif incorrect_count == 0:
385
+ # print(" All successful predictions matched the ground truth <answer>.")
386
+
387
+ # # --- 8. 总结 ---
388
+ # print(f"\n--- Summary ---")
389
+ # print(f"Total samples processed: {total_samples}")
390
+ # print(f"Successfully processed samples: {successful_samples}")
391
+ # if filtered_pred_answers:
392
+ # print(f"<answer> Accuracy (successful samples): {answer_accuracy:.4f}")
393
+ # if filtered_pred_controls:
394
+ # print(f"<control> Presence Accuracy (successful samples): {control_existence_acc:.4f}")
395
+ # print("Per-class accuracies calculated above (if applicable).")
396
+ # print("Detailed results are available in the output file.")
397
+
398
+
399
+ # if __name__ == "__main__":
400
+ # main()
401
+
402
+
403
+
404
+
405
+ import json
406
+ import os
407
+ from typing import Dict, List, Any, Tuple
408
+ import re
409
+ from collections import defaultdict, Counter
410
+ import ast
411
+
412
+ def load_data(file_path: str) -> List[Dict]:
413
+ """
414
+ Load data from JSON or JSONL file
415
+ """
416
+ data = []
417
+
418
+ # Check file extension to determine format
419
+ if file_path.lower().endswith('.json'):
420
+ with open(file_path, 'r', encoding='utf-8') as f:
421
+ data = json.load(f)
422
+ elif file_path.lower().endswith('.jsonl'):
423
+ with open(file_path, 'r', encoding='utf-8') as f:
424
+ for line in f:
425
+ line = line.strip()
426
+ if line:
427
+ data.append(json.loads(line))
428
+ else:
429
+ # Try to auto-detect based on content
430
+ with open(file_path, 'r', encoding='utf-8') as f:
431
+ first_line = f.readline().strip()
432
+ f.seek(0)
433
+
434
+ if first_line.startswith('['): # JSON array
435
+ data = json.load(f)
436
+ else: # Assume JSONL
437
+ for line in f:
438
+ line = line.strip()
439
+ if line:
440
+ data.append(json.loads(line))
441
+
442
+ return data
443
+
444
+ def parse_think_content(think_str: str) -> Dict[str, str]:
445
+ """
446
+ Parse <think> content to extract behavior description
447
+ """
448
+ if not think_str:
449
+ return {"raw": "", "behavior": ""}
450
+
451
+ # Remove <think> tags and extract content
452
+ clean_str = re.sub(r'<think>|</think>', '', think_str).strip()
453
+
454
+ return {
455
+ "raw": clean_str,
456
+ "behavior": clean_str # For now, the behavior is the full content
457
+ }
458
+
459
+ def parse_control_content(control_str: str) -> Dict[str, Any]:
460
+ """
461
+ Parse <control> content to extract control command and parameters
462
+ """
463
+ if not control_str:
464
+ return {"raw": "", "command": "", "parameters": {}, "type": "none"}
465
+
466
+ clean_str = re.sub(r'<control>|</control>', '', control_str).strip()
467
+
468
+ # Extract command and parameters
469
+ command = clean_str
470
+ params = {}
471
+ control_type = "other"
472
+
473
+ if "(" in clean_str and ")" in clean_str:
474
+ # Pattern like: MonitorPassenger(SwellingDetected)
475
+ match = re.match(r'(\w+)\(([^)]+)\)', clean_str)
476
+ if match:
477
+ command = match.group(1)
478
+ param_str = match.group(2)
479
+ params = {"parameter": param_str}
480
+ if "Monitor" in command:
481
+ control_type = "monitoring"
482
+ elif "Alert" in command:
483
+ control_type = "alerting"
484
+ elif "set" in command:
485
+ control_type = "setting"
486
+ elif "|" in clean_str:
487
+ # Pattern like: setMute|false|
488
+ parts = clean_str.split("|")
489
+ command = parts[0] if parts else ""
490
+ params = {"params": parts[1:] if len(parts) > 1 else []}
491
+ control_type = "command"
492
+ else:
493
+ command = clean_str
494
+ control_type = "function"
495
+
496
+ return {
497
+ "raw": clean_str,
498
+ "command": command,
499
+ "parameters": params,
500
+ "type": control_type
501
+ }
502
+
503
+ def parse_answer_content(answer_str: str) -> Dict[str, str]:
504
+ """
505
+ Parse <answer> content to extract the final answer
506
+ """
507
+ if not answer_str:
508
+ return {"raw": "", "category": "", "description": ""}
509
+
510
+ clean_str = re.sub(r'<answer>|</answer>', '', answer_str).strip()
511
+
512
+ # Try to categorize the answer
513
+ category = "other"
514
+ if any(keyword in clean_str.lower() for keyword in ["swelling", "eye", "face", "facial"]):
515
+ category = "physical_symptom"
516
+ elif any(keyword in clean_str.lower() for keyword in ["sleep", "drowsy", "tired", "yawn"]):
517
+ category = "drowsiness"
518
+ elif any(keyword in clean_str.lower() for keyword in ["phone", "call", "text", "mobile"]):
519
+ category = "distraction"
520
+ elif any(keyword in clean_str.lower() for keyword in ["smoke", "cigarette"]):
521
+ category = "smoking"
522
+ elif any(keyword in clean_str.lower() for keyword in ["drunk", "alcohol", "intoxicated"]):
523
+ category = "intoxication"
524
+ elif any(keyword in clean_str.lower() for keyword in ["mouth", "corner", "slanting"]):
525
+ category = "facial_expression"
526
+ elif any(keyword in clean_str.lower() for keyword in ["head", "cover", "hold"]):
527
+ category = "head_behavior"
528
+ elif any(keyword in clean_str.lower() for keyword in ["arm", "hand", "slip", "droop"]):
529
+ category = "limb_behavior"
530
+ elif any(keyword in clean_str.lower() for keyword in ["radio", "adjust", "control"]):
531
+ category = "vehicle_control"
532
+
533
+ return {
534
+ "raw": clean_str,
535
+ "category": category,
536
+ "description": clean_str
537
+ }
538
+
539
+ def extract_all_components(text: str) -> Dict[str, str]:
540
+ """
541
+ Extract think, control, and answer components from text
542
+ """
543
+ components = {
544
+ "think": "",
545
+ "control": "",
546
+ "answer": ""
547
+ }
548
+
549
+ # Extract <think> content
550
+ think_match = re.search(r'<think>(.*?)</think>', text, re.DOTALL)
551
+ if think_match:
552
+ components["think"] = think_match.group(1).strip()
553
+
554
+ # Extract <control> content
555
+ control_match = re.search(r'<control>(.*?)</control>', text, re.DOTALL)
556
+ if control_match:
557
+ components["control"] = control_match.group(1).strip()
558
+
559
+ # Extract <answer> content
560
+ answer_match = re.search(r'<answer>(.*?)</answer>', text, re.DOTALL)
561
+ if answer_match:
562
+ components["answer"] = answer_match.group(1).strip()
563
+
564
+ return components
565
+
566
+ def calculate_component_accuracy(predicted_components: Dict, actual_components: Dict) -> Dict[str, float]:
567
+ """
568
+ Calculate accuracy for each component
569
+ """
570
+ accuracy = {}
571
+
572
+ # Think component accuracy
573
+ accuracy['think'] = calculate_similarity(
574
+ predicted_components.get('think', ''),
575
+ actual_components.get('think', '')
576
+ )
577
+
578
+ # Control component accuracy
579
+ accuracy['control'] = calculate_similarity(
580
+ predicted_components.get('control', ''),
581
+ actual_components.get('control', '')
582
+ )
583
+
584
+ # Answer component accuracy
585
+ accuracy['answer'] = calculate_similarity(
586
+ predicted_components.get('answer', ''),
587
+ actual_components.get('answer', '')
588
+ )
589
+
590
+ return accuracy
591
+
592
+ def calculate_similarity(str1: str, str2: str) -> float:
593
+ """
594
+ Calculate similarity between two strings
595
+ """
596
+ if not str1 and not str2:
597
+ return 1.0
598
+ if not str1 or not str2:
599
+ return 0.0
600
+
601
+ str1_lower = str1.lower().strip()
602
+ str2_lower = str2.lower().strip()
603
+
604
+ if str1_lower == str2_lower:
605
+ return 1.0
606
+
607
+ # Calculate word overlap
608
+ words1 = set(str1_lower.split())
609
+ words2 = set(str2_lower.split())
610
+
611
+ if len(words1) == 0 and len(words2) == 0:
612
+ return 1.0
613
+ if len(words1) == 0 or len(words2) == 0:
614
+ return 0.0
615
+
616
+ intersection = words1.intersection(words2)
617
+ union = words1.union(words2)
618
+
619
+ # Jaccard similarity
620
+ jaccard = len(intersection) / len(union) if union else 0
621
+
622
+ # Also consider sequence similarity for exact matches
623
+ if str1_lower in str2_lower or str2_lower in str1_lower:
624
+ return max(jaccard, 0.8)
625
+
626
+ return jaccard
627
+
628
+ def evaluate_component_quality(parsed_component: Dict, expected_component: Dict) -> Dict[str, float]:
629
+ """
630
+ Evaluate the quality of component parsing and prediction
631
+ """
632
+ quality = {}
633
+
634
+ if parsed_component.get('type') == expected_component.get('type'):
635
+ quality['type_match'] = 1.0
636
+ else:
637
+ quality['type_match'] = 0.0
638
+
639
+ # Evaluate content quality based on component type
640
+ if parsed_component.get('type') == 'monitoring':
641
+ quality['content_quality'] = 1.0 if 'Monitor' in parsed_component.get('command', '') else 0.0
642
+ elif parsed_component.get('type') == 'alerting':
643
+ quality['content_quality'] = 1.0 if 'Alert' in parsed_component.get('command', '') else 0.0
644
+ else:
645
+ quality['content_quality'] = 0.5 # Default medium quality
646
+
647
+ return quality
648
+
649
+ def comprehensive_evaluation(data: List[Dict]) -> Dict[str, Any]:
650
+ """
651
+ Comprehensive evaluation of all three components
652
+ """
653
+ total_samples = len(data)
654
+ results = {
655
+ 'overall_metrics': {},
656
+ 'component_wise_metrics': {
657
+ 'think': {'accuracy_scores': [], 'quality_scores': []},
658
+ 'control': {'accuracy_scores': [], 'quality_scores': []},
659
+ 'answer': {'accuracy_scores': [], 'quality_scores': []}
660
+ },
661
+ 'detailed_analysis': [],
662
+ 'error_patterns': {
663
+ 'think_errors': [],
664
+ 'control_errors': [],
665
+ 'answer_errors': []
666
+ }
667
+ }
668
+
669
+ for idx, sample in enumerate(data):
670
+ # Extract components from response and labels
671
+ response_components = extract_all_components(sample.get('response', ''))
672
+ label_components = extract_all_components(sample.get('labels', ''))
673
+
674
+ # Parse components for deeper analysis
675
+ parsed_think = parse_think_content(response_components['think'])
676
+ parsed_control = parse_control_content(response_components['control'])
677
+ parsed_answer = parse_answer_content(response_components['answer'])
678
+
679
+ actual_think = parse_think_content(label_components['think'])
680
+ actual_control = parse_control_content(label_components['control'])
681
+ actual_answer = parse_answer_content(label_components['answer'])
682
+
683
+ # Calculate component-wise accuracy
684
+ component_accuracy = calculate_component_accuracy(response_components, label_components)
685
+
686
+ # Calculate component quality
687
+ think_quality = evaluate_component_quality(parsed_think, actual_think)
688
+ control_quality = evaluate_component_quality(parsed_control, actual_control)
689
+ answer_quality = evaluate_component_quality(parsed_answer, actual_answer)
690
+
691
+ # Store component-wise metrics
692
+ for comp in ['think', 'control', 'answer']:
693
+ results['component_wise_metrics'][comp]['accuracy_scores'].append(component_accuracy[comp])
694
+ results['component_wise_metrics'][comp]['quality_scores'].append(
695
+ think_quality.get('content_quality', 0.5) if comp == 'think' else
696
+ control_quality.get('content_quality', 0.5) if comp == 'control' else
697
+ answer_quality.get('content_quality', 0.5)
698
+ )
699
+
700
+ # Store detailed analysis
701
+ detailed_result = {
702
+ 'index': idx,
703
+ 'response_components': response_components,
704
+ 'label_components': label_components,
705
+ 'parsed_response': {
706
+ 'think': parsed_think,
707
+ 'control': parsed_control,
708
+ 'answer': parsed_answer
709
+ },
710
+ 'parsed_labels': {
711
+ 'think': actual_think,
712
+ 'control': actual_control,
713
+ 'answer': actual_answer
714
+ },
715
+ 'component_accuracy': component_accuracy,
716
+ 'component_quality': {
717
+ 'think': think_quality,
718
+ 'control': control_quality,
719
+ 'answer': answer_quality
720
+ },
721
+ 'overall_score': sum(component_accuracy.values()) / 3 if component_accuracy else 0
722
+ }
723
+
724
+ results['detailed_analysis'].append(detailed_result)
725
+
726
+ # Analyze errors
727
+ if component_accuracy['think'] < 0.5:
728
+ results['error_patterns']['think_errors'].append(idx)
729
+ if component_accuracy['control'] < 0.5:
730
+ results['error_patterns']['control_errors'].append(idx)
731
+ if component_accuracy['answer'] < 0.5:
732
+ results['error_patterns']['answer_errors'].append(idx)
733
+
734
+ # Calculate overall metrics
735
+ overall_metrics = {}
736
+ for comp in ['think', 'control', 'answer']:
737
+ acc_scores = results['component_wise_metrics'][comp]['accuracy_scores']
738
+ qual_scores = results['component_wise_metrics'][comp]['quality_scores']
739
+
740
+ overall_metrics[f'{comp}_avg_accuracy'] = sum(acc_scores) / len(acc_scores) if acc_scores else 0
741
+ overall_metrics[f'{comp}_avg_quality'] = sum(qual_scores) / len(qual_scores) if qual_scores else 0
742
+ overall_metrics[f'{comp}_std_accuracy'] = (
743
+ sum((x - overall_metrics[f'{comp}_avg_accuracy'])**2 for x in acc_scores) / len(acc_scores)
744
+ )**0.5 if acc_scores else 0
745
+
746
+ # Calculate overall system performance
747
+ overall_metrics['total_samples'] = total_samples
748
+ overall_metrics['avg_overall_score'] = sum(
749
+ d['overall_score'] for d in results['detailed_analysis']
750
+ ) / total_samples if total_samples > 0 else 0
751
+
752
+ results['overall_metrics'] = overall_metrics
753
+
754
+ return results
755
+
756
+ def generate_evaluation_report(results: Dict[str, Any]) -> str:
757
+ """
758
+ Generate comprehensive evaluation report
759
+ """
760
+ report = []
761
+ report.append("="*100)
762
+ report.append("COMPREHENSIVE EVALUATION OF IN-VEHICLE MULTIMODAL AI MODEL")
763
+ report.append("="*100)
764
+
765
+ metrics = results['overall_metrics']
766
+ report.append(f"\n📊 OVERALL SYSTEM PERFORMANCE:")
767
+ report.append(f" Total Samples: {metrics['total_samples']}")
768
+ report.append(f" Average Overall Score: {metrics['avg_overall_score']:.4f}")
769
+
770
+ report.append(f"\n🔍 COMPONENT-WISE PERFORMANCE:")
771
+ for comp in ['think', 'control', 'answer']:
772
+ avg_acc = metrics.get(f'{comp}_avg_accuracy', 0)
773
+ avg_qual = metrics.get(f'{comp}_avg_quality', 0)
774
+ std_acc = metrics.get(f'{comp}_std_accuracy', 0)
775
+
776
+ report.append(f" {comp.upper()}:")
777
+ report.append(f" Average Accuracy: {avg_acc:.4f}")
778
+ report.append(f" Average Quality: {avg_qual:.4f}")
779
+ report.append(f" Std Deviation: {std_acc:.4f}")
780
+
781
+ # Error analysis
782
+ error_patterns = results['error_patterns']
783
+ report.append(f"\n❌ ERROR ANALYSIS:")
784
+ report.append(f" Think component errors: {len(error_patterns['think_errors'])} samples")
785
+ report.append(f" Control component errors: {len(error_patterns['control_errors'])} samples")
786
+ report.append(f" Answer component errors: {len(error_patterns['answer_errors'])} samples")
787
+
788
+ # Sample error analysis
789
+ if results['detailed_analysis']:
790
+ sample_analysis = results['detailed_analysis'][0] # Show first sample as example
791
+ report.append(f"\n📋 SAMPLE ANALYSIS (First Sample):")
792
+ report.append(f" Think Accuracy: {sample_analysis['component_accuracy']['think']:.4f}")
793
+ report.append(f" Control Accuracy: {sample_analysis['component_accuracy']['control']:.4f}")
794
+ report.append(f" Answer Accuracy: {sample_analysis['component_accuracy']['answer']:.4f}")
795
+ report.append(f" Overall Score: {sample_analysis['overall_score']:.4f}")
796
+
797
+ # Component type analysis
798
+ report.append(f"\n🔧 COMPONENT TYPE ANALYSIS:")
799
+
800
+ # Analyze control command types
801
+ control_types = []
802
+ for analysis in results['detailed_analysis']:
803
+ control_type = analysis['parsed_response']['control'].get('type', 'unknown')
804
+ control_types.append(control_type)
805
+
806
+ type_counts = Counter(control_types)
807
+ report.append(" Control Command Types:")
808
+ for control_type, count in type_counts.most_common():
809
+ report.append(f" {control_type}: {count} samples")
810
+
811
+ # Answer category analysis
812
+ answer_categories = []
813
+ for analysis in results['detailed_analysis']:
814
+ answer_category = analysis['parsed_response']['answer'].get('category', 'unknown')
815
+ answer_categories.append(answer_category)
816
+
817
+ category_counts = Counter(answer_categories)
818
+ report.append(" Answer Categories:")
819
+ for category, count in category_counts.most_common():
820
+ report.append(f" {category}: {count} samples")
821
+
822
+ report.append(f"\n🎯 RECOMMENDATIONS:")
823
+ if metrics.get('think_avg_accuracy', 0) < 0.7:
824
+ report.append(" - Improve think component (behavior analysis)")
825
+ if metrics.get('control_avg_accuracy', 0) < 0.7:
826
+ report.append(" - Improve control component (command generation)")
827
+ if metrics.get('answer_avg_accuracy', 0) < 0.7:
828
+ report.append(" - Improve answer component (final classification)")
829
+
830
+ return "\n".join(report)
831
+
832
+ def save_evaluation_results(results: Dict[str, Any], output_path: str):
833
+ """
834
+ Save evaluation results to JSON file
835
+ """
836
+ with open(output_path, 'w', encoding='utf-8') as f:
837
+ json.dump(results, f, ensure_ascii=False, indent=2)
838
+
839
+ def main(input_file: str, output_file: str = None):
840
+ """
841
+ Main function to perform comprehensive evaluation
842
+ """
843
+ print(f"Loading data from: {input_file}")
844
+
845
+ # Load data
846
+ data = load_data(input_file)
847
+ print(f"Loaded {len(data)} samples")
848
+
849
+ # Perform comprehensive evaluation
850
+ print("Performing comprehensive evaluation...")
851
+ results = comprehensive_evaluation(data)
852
+
853
+ # Generate and print report
854
+ report = generate_evaluation_report(results)
855
+ print(report)
856
+
857
+ # Save results if output path provided
858
+ if output_file:
859
+ save_evaluation_results(results, output_file)
860
+ print(f"\nDetailed evaluation results saved to: {output_file}")
861
+
862
+ return results
863
+
864
+ if __name__ == "__main__":
865
+ import sys
866
+
867
+ # if len(sys.argv) < 2:
868
+ # print("Usage: python comprehensive_evaluation.py <input_file> [output_file]")
869
+ # print(" input_file: Path to JSON or JSONL file containing model predictions")
870
+ # print(" output_file: Optional path to save detailed evaluation results")
871
+ # sys.exit(1)
872
+
873
+ # input_file = sys.argv[1]
874
+ # output_file = sys.argv[2] if len(sys.argv) > 2 else None
875
+
876
+ input_file = r"/data/LLM-SFT/SFT_Output/multiclsTask/Qwen2.5-VL-3B-Instruct/v0-20251123-182828/checkpoint-264/infer_result/20251124-175009.jsonl"
877
+ output_file = r"/data/LLM-SFT/SFT_Output/multiclsTask/Qwen2.5-VL-3B-Instruct/v0-20251123-182828/checkpoint-264/eval/20251124-175009.jsonl"
878
+
879
+
880
+ results = main(input_file, output_file)