PiloBi's picture
Upload RL_infer.py with huggingface_hub
cc427cc verified
# import json
# import re
# from PIL import Image
# from transformers import AutoModelForVision2Seq, AutoProcessor
# import torch
# import os
# from qwen_vl_utils import process_vision_info
# # --- 1. 辅助函数 ---
# def load_test_data(file_path):
# """
# 根据文件扩展名自动加载 .json 或 .jsonl 文件。
# 对于 .json 文件,尝试不同的常见键来查找样本列表。
# """
# _, ext = os.path.splitext(file_path)
# ext = ext.lower()
# test_samples = []
# if ext == '.jsonl':
# print(f"Loading data from JSON Lines file: {file_path}")
# with open(file_path, 'r', encoding='utf-8') as f:
# for i, line in enumerate(f):
# line = line.strip()
# if not line:
# continue
# try:
# test_samples.append(json.loads(line))
# except json.JSONDecodeError as e:
# print(f"Warning: Skipping invalid JSON line {i+1} in {file_path}: {e}")
# elif ext == '.json':
# print(f"Loading data from JSON file: {file_path}")
# try:
# with open(file_path, 'r', encoding='utf-8') as f:
# data = json.load(f)
# if isinstance(data, list):
# print(" Detected JSON array format.")
# test_samples = data
# elif isinstance(data, dict):
# print(" Detected JSON object format. Searching for samples...")
# possible_keys = ['data', 'samples', 'instances', 'items', 'conversations', 'messages']
# found = False
# for key in possible_keys:
# if key in data and isinstance(data[key], list) and len(data[key]) > 0:
# # 简单检查列表第一个元素是否像样本 (dict with 'messages')
# first_item = data[key][0]
# if isinstance(first_item, dict) and 'messages' in first_item:
# print(f" Found samples under key '{key}'.")
# test_samples = data[key]
# found = True
# break
# if not found:
# # 启发式:查找第一个值是列表且列表元素是字典的键
# for key, value in data.items():
# if isinstance(value, list) and len(value) > 0 and isinstance(value[0], dict) and 'messages' in value[0]:
# print(f" Found samples under key '{key}' (heuristic).")
# test_samples = value
# found = True
# break
# if not found:
# print(f" Error: Could not find a list of samples in the JSON object. Keys found: {list(data.keys())}")
# else:
# print(f" Error: Unexpected JSON structure. Root element type: {type(data)}")
# except json.JSONDecodeError as e:
# print(f"Error: Failed to decode JSON from {file_path}: {e}")
# except Exception as e:
# print(f"Error: An unexpected error occurred while loading {file_path}: {e}")
# else:
# print(f"Error: Unsupported file extension '{ext}'. Please provide a .json or .jsonl file.")
# print(f"Loaded {len(test_samples)} samples.")
# # 验证加载的样本结构
# if test_samples and isinstance(test_samples, list):
# print("Performing basic structure validation on loaded samples...")
# sample_count_to_check = min(5, len(test_samples))
# for i in range(sample_count_to_check):
# s = test_samples[i]
# if not isinstance(s, dict):
# print(f" CRITICAL: Sample {i} is not a dict. Type: {type(s)}")
# # 可以选择在这里中断或清理数据
# # return []
# elif 'messages' not in s or 'images' not in s:
# print(f" WARNING: Sample {i} might be missing 'messages' or 'images' keys. Found keys: {list(s.keys())}")
# else:
# if not isinstance(s['messages'], list):
# print(f" CRITICAL: Sample {i} 'messages' is not a list. Type: {type(s['messages'])}")
# if not isinstance(s['images'], list):
# print(f" CRITICAL: Sample {i} 'images' is not a list. Type: {type(s['images'])}")
# print("Structure validation complete.")
# elif test_samples:
# print(f"CRITICAL: Expected test_samples to be a list after loading, got {type(test_samples)}.")
# test_samples = [] # Reset to empty list on critical error
# return test_samples
# def extract_components(text):
# """从模型输出或标签中提取 <think>, <control>, <answer> 组件"""
# think_match = re.search(r'<think>(.*?)</think>', text, re.DOTALL)
# control_match = re.search(r'<control>(.*?)</control>', text)
# answer_match = re.search(r'<answer>(.*?)</answer>', text)
# return {
# 'think': think_match.group(1).strip() if think_match else "",
# 'control': control_match.group(1).strip() if control_match else "",
# 'answer': answer_match.group(1).strip() if answer_match else ""
# }
# def calculate_accuracy(pred_list, true_list):
# """计算准确率 (用于 <answer>)"""
# if len(pred_list) != len(true_list):
# raise ValueError("Prediction and truth lists must have the same length for accuracy calculation.")
# if not pred_list:
# return 0.0
# correct = sum(p == t for p, t in zip(pred_list, true_list))
# return correct / len(pred_list)
# # --- 2. 主评估逻辑 ---
# def main():
# # --- 配置 ---
# # 替换为您的模型路径
# model_path = "/data/LLM-SFT/SFT_Output/multiclsTask/Qwen2.5-VL-3B-Instruct/SFT/checkpoint-894"
# # 替换为您的测试集路径 (.json 或 .jsonl)
# test_data_path = "/data/LLM-SFT/datasets/driver_behavior_datasets/output_test.jsonl"
# output_file = model_path + "/eval/detailed_model_evaluation_results.json"
# # --- 加载模型和处理器 ---
# print("Loading model and processor...")
# try:
# processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
# model = AutoModelForVision2Seq.from_pretrained(
# model_path,
# trust_remote_code=True,
# torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32
# )
# model.eval()
# if torch.cuda.is_available():
# model = model.to('cuda')
# print("Model loaded on GPU.")
# else:
# print("Model loaded on CPU.")
# except Exception as e:
# print(f"Failed to load model/processor: {e}")
# return # Exit if model loading fails
# # --- 加载测试数据 ---
# try:
# test_samples = load_test_data(test_data_path)
# # print('test_samples',test_samples)
# if not test_samples:
# print("No samples loaded. Exiting.")
# return
# except Exception as e:
# print(f"Failed to load test data: {e}")
# return
# # --- 推理和收集结果 (带解析) ---
# results = []
# pred_answers = []
# true_answers = []
# pred_controls = [] # 存储 control 字符串用于后续分析
# true_controls = []
# print("Starting inference...")
# for i, sample in enumerate(test_samples):
# try:
# conversation = sample['messages']
# image_path = sample['images'][0]
# if not os.path.exists(image_path):
# print(f"Warning: Image not found: {image_path}. Skipping sample {i}.")
# # 为保持列表对齐,添加空占位符
# pred_answers.append("")
# true_answers.append(extract_components(conversation[-1]['content'])['answer'])
# pred_controls.append("")
# true_controls.append(extract_components(conversation[-1]['content'])['control'])
# continue
# image = Image.open(image_path).convert('RGB')
# # 准备输入
# # 注意:Qwen VL 系列通常期望 messages 是一个列表,其中包含 role 和 content
# # processor 会处理 <image> token 和图像的对齐
# # print('conversation[:-1]',conversation[:-1])
# texts = processor.apply_chat_template(conversation[:-1], tokenize=False, add_generation_prompt=True)
# image_inputs, video_inputs = process_vision_info(conversation[:-1])
# inputs = processor(
# text=texts,
# images=image_inputs,
# videos=video_inputs,
# padding=True,
# return_tensors="pt",
# )
# if torch.cuda.is_available():
# inputs = {k: v.to('cuda') for k, v in inputs.items()}
# # 生成
# with torch.no_grad():
# generated_ids = model.generate(**inputs,
# max_new_tokens=200,
# # num_beams=5,
# do_sample=True,
# top_p=0.75,
# top_k=50,
# temperature=0.2
# # repetition_penalty=1.2,
# # early_stopping=True
# )
# generated_ids_trimmed = [
# out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
# ]
# output_text = processor.batch_decode(
# generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
# )
# # 提取标签
# ground_truth = conversation[-1]['content']
# # 解析模型输出和标签
# pred_components = extract_components(output_text)
# true_components = extract_components(ground_truth)
# # 收集用于评估的数据
# pred_answers.append(pred_components['answer'])
# true_answers.append(true_components['answer'])
# pred_controls.append(pred_components['control'])
# true_controls.append(true_components['control'])
# # 打印部分样本进行观察
# if i < 3: # 打印前3个样本
# print(f"\n--- Sample {i+1} ---")
# print(f" Image: {image_path}")
# print(f" Input Text: {input_text}")
# print(f" Full Decoded Output: {decoded_output}")
# print(f" Processed Output Text: {output_text}")
# print(f" Parsed Prediction: {pred_components}")
# print(f" Ground Truth: {ground_truth}")
# print(f" Parsed Truth: {true_components}")
# # 存储详细结果
# results.append({
# "sample_id": i,
# "image_path": image_path,
# "input_text": input_text,
# "model_output_raw": decoded_output,
# "model_output_processed": output_text,
# "parsed_prediction": pred_components,
# "ground_truth_raw": ground_truth,
# "parsed_truth": true_components
# })
# except Exception as e:
# print(f"Error processing sample {i}: {e}")
# # 错误样本也计入评估列表,但标记为空或错误
# pred_answers.append("ERROR")
# true_answers.append(extract_components(conversation[-1]['content'])['answer'] if 'conversation' in locals() else "")
# pred_controls.append("ERROR")
# true_controls.append(extract_components(conversation[-1]['content'])['control'] if 'conversation' in locals() else "")
# results.append({
# "sample_id": i,
# "image_path": image_path if 'image_path' in locals() else "N/A",
# "input_text": conversation[-2]['content'] if 'conversation' in locals() else "N/A",
# "model_output_raw": f"ERROR: {e}",
# "model_output_processed": f"ERROR: {e}",
# "parsed_prediction": {"think": "", "control": "", "answer": "ERROR"},
# "ground_truth_raw": conversation[-1]['content'] if 'conversation' in locals() else "N/A",
# "parsed_truth": extract_components(conversation[-1]['content']) if 'conversation' in locals() else {"think": "", "control": "", "answer": ""}
# })
# # --- 保存详细结果 ---
# with open(output_file, 'w', encoding='utf-8') as f:
# json.dump(results, f, indent=2, ensure_ascii=False)
# print(f"\nDetailed results saved to {output_file}")
# # --- 深入定量评估 ---
# print(f"\n--- Quantitative Evaluation ---")
# total_samples = len(test_samples)
# successful_samples = len([r for r in results if not r['model_output_raw'].startswith("ERROR")])
# print(f"Total samples: {total_samples}, Successfully processed: {successful_samples}")
# if successful_samples == 0:
# print("No samples were processed successfully. Skipping quantitative evaluation.")
# return
# # a. <answer> 标签准确率 (仅计算成功处理的样本)
# # 过滤掉错误样本
# filtered_pred_answers = [p for p in pred_answers if p != "ERROR"]
# filtered_true_answers = [t for p, t in zip(pred_answers, true_answers) if p != "ERROR"]
# if filtered_pred_answers:
# answer_accuracy = calculate_accuracy(filtered_pred_answers, filtered_true_answers)
# print(f"<answer> Tag Accuracy (on successful samples): {answer_accuracy:.4f} ({sum(p==t for p,t in zip(filtered_pred_answers, filtered_true_answers))}/{len(filtered_true_answers)})")
# else:
# print("No valid <answer> predictions to evaluate.")
# answer_accuracy = 0.0
# # b. <control> 指令分析
# filtered_pred_controls = [c for p, c in zip(pred_answers, pred_controls) if p != "ERROR"]
# filtered_true_controls = [t for p, t in zip(pred_answers, true_controls) if p != "ERROR"]
# if filtered_pred_controls:
# control_non_empty_pred = [c != "" for c in filtered_pred_controls]
# control_non_empty_true = [c != "" for c in filtered_true_controls]
# control_existence_acc = calculate_accuracy(control_non_empty_pred, control_non_empty_true)
# print(f"<control> Tag Presence Accuracy (on successful samples): {control_existence_acc:.4f}")
# else:
# print("No valid <control> predictions to evaluate.")
# control_existence_acc = 0.0
# # c. 分类别 <answer> 准确率
# if filtered_true_answers:
# unique_labels = sorted(list(set(filtered_true_answers + filtered_pred_answers)))
# print("\nPer-class <answer> accuracy:")
# class_acc = {}
# for label in unique_labels:
# tp = sum(1 for p, t in zip(filtered_pred_answers, filtered_true_answers) if p == label and t == label)
# total_true = sum(1 for t in filtered_true_answers if t == label)
# class_acc[label] = tp / total_true if total_true > 0 else 0.0
# print(f" Accuracy for '{label}': {class_acc[label]:.4f} ({tp}/{total_true if total_true > 0 else 'N/A'})")
# # d. (可选) 文本相似度评估 (需要安装 nltk 或 rouge-score)
# # 示例使用 ROUGE (需要 pip install rouge)
# from rouge import Rouge
# rouge = Rouge()
# avg_rouge_scores = {'rouge-1': 0.0, 'rouge-2': 0.0, 'rouge-l': 0.0}
# valid_samples_for_rouge = 0
# for res in results:
# if not res['model_output_raw'].startswith("ERROR") and res['parsed_truth']['think'] and res['parsed_prediction']['think']:
# try:
# scores = rouge.get_scores(res['parsed_prediction']['think'], res['parsed_truth']['think'])
# for metric in avg_rouge_scores:
# avg_rouge_scores[metric] += scores[0][metric]['f']
# valid_samples_for_rouge += 1
# except Exception as e:
# print(f"ROUGE calculation error for sample {res['sample_id']}: {e}")
# if valid_samples_for_rouge > 0:
# for metric in avg_rouge_scores:
# avg_rouge_scores[metric] /= valid_samples_for_rouge
# print(f"\nAverage ROUGE Scores (on <think> tags, {valid_samples_for_rouge} valid samples):")
# for metric, score in avg_rouge_scores.items():
# print(f" {metric.upper()}: {score:.4f}")
# else:
# print("\nNo valid samples for ROUGE calculation on <think> tags.")
# # --- 7. 错误案例分析 ---
# print(f"\n--- Error Analysis ---")
# error_count = sum(1 for r in results if r['model_output_raw'].startswith("ERROR"))
# if error_count > 0:
# print(f"Number of samples with processing errors: {error_count}")
# # 可以在这里打印错误详情
# else:
# print("No processing errors detected during inference.")
# print("Samples where <answer> prediction was incorrect (excluding errors):")
# incorrect_count = 0
# for res in results:
# # 只分析成功处理且预测错误的样本
# if not res['model_output_raw'].startswith("ERROR") and \
# res['parsed_prediction']['answer'] != res['parsed_truth']['answer']:
# incorrect_count += 1
# if incorrect_count <= 5: # 只打印前5个错误案例
# print(f" Sample ID: {res['sample_id']}")
# print(f" Image: {res['image_path']}")
# print(f" Input: {res['input_text']}")
# print(f" Predicted Answer: '{res['parsed_prediction']['answer']}'")
# print(f" True Answer: '{res['parsed_truth']['answer']}'")
# print(f" Predicted Control: '{res['parsed_prediction']['control']}'")
# print(f" True Control: '{res['parsed_truth']['control']}'")
# # print(f" Predicted Think: '{res['parsed_prediction']['think']}'") # 可选
# # print(f" True Think: '{res['parsed_truth']['think']}'") # 可选
# print("-" * 20)
# if incorrect_count > 5:
# print(f"... and {incorrect_count - 5} more incorrect predictions.")
# elif incorrect_count == 0:
# print(" All successful predictions matched the ground truth <answer>.")
# # --- 8. 总结 ---
# print(f"\n--- Summary ---")
# print(f"Total samples processed: {total_samples}")
# print(f"Successfully processed samples: {successful_samples}")
# if filtered_pred_answers:
# print(f"<answer> Accuracy (successful samples): {answer_accuracy:.4f}")
# if filtered_pred_controls:
# print(f"<control> Presence Accuracy (successful samples): {control_existence_acc:.4f}")
# print("Per-class accuracies calculated above (if applicable).")
# print("Detailed results are available in the output file.")
# if __name__ == "__main__":
# main()
import json
import os
from typing import Dict, List, Any, Tuple
import re
from collections import defaultdict, Counter
import ast
def load_data(file_path: str) -> List[Dict]:
"""
Load data from JSON or JSONL file
"""
data = []
# Check file extension to determine format
if file_path.lower().endswith('.json'):
with open(file_path, 'r', encoding='utf-8') as f:
data = json.load(f)
elif file_path.lower().endswith('.jsonl'):
with open(file_path, 'r', encoding='utf-8') as f:
for line in f:
line = line.strip()
if line:
data.append(json.loads(line))
else:
# Try to auto-detect based on content
with open(file_path, 'r', encoding='utf-8') as f:
first_line = f.readline().strip()
f.seek(0)
if first_line.startswith('['): # JSON array
data = json.load(f)
else: # Assume JSONL
for line in f:
line = line.strip()
if line:
data.append(json.loads(line))
return data
def parse_think_content(think_str: str) -> Dict[str, str]:
"""
Parse <think> content to extract behavior description
"""
if not think_str:
return {"raw": "", "behavior": ""}
# Remove <think> tags and extract content
clean_str = re.sub(r'<think>|</think>', '', think_str).strip()
return {
"raw": clean_str,
"behavior": clean_str # For now, the behavior is the full content
}
def parse_control_content(control_str: str) -> Dict[str, Any]:
"""
Parse <control> content to extract control command and parameters
"""
if not control_str:
return {"raw": "", "command": "", "parameters": {}, "type": "none"}
clean_str = re.sub(r'<control>|</control>', '', control_str).strip()
# Extract command and parameters
command = clean_str
params = {}
control_type = "other"
if "(" in clean_str and ")" in clean_str:
# Pattern like: MonitorPassenger(SwellingDetected)
match = re.match(r'(\w+)\(([^)]+)\)', clean_str)
if match:
command = match.group(1)
param_str = match.group(2)
params = {"parameter": param_str}
if "Monitor" in command:
control_type = "monitoring"
elif "Alert" in command:
control_type = "alerting"
elif "set" in command:
control_type = "setting"
elif "|" in clean_str:
# Pattern like: setMute|false|
parts = clean_str.split("|")
command = parts[0] if parts else ""
params = {"params": parts[1:] if len(parts) > 1 else []}
control_type = "command"
else:
command = clean_str
control_type = "function"
return {
"raw": clean_str,
"command": command,
"parameters": params,
"type": control_type
}
def parse_answer_content(answer_str: str) -> Dict[str, str]:
"""
Parse <answer> content to extract the final answer
"""
if not answer_str:
return {"raw": "", "category": "", "description": ""}
clean_str = re.sub(r'<answer>|</answer>', '', answer_str).strip()
# Try to categorize the answer
category = "other"
if any(keyword in clean_str.lower() for keyword in ["swelling", "eye", "face", "facial"]):
category = "physical_symptom"
elif any(keyword in clean_str.lower() for keyword in ["sleep", "drowsy", "tired", "yawn"]):
category = "drowsiness"
elif any(keyword in clean_str.lower() for keyword in ["phone", "call", "text", "mobile"]):
category = "distraction"
elif any(keyword in clean_str.lower() for keyword in ["smoke", "cigarette"]):
category = "smoking"
elif any(keyword in clean_str.lower() for keyword in ["drunk", "alcohol", "intoxicated"]):
category = "intoxication"
elif any(keyword in clean_str.lower() for keyword in ["mouth", "corner", "slanting"]):
category = "facial_expression"
elif any(keyword in clean_str.lower() for keyword in ["head", "cover", "hold"]):
category = "head_behavior"
elif any(keyword in clean_str.lower() for keyword in ["arm", "hand", "slip", "droop"]):
category = "limb_behavior"
elif any(keyword in clean_str.lower() for keyword in ["radio", "adjust", "control"]):
category = "vehicle_control"
return {
"raw": clean_str,
"category": category,
"description": clean_str
}
def extract_all_components(text: str) -> Dict[str, str]:
"""
Extract think, control, and answer components from text
"""
components = {
"think": "",
"control": "",
"answer": ""
}
# Extract <think> content
think_match = re.search(r'<think>(.*?)</think>', text, re.DOTALL)
if think_match:
components["think"] = think_match.group(1).strip()
# Extract <control> content
control_match = re.search(r'<control>(.*?)</control>', text, re.DOTALL)
if control_match:
components["control"] = control_match.group(1).strip()
# Extract <answer> content
answer_match = re.search(r'<answer>(.*?)</answer>', text, re.DOTALL)
if answer_match:
components["answer"] = answer_match.group(1).strip()
return components
def calculate_component_accuracy(predicted_components: Dict, actual_components: Dict) -> Dict[str, float]:
"""
Calculate accuracy for each component
"""
accuracy = {}
# Think component accuracy
accuracy['think'] = calculate_similarity(
predicted_components.get('think', ''),
actual_components.get('think', '')
)
# Control component accuracy
accuracy['control'] = calculate_similarity(
predicted_components.get('control', ''),
actual_components.get('control', '')
)
# Answer component accuracy
accuracy['answer'] = calculate_similarity(
predicted_components.get('answer', ''),
actual_components.get('answer', '')
)
return accuracy
def calculate_similarity(str1: str, str2: str) -> float:
"""
Calculate similarity between two strings
"""
if not str1 and not str2:
return 1.0
if not str1 or not str2:
return 0.0
str1_lower = str1.lower().strip()
str2_lower = str2.lower().strip()
if str1_lower == str2_lower:
return 1.0
# Calculate word overlap
words1 = set(str1_lower.split())
words2 = set(str2_lower.split())
if len(words1) == 0 and len(words2) == 0:
return 1.0
if len(words1) == 0 or len(words2) == 0:
return 0.0
intersection = words1.intersection(words2)
union = words1.union(words2)
# Jaccard similarity
jaccard = len(intersection) / len(union) if union else 0
# Also consider sequence similarity for exact matches
if str1_lower in str2_lower or str2_lower in str1_lower:
return max(jaccard, 0.8)
return jaccard
def evaluate_component_quality(parsed_component: Dict, expected_component: Dict) -> Dict[str, float]:
"""
Evaluate the quality of component parsing and prediction
"""
quality = {}
if parsed_component.get('type') == expected_component.get('type'):
quality['type_match'] = 1.0
else:
quality['type_match'] = 0.0
# Evaluate content quality based on component type
if parsed_component.get('type') == 'monitoring':
quality['content_quality'] = 1.0 if 'Monitor' in parsed_component.get('command', '') else 0.0
elif parsed_component.get('type') == 'alerting':
quality['content_quality'] = 1.0 if 'Alert' in parsed_component.get('command', '') else 0.0
else:
quality['content_quality'] = 0.5 # Default medium quality
return quality
def comprehensive_evaluation(data: List[Dict]) -> Dict[str, Any]:
"""
Comprehensive evaluation of all three components
"""
total_samples = len(data)
results = {
'overall_metrics': {},
'component_wise_metrics': {
'think': {'accuracy_scores': [], 'quality_scores': []},
'control': {'accuracy_scores': [], 'quality_scores': []},
'answer': {'accuracy_scores': [], 'quality_scores': []}
},
'detailed_analysis': [],
'error_patterns': {
'think_errors': [],
'control_errors': [],
'answer_errors': []
}
}
for idx, sample in enumerate(data):
# Extract components from response and labels
response_components = extract_all_components(sample.get('response', ''))
label_components = extract_all_components(sample.get('labels', ''))
# Parse components for deeper analysis
parsed_think = parse_think_content(response_components['think'])
parsed_control = parse_control_content(response_components['control'])
parsed_answer = parse_answer_content(response_components['answer'])
actual_think = parse_think_content(label_components['think'])
actual_control = parse_control_content(label_components['control'])
actual_answer = parse_answer_content(label_components['answer'])
# Calculate component-wise accuracy
component_accuracy = calculate_component_accuracy(response_components, label_components)
# Calculate component quality
think_quality = evaluate_component_quality(parsed_think, actual_think)
control_quality = evaluate_component_quality(parsed_control, actual_control)
answer_quality = evaluate_component_quality(parsed_answer, actual_answer)
# Store component-wise metrics
for comp in ['think', 'control', 'answer']:
results['component_wise_metrics'][comp]['accuracy_scores'].append(component_accuracy[comp])
results['component_wise_metrics'][comp]['quality_scores'].append(
think_quality.get('content_quality', 0.5) if comp == 'think' else
control_quality.get('content_quality', 0.5) if comp == 'control' else
answer_quality.get('content_quality', 0.5)
)
# Store detailed analysis
detailed_result = {
'index': idx,
'response_components': response_components,
'label_components': label_components,
'parsed_response': {
'think': parsed_think,
'control': parsed_control,
'answer': parsed_answer
},
'parsed_labels': {
'think': actual_think,
'control': actual_control,
'answer': actual_answer
},
'component_accuracy': component_accuracy,
'component_quality': {
'think': think_quality,
'control': control_quality,
'answer': answer_quality
},
'overall_score': sum(component_accuracy.values()) / 3 if component_accuracy else 0
}
results['detailed_analysis'].append(detailed_result)
# Analyze errors
if component_accuracy['think'] < 0.5:
results['error_patterns']['think_errors'].append(idx)
if component_accuracy['control'] < 0.5:
results['error_patterns']['control_errors'].append(idx)
if component_accuracy['answer'] < 0.5:
results['error_patterns']['answer_errors'].append(idx)
# Calculate overall metrics
overall_metrics = {}
for comp in ['think', 'control', 'answer']:
acc_scores = results['component_wise_metrics'][comp]['accuracy_scores']
qual_scores = results['component_wise_metrics'][comp]['quality_scores']
overall_metrics[f'{comp}_avg_accuracy'] = sum(acc_scores) / len(acc_scores) if acc_scores else 0
overall_metrics[f'{comp}_avg_quality'] = sum(qual_scores) / len(qual_scores) if qual_scores else 0
overall_metrics[f'{comp}_std_accuracy'] = (
sum((x - overall_metrics[f'{comp}_avg_accuracy'])**2 for x in acc_scores) / len(acc_scores)
)**0.5 if acc_scores else 0
# Calculate overall system performance
overall_metrics['total_samples'] = total_samples
overall_metrics['avg_overall_score'] = sum(
d['overall_score'] for d in results['detailed_analysis']
) / total_samples if total_samples > 0 else 0
results['overall_metrics'] = overall_metrics
return results
def generate_evaluation_report(results: Dict[str, Any]) -> str:
"""
Generate comprehensive evaluation report
"""
report = []
report.append("="*100)
report.append("COMPREHENSIVE EVALUATION OF IN-VEHICLE MULTIMODAL AI MODEL")
report.append("="*100)
metrics = results['overall_metrics']
report.append(f"\n📊 OVERALL SYSTEM PERFORMANCE:")
report.append(f" Total Samples: {metrics['total_samples']}")
report.append(f" Average Overall Score: {metrics['avg_overall_score']:.4f}")
report.append(f"\n🔍 COMPONENT-WISE PERFORMANCE:")
for comp in ['think', 'control', 'answer']:
avg_acc = metrics.get(f'{comp}_avg_accuracy', 0)
avg_qual = metrics.get(f'{comp}_avg_quality', 0)
std_acc = metrics.get(f'{comp}_std_accuracy', 0)
report.append(f" {comp.upper()}:")
report.append(f" Average Accuracy: {avg_acc:.4f}")
report.append(f" Average Quality: {avg_qual:.4f}")
report.append(f" Std Deviation: {std_acc:.4f}")
# Error analysis
error_patterns = results['error_patterns']
report.append(f"\n❌ ERROR ANALYSIS:")
report.append(f" Think component errors: {len(error_patterns['think_errors'])} samples")
report.append(f" Control component errors: {len(error_patterns['control_errors'])} samples")
report.append(f" Answer component errors: {len(error_patterns['answer_errors'])} samples")
# Sample error analysis
if results['detailed_analysis']:
sample_analysis = results['detailed_analysis'][0] # Show first sample as example
report.append(f"\n📋 SAMPLE ANALYSIS (First Sample):")
report.append(f" Think Accuracy: {sample_analysis['component_accuracy']['think']:.4f}")
report.append(f" Control Accuracy: {sample_analysis['component_accuracy']['control']:.4f}")
report.append(f" Answer Accuracy: {sample_analysis['component_accuracy']['answer']:.4f}")
report.append(f" Overall Score: {sample_analysis['overall_score']:.4f}")
# Component type analysis
report.append(f"\n🔧 COMPONENT TYPE ANALYSIS:")
# Analyze control command types
control_types = []
for analysis in results['detailed_analysis']:
control_type = analysis['parsed_response']['control'].get('type', 'unknown')
control_types.append(control_type)
type_counts = Counter(control_types)
report.append(" Control Command Types:")
for control_type, count in type_counts.most_common():
report.append(f" {control_type}: {count} samples")
# Answer category analysis
answer_categories = []
for analysis in results['detailed_analysis']:
answer_category = analysis['parsed_response']['answer'].get('category', 'unknown')
answer_categories.append(answer_category)
category_counts = Counter(answer_categories)
report.append(" Answer Categories:")
for category, count in category_counts.most_common():
report.append(f" {category}: {count} samples")
report.append(f"\n🎯 RECOMMENDATIONS:")
if metrics.get('think_avg_accuracy', 0) < 0.7:
report.append(" - Improve think component (behavior analysis)")
if metrics.get('control_avg_accuracy', 0) < 0.7:
report.append(" - Improve control component (command generation)")
if metrics.get('answer_avg_accuracy', 0) < 0.7:
report.append(" - Improve answer component (final classification)")
return "\n".join(report)
def save_evaluation_results(results: Dict[str, Any], output_path: str):
"""
Save evaluation results to JSON file
"""
with open(output_path, 'w', encoding='utf-8') as f:
json.dump(results, f, ensure_ascii=False, indent=2)
def main(input_file: str, output_file: str = None):
"""
Main function to perform comprehensive evaluation
"""
print(f"Loading data from: {input_file}")
# Load data
data = load_data(input_file)
print(f"Loaded {len(data)} samples")
# Perform comprehensive evaluation
print("Performing comprehensive evaluation...")
results = comprehensive_evaluation(data)
# Generate and print report
report = generate_evaluation_report(results)
print(report)
# Save results if output path provided
if output_file:
save_evaluation_results(results, output_file)
print(f"\nDetailed evaluation results saved to: {output_file}")
return results
if __name__ == "__main__":
import sys
# if len(sys.argv) < 2:
# print("Usage: python comprehensive_evaluation.py <input_file> [output_file]")
# print(" input_file: Path to JSON or JSONL file containing model predictions")
# print(" output_file: Optional path to save detailed evaluation results")
# sys.exit(1)
# input_file = sys.argv[1]
# output_file = sys.argv[2] if len(sys.argv) > 2 else None
input_file = r"/data/LLM-SFT/SFT_Output/multiclsTask/Qwen2.5-VL-3B-Instruct/v0-20251123-182828/checkpoint-264/infer_result/20251124-175009.jsonl"
output_file = r"/data/LLM-SFT/SFT_Output/multiclsTask/Qwen2.5-VL-3B-Instruct/v0-20251123-182828/checkpoint-264/eval/20251124-175009.jsonl"
results = main(input_file, output_file)