| # import json | |
| # import re | |
| # from PIL import Image | |
| # from transformers import AutoModelForVision2Seq, AutoProcessor | |
| # import torch | |
| # import os | |
| # from qwen_vl_utils import process_vision_info | |
| # # --- 1. 辅助函数 --- | |
| # def load_test_data(file_path): | |
| # """ | |
| # 根据文件扩展名自动加载 .json 或 .jsonl 文件。 | |
| # 对于 .json 文件,尝试不同的常见键来查找样本列表。 | |
| # """ | |
| # _, ext = os.path.splitext(file_path) | |
| # ext = ext.lower() | |
| # test_samples = [] | |
| # if ext == '.jsonl': | |
| # print(f"Loading data from JSON Lines file: {file_path}") | |
| # with open(file_path, 'r', encoding='utf-8') as f: | |
| # for i, line in enumerate(f): | |
| # line = line.strip() | |
| # if not line: | |
| # continue | |
| # try: | |
| # test_samples.append(json.loads(line)) | |
| # except json.JSONDecodeError as e: | |
| # print(f"Warning: Skipping invalid JSON line {i+1} in {file_path}: {e}") | |
| # elif ext == '.json': | |
| # print(f"Loading data from JSON file: {file_path}") | |
| # try: | |
| # with open(file_path, 'r', encoding='utf-8') as f: | |
| # data = json.load(f) | |
| # if isinstance(data, list): | |
| # print(" Detected JSON array format.") | |
| # test_samples = data | |
| # elif isinstance(data, dict): | |
| # print(" Detected JSON object format. Searching for samples...") | |
| # possible_keys = ['data', 'samples', 'instances', 'items', 'conversations', 'messages'] | |
| # found = False | |
| # for key in possible_keys: | |
| # if key in data and isinstance(data[key], list) and len(data[key]) > 0: | |
| # # 简单检查列表第一个元素是否像样本 (dict with 'messages') | |
| # first_item = data[key][0] | |
| # if isinstance(first_item, dict) and 'messages' in first_item: | |
| # print(f" Found samples under key '{key}'.") | |
| # test_samples = data[key] | |
| # found = True | |
| # break | |
| # if not found: | |
| # # 启发式:查找第一个值是列表且列表元素是字典的键 | |
| # for key, value in data.items(): | |
| # if isinstance(value, list) and len(value) > 0 and isinstance(value[0], dict) and 'messages' in value[0]: | |
| # print(f" Found samples under key '{key}' (heuristic).") | |
| # test_samples = value | |
| # found = True | |
| # break | |
| # if not found: | |
| # print(f" Error: Could not find a list of samples in the JSON object. Keys found: {list(data.keys())}") | |
| # else: | |
| # print(f" Error: Unexpected JSON structure. Root element type: {type(data)}") | |
| # except json.JSONDecodeError as e: | |
| # print(f"Error: Failed to decode JSON from {file_path}: {e}") | |
| # except Exception as e: | |
| # print(f"Error: An unexpected error occurred while loading {file_path}: {e}") | |
| # else: | |
| # print(f"Error: Unsupported file extension '{ext}'. Please provide a .json or .jsonl file.") | |
| # print(f"Loaded {len(test_samples)} samples.") | |
| # # 验证加载的样本结构 | |
| # if test_samples and isinstance(test_samples, list): | |
| # print("Performing basic structure validation on loaded samples...") | |
| # sample_count_to_check = min(5, len(test_samples)) | |
| # for i in range(sample_count_to_check): | |
| # s = test_samples[i] | |
| # if not isinstance(s, dict): | |
| # print(f" CRITICAL: Sample {i} is not a dict. Type: {type(s)}") | |
| # # 可以选择在这里中断或清理数据 | |
| # # return [] | |
| # elif 'messages' not in s or 'images' not in s: | |
| # print(f" WARNING: Sample {i} might be missing 'messages' or 'images' keys. Found keys: {list(s.keys())}") | |
| # else: | |
| # if not isinstance(s['messages'], list): | |
| # print(f" CRITICAL: Sample {i} 'messages' is not a list. Type: {type(s['messages'])}") | |
| # if not isinstance(s['images'], list): | |
| # print(f" CRITICAL: Sample {i} 'images' is not a list. Type: {type(s['images'])}") | |
| # print("Structure validation complete.") | |
| # elif test_samples: | |
| # print(f"CRITICAL: Expected test_samples to be a list after loading, got {type(test_samples)}.") | |
| # test_samples = [] # Reset to empty list on critical error | |
| # return test_samples | |
| # def extract_components(text): | |
| # """从模型输出或标签中提取 <think>, <control>, <answer> 组件""" | |
| # think_match = re.search(r'<think>(.*?)</think>', text, re.DOTALL) | |
| # control_match = re.search(r'<control>(.*?)</control>', text) | |
| # answer_match = re.search(r'<answer>(.*?)</answer>', text) | |
| # return { | |
| # 'think': think_match.group(1).strip() if think_match else "", | |
| # 'control': control_match.group(1).strip() if control_match else "", | |
| # 'answer': answer_match.group(1).strip() if answer_match else "" | |
| # } | |
| # def calculate_accuracy(pred_list, true_list): | |
| # """计算准确率 (用于 <answer>)""" | |
| # if len(pred_list) != len(true_list): | |
| # raise ValueError("Prediction and truth lists must have the same length for accuracy calculation.") | |
| # if not pred_list: | |
| # return 0.0 | |
| # correct = sum(p == t for p, t in zip(pred_list, true_list)) | |
| # return correct / len(pred_list) | |
| # # --- 2. 主评估逻辑 --- | |
| # def main(): | |
| # # --- 配置 --- | |
| # # 替换为您的模型路径 | |
| # model_path = "/data/LLM-SFT/SFT_Output/multiclsTask/Qwen2.5-VL-3B-Instruct/SFT/checkpoint-894" | |
| # # 替换为您的测试集路径 (.json 或 .jsonl) | |
| # test_data_path = "/data/LLM-SFT/datasets/driver_behavior_datasets/output_test.jsonl" | |
| # output_file = model_path + "/eval/detailed_model_evaluation_results.json" | |
| # # --- 加载模型和处理器 --- | |
| # print("Loading model and processor...") | |
| # try: | |
| # processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True) | |
| # model = AutoModelForVision2Seq.from_pretrained( | |
| # model_path, | |
| # trust_remote_code=True, | |
| # torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32 | |
| # ) | |
| # model.eval() | |
| # if torch.cuda.is_available(): | |
| # model = model.to('cuda') | |
| # print("Model loaded on GPU.") | |
| # else: | |
| # print("Model loaded on CPU.") | |
| # except Exception as e: | |
| # print(f"Failed to load model/processor: {e}") | |
| # return # Exit if model loading fails | |
| # # --- 加载测试数据 --- | |
| # try: | |
| # test_samples = load_test_data(test_data_path) | |
| # # print('test_samples',test_samples) | |
| # if not test_samples: | |
| # print("No samples loaded. Exiting.") | |
| # return | |
| # except Exception as e: | |
| # print(f"Failed to load test data: {e}") | |
| # return | |
| # # --- 推理和收集结果 (带解析) --- | |
| # results = [] | |
| # pred_answers = [] | |
| # true_answers = [] | |
| # pred_controls = [] # 存储 control 字符串用于后续分析 | |
| # true_controls = [] | |
| # print("Starting inference...") | |
| # for i, sample in enumerate(test_samples): | |
| # try: | |
| # conversation = sample['messages'] | |
| # image_path = sample['images'][0] | |
| # if not os.path.exists(image_path): | |
| # print(f"Warning: Image not found: {image_path}. Skipping sample {i}.") | |
| # # 为保持列表对齐,添加空占位符 | |
| # pred_answers.append("") | |
| # true_answers.append(extract_components(conversation[-1]['content'])['answer']) | |
| # pred_controls.append("") | |
| # true_controls.append(extract_components(conversation[-1]['content'])['control']) | |
| # continue | |
| # image = Image.open(image_path).convert('RGB') | |
| # # 准备输入 | |
| # # 注意:Qwen VL 系列通常期望 messages 是一个列表,其中包含 role 和 content | |
| # # processor 会处理 <image> token 和图像的对齐 | |
| # # print('conversation[:-1]',conversation[:-1]) | |
| # texts = processor.apply_chat_template(conversation[:-1], tokenize=False, add_generation_prompt=True) | |
| # image_inputs, video_inputs = process_vision_info(conversation[:-1]) | |
| # inputs = processor( | |
| # text=texts, | |
| # images=image_inputs, | |
| # videos=video_inputs, | |
| # padding=True, | |
| # return_tensors="pt", | |
| # ) | |
| # if torch.cuda.is_available(): | |
| # inputs = {k: v.to('cuda') for k, v in inputs.items()} | |
| # # 生成 | |
| # with torch.no_grad(): | |
| # generated_ids = model.generate(**inputs, | |
| # max_new_tokens=200, | |
| # # num_beams=5, | |
| # do_sample=True, | |
| # top_p=0.75, | |
| # top_k=50, | |
| # temperature=0.2 | |
| # # repetition_penalty=1.2, | |
| # # early_stopping=True | |
| # ) | |
| # generated_ids_trimmed = [ | |
| # out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) | |
| # ] | |
| # output_text = processor.batch_decode( | |
| # generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False | |
| # ) | |
| # # 提取标签 | |
| # ground_truth = conversation[-1]['content'] | |
| # # 解析模型输出和标签 | |
| # pred_components = extract_components(output_text) | |
| # true_components = extract_components(ground_truth) | |
| # # 收集用于评估的数据 | |
| # pred_answers.append(pred_components['answer']) | |
| # true_answers.append(true_components['answer']) | |
| # pred_controls.append(pred_components['control']) | |
| # true_controls.append(true_components['control']) | |
| # # 打印部分样本进行观察 | |
| # if i < 3: # 打印前3个样本 | |
| # print(f"\n--- Sample {i+1} ---") | |
| # print(f" Image: {image_path}") | |
| # print(f" Input Text: {input_text}") | |
| # print(f" Full Decoded Output: {decoded_output}") | |
| # print(f" Processed Output Text: {output_text}") | |
| # print(f" Parsed Prediction: {pred_components}") | |
| # print(f" Ground Truth: {ground_truth}") | |
| # print(f" Parsed Truth: {true_components}") | |
| # # 存储详细结果 | |
| # results.append({ | |
| # "sample_id": i, | |
| # "image_path": image_path, | |
| # "input_text": input_text, | |
| # "model_output_raw": decoded_output, | |
| # "model_output_processed": output_text, | |
| # "parsed_prediction": pred_components, | |
| # "ground_truth_raw": ground_truth, | |
| # "parsed_truth": true_components | |
| # }) | |
| # except Exception as e: | |
| # print(f"Error processing sample {i}: {e}") | |
| # # 错误样本也计入评估列表,但标记为空或错误 | |
| # pred_answers.append("ERROR") | |
| # true_answers.append(extract_components(conversation[-1]['content'])['answer'] if 'conversation' in locals() else "") | |
| # pred_controls.append("ERROR") | |
| # true_controls.append(extract_components(conversation[-1]['content'])['control'] if 'conversation' in locals() else "") | |
| # results.append({ | |
| # "sample_id": i, | |
| # "image_path": image_path if 'image_path' in locals() else "N/A", | |
| # "input_text": conversation[-2]['content'] if 'conversation' in locals() else "N/A", | |
| # "model_output_raw": f"ERROR: {e}", | |
| # "model_output_processed": f"ERROR: {e}", | |
| # "parsed_prediction": {"think": "", "control": "", "answer": "ERROR"}, | |
| # "ground_truth_raw": conversation[-1]['content'] if 'conversation' in locals() else "N/A", | |
| # "parsed_truth": extract_components(conversation[-1]['content']) if 'conversation' in locals() else {"think": "", "control": "", "answer": ""} | |
| # }) | |
| # # --- 保存详细结果 --- | |
| # with open(output_file, 'w', encoding='utf-8') as f: | |
| # json.dump(results, f, indent=2, ensure_ascii=False) | |
| # print(f"\nDetailed results saved to {output_file}") | |
| # # --- 深入定量评估 --- | |
| # print(f"\n--- Quantitative Evaluation ---") | |
| # total_samples = len(test_samples) | |
| # successful_samples = len([r for r in results if not r['model_output_raw'].startswith("ERROR")]) | |
| # print(f"Total samples: {total_samples}, Successfully processed: {successful_samples}") | |
| # if successful_samples == 0: | |
| # print("No samples were processed successfully. Skipping quantitative evaluation.") | |
| # return | |
| # # a. <answer> 标签准确率 (仅计算成功处理的样本) | |
| # # 过滤掉错误样本 | |
| # filtered_pred_answers = [p for p in pred_answers if p != "ERROR"] | |
| # filtered_true_answers = [t for p, t in zip(pred_answers, true_answers) if p != "ERROR"] | |
| # if filtered_pred_answers: | |
| # answer_accuracy = calculate_accuracy(filtered_pred_answers, filtered_true_answers) | |
| # print(f"<answer> Tag Accuracy (on successful samples): {answer_accuracy:.4f} ({sum(p==t for p,t in zip(filtered_pred_answers, filtered_true_answers))}/{len(filtered_true_answers)})") | |
| # else: | |
| # print("No valid <answer> predictions to evaluate.") | |
| # answer_accuracy = 0.0 | |
| # # b. <control> 指令分析 | |
| # filtered_pred_controls = [c for p, c in zip(pred_answers, pred_controls) if p != "ERROR"] | |
| # filtered_true_controls = [t for p, t in zip(pred_answers, true_controls) if p != "ERROR"] | |
| # if filtered_pred_controls: | |
| # control_non_empty_pred = [c != "" for c in filtered_pred_controls] | |
| # control_non_empty_true = [c != "" for c in filtered_true_controls] | |
| # control_existence_acc = calculate_accuracy(control_non_empty_pred, control_non_empty_true) | |
| # print(f"<control> Tag Presence Accuracy (on successful samples): {control_existence_acc:.4f}") | |
| # else: | |
| # print("No valid <control> predictions to evaluate.") | |
| # control_existence_acc = 0.0 | |
| # # c. 分类别 <answer> 准确率 | |
| # if filtered_true_answers: | |
| # unique_labels = sorted(list(set(filtered_true_answers + filtered_pred_answers))) | |
| # print("\nPer-class <answer> accuracy:") | |
| # class_acc = {} | |
| # for label in unique_labels: | |
| # tp = sum(1 for p, t in zip(filtered_pred_answers, filtered_true_answers) if p == label and t == label) | |
| # total_true = sum(1 for t in filtered_true_answers if t == label) | |
| # class_acc[label] = tp / total_true if total_true > 0 else 0.0 | |
| # print(f" Accuracy for '{label}': {class_acc[label]:.4f} ({tp}/{total_true if total_true > 0 else 'N/A'})") | |
| # # d. (可选) 文本相似度评估 (需要安装 nltk 或 rouge-score) | |
| # # 示例使用 ROUGE (需要 pip install rouge) | |
| # from rouge import Rouge | |
| # rouge = Rouge() | |
| # avg_rouge_scores = {'rouge-1': 0.0, 'rouge-2': 0.0, 'rouge-l': 0.0} | |
| # valid_samples_for_rouge = 0 | |
| # for res in results: | |
| # if not res['model_output_raw'].startswith("ERROR") and res['parsed_truth']['think'] and res['parsed_prediction']['think']: | |
| # try: | |
| # scores = rouge.get_scores(res['parsed_prediction']['think'], res['parsed_truth']['think']) | |
| # for metric in avg_rouge_scores: | |
| # avg_rouge_scores[metric] += scores[0][metric]['f'] | |
| # valid_samples_for_rouge += 1 | |
| # except Exception as e: | |
| # print(f"ROUGE calculation error for sample {res['sample_id']}: {e}") | |
| # if valid_samples_for_rouge > 0: | |
| # for metric in avg_rouge_scores: | |
| # avg_rouge_scores[metric] /= valid_samples_for_rouge | |
| # print(f"\nAverage ROUGE Scores (on <think> tags, {valid_samples_for_rouge} valid samples):") | |
| # for metric, score in avg_rouge_scores.items(): | |
| # print(f" {metric.upper()}: {score:.4f}") | |
| # else: | |
| # print("\nNo valid samples for ROUGE calculation on <think> tags.") | |
| # # --- 7. 错误案例分析 --- | |
| # print(f"\n--- Error Analysis ---") | |
| # error_count = sum(1 for r in results if r['model_output_raw'].startswith("ERROR")) | |
| # if error_count > 0: | |
| # print(f"Number of samples with processing errors: {error_count}") | |
| # # 可以在这里打印错误详情 | |
| # else: | |
| # print("No processing errors detected during inference.") | |
| # print("Samples where <answer> prediction was incorrect (excluding errors):") | |
| # incorrect_count = 0 | |
| # for res in results: | |
| # # 只分析成功处理且预测错误的样本 | |
| # if not res['model_output_raw'].startswith("ERROR") and \ | |
| # res['parsed_prediction']['answer'] != res['parsed_truth']['answer']: | |
| # incorrect_count += 1 | |
| # if incorrect_count <= 5: # 只打印前5个错误案例 | |
| # print(f" Sample ID: {res['sample_id']}") | |
| # print(f" Image: {res['image_path']}") | |
| # print(f" Input: {res['input_text']}") | |
| # print(f" Predicted Answer: '{res['parsed_prediction']['answer']}'") | |
| # print(f" True Answer: '{res['parsed_truth']['answer']}'") | |
| # print(f" Predicted Control: '{res['parsed_prediction']['control']}'") | |
| # print(f" True Control: '{res['parsed_truth']['control']}'") | |
| # # print(f" Predicted Think: '{res['parsed_prediction']['think']}'") # 可选 | |
| # # print(f" True Think: '{res['parsed_truth']['think']}'") # 可选 | |
| # print("-" * 20) | |
| # if incorrect_count > 5: | |
| # print(f"... and {incorrect_count - 5} more incorrect predictions.") | |
| # elif incorrect_count == 0: | |
| # print(" All successful predictions matched the ground truth <answer>.") | |
| # # --- 8. 总结 --- | |
| # print(f"\n--- Summary ---") | |
| # print(f"Total samples processed: {total_samples}") | |
| # print(f"Successfully processed samples: {successful_samples}") | |
| # if filtered_pred_answers: | |
| # print(f"<answer> Accuracy (successful samples): {answer_accuracy:.4f}") | |
| # if filtered_pred_controls: | |
| # print(f"<control> Presence Accuracy (successful samples): {control_existence_acc:.4f}") | |
| # print("Per-class accuracies calculated above (if applicable).") | |
| # print("Detailed results are available in the output file.") | |
| # if __name__ == "__main__": | |
| # main() | |
| import json | |
| import os | |
| from typing import Dict, List, Any, Tuple | |
| import re | |
| from collections import defaultdict, Counter | |
| import ast | |
| def load_data(file_path: str) -> List[Dict]: | |
| """ | |
| Load data from JSON or JSONL file | |
| """ | |
| data = [] | |
| # Check file extension to determine format | |
| if file_path.lower().endswith('.json'): | |
| with open(file_path, 'r', encoding='utf-8') as f: | |
| data = json.load(f) | |
| elif file_path.lower().endswith('.jsonl'): | |
| with open(file_path, 'r', encoding='utf-8') as f: | |
| for line in f: | |
| line = line.strip() | |
| if line: | |
| data.append(json.loads(line)) | |
| else: | |
| # Try to auto-detect based on content | |
| with open(file_path, 'r', encoding='utf-8') as f: | |
| first_line = f.readline().strip() | |
| f.seek(0) | |
| if first_line.startswith('['): # JSON array | |
| data = json.load(f) | |
| else: # Assume JSONL | |
| for line in f: | |
| line = line.strip() | |
| if line: | |
| data.append(json.loads(line)) | |
| return data | |
| def parse_think_content(think_str: str) -> Dict[str, str]: | |
| """ | |
| Parse <think> content to extract behavior description | |
| """ | |
| if not think_str: | |
| return {"raw": "", "behavior": ""} | |
| # Remove <think> tags and extract content | |
| clean_str = re.sub(r'<think>|</think>', '', think_str).strip() | |
| return { | |
| "raw": clean_str, | |
| "behavior": clean_str # For now, the behavior is the full content | |
| } | |
| def parse_control_content(control_str: str) -> Dict[str, Any]: | |
| """ | |
| Parse <control> content to extract control command and parameters | |
| """ | |
| if not control_str: | |
| return {"raw": "", "command": "", "parameters": {}, "type": "none"} | |
| clean_str = re.sub(r'<control>|</control>', '', control_str).strip() | |
| # Extract command and parameters | |
| command = clean_str | |
| params = {} | |
| control_type = "other" | |
| if "(" in clean_str and ")" in clean_str: | |
| # Pattern like: MonitorPassenger(SwellingDetected) | |
| match = re.match(r'(\w+)\(([^)]+)\)', clean_str) | |
| if match: | |
| command = match.group(1) | |
| param_str = match.group(2) | |
| params = {"parameter": param_str} | |
| if "Monitor" in command: | |
| control_type = "monitoring" | |
| elif "Alert" in command: | |
| control_type = "alerting" | |
| elif "set" in command: | |
| control_type = "setting" | |
| elif "|" in clean_str: | |
| # Pattern like: setMute|false| | |
| parts = clean_str.split("|") | |
| command = parts[0] if parts else "" | |
| params = {"params": parts[1:] if len(parts) > 1 else []} | |
| control_type = "command" | |
| else: | |
| command = clean_str | |
| control_type = "function" | |
| return { | |
| "raw": clean_str, | |
| "command": command, | |
| "parameters": params, | |
| "type": control_type | |
| } | |
| def parse_answer_content(answer_str: str) -> Dict[str, str]: | |
| """ | |
| Parse <answer> content to extract the final answer | |
| """ | |
| if not answer_str: | |
| return {"raw": "", "category": "", "description": ""} | |
| clean_str = re.sub(r'<answer>|</answer>', '', answer_str).strip() | |
| # Try to categorize the answer | |
| category = "other" | |
| if any(keyword in clean_str.lower() for keyword in ["swelling", "eye", "face", "facial"]): | |
| category = "physical_symptom" | |
| elif any(keyword in clean_str.lower() for keyword in ["sleep", "drowsy", "tired", "yawn"]): | |
| category = "drowsiness" | |
| elif any(keyword in clean_str.lower() for keyword in ["phone", "call", "text", "mobile"]): | |
| category = "distraction" | |
| elif any(keyword in clean_str.lower() for keyword in ["smoke", "cigarette"]): | |
| category = "smoking" | |
| elif any(keyword in clean_str.lower() for keyword in ["drunk", "alcohol", "intoxicated"]): | |
| category = "intoxication" | |
| elif any(keyword in clean_str.lower() for keyword in ["mouth", "corner", "slanting"]): | |
| category = "facial_expression" | |
| elif any(keyword in clean_str.lower() for keyword in ["head", "cover", "hold"]): | |
| category = "head_behavior" | |
| elif any(keyword in clean_str.lower() for keyword in ["arm", "hand", "slip", "droop"]): | |
| category = "limb_behavior" | |
| elif any(keyword in clean_str.lower() for keyword in ["radio", "adjust", "control"]): | |
| category = "vehicle_control" | |
| return { | |
| "raw": clean_str, | |
| "category": category, | |
| "description": clean_str | |
| } | |
| def extract_all_components(text: str) -> Dict[str, str]: | |
| """ | |
| Extract think, control, and answer components from text | |
| """ | |
| components = { | |
| "think": "", | |
| "control": "", | |
| "answer": "" | |
| } | |
| # Extract <think> content | |
| think_match = re.search(r'<think>(.*?)</think>', text, re.DOTALL) | |
| if think_match: | |
| components["think"] = think_match.group(1).strip() | |
| # Extract <control> content | |
| control_match = re.search(r'<control>(.*?)</control>', text, re.DOTALL) | |
| if control_match: | |
| components["control"] = control_match.group(1).strip() | |
| # Extract <answer> content | |
| answer_match = re.search(r'<answer>(.*?)</answer>', text, re.DOTALL) | |
| if answer_match: | |
| components["answer"] = answer_match.group(1).strip() | |
| return components | |
| def calculate_component_accuracy(predicted_components: Dict, actual_components: Dict) -> Dict[str, float]: | |
| """ | |
| Calculate accuracy for each component | |
| """ | |
| accuracy = {} | |
| # Think component accuracy | |
| accuracy['think'] = calculate_similarity( | |
| predicted_components.get('think', ''), | |
| actual_components.get('think', '') | |
| ) | |
| # Control component accuracy | |
| accuracy['control'] = calculate_similarity( | |
| predicted_components.get('control', ''), | |
| actual_components.get('control', '') | |
| ) | |
| # Answer component accuracy | |
| accuracy['answer'] = calculate_similarity( | |
| predicted_components.get('answer', ''), | |
| actual_components.get('answer', '') | |
| ) | |
| return accuracy | |
| def calculate_similarity(str1: str, str2: str) -> float: | |
| """ | |
| Calculate similarity between two strings | |
| """ | |
| if not str1 and not str2: | |
| return 1.0 | |
| if not str1 or not str2: | |
| return 0.0 | |
| str1_lower = str1.lower().strip() | |
| str2_lower = str2.lower().strip() | |
| if str1_lower == str2_lower: | |
| return 1.0 | |
| # Calculate word overlap | |
| words1 = set(str1_lower.split()) | |
| words2 = set(str2_lower.split()) | |
| if len(words1) == 0 and len(words2) == 0: | |
| return 1.0 | |
| if len(words1) == 0 or len(words2) == 0: | |
| return 0.0 | |
| intersection = words1.intersection(words2) | |
| union = words1.union(words2) | |
| # Jaccard similarity | |
| jaccard = len(intersection) / len(union) if union else 0 | |
| # Also consider sequence similarity for exact matches | |
| if str1_lower in str2_lower or str2_lower in str1_lower: | |
| return max(jaccard, 0.8) | |
| return jaccard | |
| def evaluate_component_quality(parsed_component: Dict, expected_component: Dict) -> Dict[str, float]: | |
| """ | |
| Evaluate the quality of component parsing and prediction | |
| """ | |
| quality = {} | |
| if parsed_component.get('type') == expected_component.get('type'): | |
| quality['type_match'] = 1.0 | |
| else: | |
| quality['type_match'] = 0.0 | |
| # Evaluate content quality based on component type | |
| if parsed_component.get('type') == 'monitoring': | |
| quality['content_quality'] = 1.0 if 'Monitor' in parsed_component.get('command', '') else 0.0 | |
| elif parsed_component.get('type') == 'alerting': | |
| quality['content_quality'] = 1.0 if 'Alert' in parsed_component.get('command', '') else 0.0 | |
| else: | |
| quality['content_quality'] = 0.5 # Default medium quality | |
| return quality | |
| def comprehensive_evaluation(data: List[Dict]) -> Dict[str, Any]: | |
| """ | |
| Comprehensive evaluation of all three components | |
| """ | |
| total_samples = len(data) | |
| results = { | |
| 'overall_metrics': {}, | |
| 'component_wise_metrics': { | |
| 'think': {'accuracy_scores': [], 'quality_scores': []}, | |
| 'control': {'accuracy_scores': [], 'quality_scores': []}, | |
| 'answer': {'accuracy_scores': [], 'quality_scores': []} | |
| }, | |
| 'detailed_analysis': [], | |
| 'error_patterns': { | |
| 'think_errors': [], | |
| 'control_errors': [], | |
| 'answer_errors': [] | |
| } | |
| } | |
| for idx, sample in enumerate(data): | |
| # Extract components from response and labels | |
| response_components = extract_all_components(sample.get('response', '')) | |
| label_components = extract_all_components(sample.get('labels', '')) | |
| # Parse components for deeper analysis | |
| parsed_think = parse_think_content(response_components['think']) | |
| parsed_control = parse_control_content(response_components['control']) | |
| parsed_answer = parse_answer_content(response_components['answer']) | |
| actual_think = parse_think_content(label_components['think']) | |
| actual_control = parse_control_content(label_components['control']) | |
| actual_answer = parse_answer_content(label_components['answer']) | |
| # Calculate component-wise accuracy | |
| component_accuracy = calculate_component_accuracy(response_components, label_components) | |
| # Calculate component quality | |
| think_quality = evaluate_component_quality(parsed_think, actual_think) | |
| control_quality = evaluate_component_quality(parsed_control, actual_control) | |
| answer_quality = evaluate_component_quality(parsed_answer, actual_answer) | |
| # Store component-wise metrics | |
| for comp in ['think', 'control', 'answer']: | |
| results['component_wise_metrics'][comp]['accuracy_scores'].append(component_accuracy[comp]) | |
| results['component_wise_metrics'][comp]['quality_scores'].append( | |
| think_quality.get('content_quality', 0.5) if comp == 'think' else | |
| control_quality.get('content_quality', 0.5) if comp == 'control' else | |
| answer_quality.get('content_quality', 0.5) | |
| ) | |
| # Store detailed analysis | |
| detailed_result = { | |
| 'index': idx, | |
| 'response_components': response_components, | |
| 'label_components': label_components, | |
| 'parsed_response': { | |
| 'think': parsed_think, | |
| 'control': parsed_control, | |
| 'answer': parsed_answer | |
| }, | |
| 'parsed_labels': { | |
| 'think': actual_think, | |
| 'control': actual_control, | |
| 'answer': actual_answer | |
| }, | |
| 'component_accuracy': component_accuracy, | |
| 'component_quality': { | |
| 'think': think_quality, | |
| 'control': control_quality, | |
| 'answer': answer_quality | |
| }, | |
| 'overall_score': sum(component_accuracy.values()) / 3 if component_accuracy else 0 | |
| } | |
| results['detailed_analysis'].append(detailed_result) | |
| # Analyze errors | |
| if component_accuracy['think'] < 0.5: | |
| results['error_patterns']['think_errors'].append(idx) | |
| if component_accuracy['control'] < 0.5: | |
| results['error_patterns']['control_errors'].append(idx) | |
| if component_accuracy['answer'] < 0.5: | |
| results['error_patterns']['answer_errors'].append(idx) | |
| # Calculate overall metrics | |
| overall_metrics = {} | |
| for comp in ['think', 'control', 'answer']: | |
| acc_scores = results['component_wise_metrics'][comp]['accuracy_scores'] | |
| qual_scores = results['component_wise_metrics'][comp]['quality_scores'] | |
| overall_metrics[f'{comp}_avg_accuracy'] = sum(acc_scores) / len(acc_scores) if acc_scores else 0 | |
| overall_metrics[f'{comp}_avg_quality'] = sum(qual_scores) / len(qual_scores) if qual_scores else 0 | |
| overall_metrics[f'{comp}_std_accuracy'] = ( | |
| sum((x - overall_metrics[f'{comp}_avg_accuracy'])**2 for x in acc_scores) / len(acc_scores) | |
| )**0.5 if acc_scores else 0 | |
| # Calculate overall system performance | |
| overall_metrics['total_samples'] = total_samples | |
| overall_metrics['avg_overall_score'] = sum( | |
| d['overall_score'] for d in results['detailed_analysis'] | |
| ) / total_samples if total_samples > 0 else 0 | |
| results['overall_metrics'] = overall_metrics | |
| return results | |
| def generate_evaluation_report(results: Dict[str, Any]) -> str: | |
| """ | |
| Generate comprehensive evaluation report | |
| """ | |
| report = [] | |
| report.append("="*100) | |
| report.append("COMPREHENSIVE EVALUATION OF IN-VEHICLE MULTIMODAL AI MODEL") | |
| report.append("="*100) | |
| metrics = results['overall_metrics'] | |
| report.append(f"\n📊 OVERALL SYSTEM PERFORMANCE:") | |
| report.append(f" Total Samples: {metrics['total_samples']}") | |
| report.append(f" Average Overall Score: {metrics['avg_overall_score']:.4f}") | |
| report.append(f"\n🔍 COMPONENT-WISE PERFORMANCE:") | |
| for comp in ['think', 'control', 'answer']: | |
| avg_acc = metrics.get(f'{comp}_avg_accuracy', 0) | |
| avg_qual = metrics.get(f'{comp}_avg_quality', 0) | |
| std_acc = metrics.get(f'{comp}_std_accuracy', 0) | |
| report.append(f" {comp.upper()}:") | |
| report.append(f" Average Accuracy: {avg_acc:.4f}") | |
| report.append(f" Average Quality: {avg_qual:.4f}") | |
| report.append(f" Std Deviation: {std_acc:.4f}") | |
| # Error analysis | |
| error_patterns = results['error_patterns'] | |
| report.append(f"\n❌ ERROR ANALYSIS:") | |
| report.append(f" Think component errors: {len(error_patterns['think_errors'])} samples") | |
| report.append(f" Control component errors: {len(error_patterns['control_errors'])} samples") | |
| report.append(f" Answer component errors: {len(error_patterns['answer_errors'])} samples") | |
| # Sample error analysis | |
| if results['detailed_analysis']: | |
| sample_analysis = results['detailed_analysis'][0] # Show first sample as example | |
| report.append(f"\n📋 SAMPLE ANALYSIS (First Sample):") | |
| report.append(f" Think Accuracy: {sample_analysis['component_accuracy']['think']:.4f}") | |
| report.append(f" Control Accuracy: {sample_analysis['component_accuracy']['control']:.4f}") | |
| report.append(f" Answer Accuracy: {sample_analysis['component_accuracy']['answer']:.4f}") | |
| report.append(f" Overall Score: {sample_analysis['overall_score']:.4f}") | |
| # Component type analysis | |
| report.append(f"\n🔧 COMPONENT TYPE ANALYSIS:") | |
| # Analyze control command types | |
| control_types = [] | |
| for analysis in results['detailed_analysis']: | |
| control_type = analysis['parsed_response']['control'].get('type', 'unknown') | |
| control_types.append(control_type) | |
| type_counts = Counter(control_types) | |
| report.append(" Control Command Types:") | |
| for control_type, count in type_counts.most_common(): | |
| report.append(f" {control_type}: {count} samples") | |
| # Answer category analysis | |
| answer_categories = [] | |
| for analysis in results['detailed_analysis']: | |
| answer_category = analysis['parsed_response']['answer'].get('category', 'unknown') | |
| answer_categories.append(answer_category) | |
| category_counts = Counter(answer_categories) | |
| report.append(" Answer Categories:") | |
| for category, count in category_counts.most_common(): | |
| report.append(f" {category}: {count} samples") | |
| report.append(f"\n🎯 RECOMMENDATIONS:") | |
| if metrics.get('think_avg_accuracy', 0) < 0.7: | |
| report.append(" - Improve think component (behavior analysis)") | |
| if metrics.get('control_avg_accuracy', 0) < 0.7: | |
| report.append(" - Improve control component (command generation)") | |
| if metrics.get('answer_avg_accuracy', 0) < 0.7: | |
| report.append(" - Improve answer component (final classification)") | |
| return "\n".join(report) | |
| def save_evaluation_results(results: Dict[str, Any], output_path: str): | |
| """ | |
| Save evaluation results to JSON file | |
| """ | |
| with open(output_path, 'w', encoding='utf-8') as f: | |
| json.dump(results, f, ensure_ascii=False, indent=2) | |
| def main(input_file: str, output_file: str = None): | |
| """ | |
| Main function to perform comprehensive evaluation | |
| """ | |
| print(f"Loading data from: {input_file}") | |
| # Load data | |
| data = load_data(input_file) | |
| print(f"Loaded {len(data)} samples") | |
| # Perform comprehensive evaluation | |
| print("Performing comprehensive evaluation...") | |
| results = comprehensive_evaluation(data) | |
| # Generate and print report | |
| report = generate_evaluation_report(results) | |
| print(report) | |
| # Save results if output path provided | |
| if output_file: | |
| save_evaluation_results(results, output_file) | |
| print(f"\nDetailed evaluation results saved to: {output_file}") | |
| return results | |
| if __name__ == "__main__": | |
| import sys | |
| # if len(sys.argv) < 2: | |
| # print("Usage: python comprehensive_evaluation.py <input_file> [output_file]") | |
| # print(" input_file: Path to JSON or JSONL file containing model predictions") | |
| # print(" output_file: Optional path to save detailed evaluation results") | |
| # sys.exit(1) | |
| # input_file = sys.argv[1] | |
| # output_file = sys.argv[2] if len(sys.argv) > 2 else None | |
| input_file = r"/data/LLM-SFT/SFT_Output/multiclsTask/Qwen2.5-VL-3B-Instruct/v0-20251123-182828/checkpoint-264/infer_result/20251124-175009.jsonl" | |
| output_file = r"/data/LLM-SFT/SFT_Output/multiclsTask/Qwen2.5-VL-3B-Instruct/v0-20251123-182828/checkpoint-264/eval/20251124-175009.jsonl" | |
| results = main(input_file, output_file) | |