Spaces:
Runtime error
Runtime error
| import pickle | |
| import json | |
| import os | |
| import sys | |
| sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
| import re | |
| from collections import defaultdict | |
| def load_ground_truth(pkl_file): | |
| """加载ground truth数据""" | |
| with open(pkl_file, 'rb') as f: | |
| data = pickle.load(f) | |
| # 提取每个蛋白的EC号 | |
| gt_dict = {} | |
| for item in data: | |
| uniprot_id = item['uniprot_id'] | |
| ec_numbers = [] | |
| # 提取EC号 | |
| if 'ec' in item: | |
| for ec_info in item['ec']: | |
| if 'reaction' in ec_info and 'ecNumber' in ec_info['reaction']: | |
| ec_numbers.append(ec_info['reaction']['ecNumber']) | |
| gt_dict[uniprot_id] = set(ec_numbers) # 使用set去重 | |
| return gt_dict | |
| def extract_ec_prediction(json_content): | |
| """从预测结果中提取EC号""" | |
| # 查找[EC_PREDICTION]标签后的内容 | |
| pattern = r'\[EC_PREDICTION\]\s*([^\n\r]*)' | |
| match = re.search(pattern, json_content) | |
| if match: | |
| line_content = match.group(1).strip() | |
| # 修改EC号格式匹配,支持不完整的EC号(带有-的情况) | |
| # 匹配格式:数字.数字.数字.数字 或 数字.数字.数字.- 或 数字.数字.-.- 或 数字.-.-.- | |
| ec_pattern = r'\b\d+\.(?:\d+|-)\.(?:\d+|-)\.(?:\d+|-)' | |
| ec_numbers = re.findall(ec_pattern, line_content) | |
| return ec_numbers | |
| return [] | |
| def load_predictions(predictions_dir): | |
| """加载所有预测结果""" | |
| predictions = {} | |
| for filename in os.listdir(predictions_dir): | |
| if filename.endswith('.json'): | |
| uniprot_id = filename.replace('.json', '') | |
| filepath = os.path.join(predictions_dir, filename) | |
| try: | |
| with open(filepath, 'r', encoding='utf-8') as f: | |
| content = f.read() | |
| # 提取EC预测 | |
| predicted_ecs = extract_ec_prediction(content) | |
| predictions[uniprot_id] = predicted_ecs | |
| except Exception as e: | |
| print(f"处理文件 {filename} 时出错: {e}") | |
| return predictions | |
| def calculate_accuracy(ground_truth, predictions, level=4): | |
| """ | |
| 计算EC号在指定级别上的准确率 | |
| level: 1-4,表示比较EC号的前几个数字 | |
| """ | |
| correct = 0 | |
| total = 0 | |
| for uniprot_id, gt_ecs in ground_truth.items(): | |
| if uniprot_id in predictions and predictions[uniprot_id]: | |
| # 取预测的第一个EC号 | |
| pred_ec = predictions[uniprot_id][0] | |
| # 检查是否有任何ground truth EC号在指定级别上与预测匹配 | |
| is_correct = False | |
| for gt_ec in gt_ecs: | |
| # 将EC号分割成组成部分 | |
| gt_parts = gt_ec.split('.')[:level] | |
| pred_parts = pred_ec.split('.')[:level] | |
| # 比较前level个部分是否相同 | |
| if gt_parts == pred_parts: | |
| is_correct = True | |
| break | |
| if is_correct: | |
| correct += 1 | |
| total += 1 | |
| accuracy = correct / total if total > 0 else 0 | |
| return accuracy, correct, total | |
| def calculate_prf1(ground_truth, predictions, level=4): | |
| """ | |
| 计算EC号在指定级别上的精确率、召回率和F1分数 (微平均) | |
| level: 1-4,表示比较EC号的前几个数字 | |
| """ | |
| total_tp = 0 | |
| total_fp = 0 | |
| total_fn = 0 | |
| # 添加用于记录错误预测的字典 | |
| incorrect_proteins = { | |
| 'false_positives': [], # 预测了但GT中没有的 | |
| 'false_negatives': [], # GT中有但没预测到的 | |
| 'no_prediction': [], # 有GT但没有预测的 | |
| 'zero_prediction': [] # 预测了0个EC号的蛋白 | |
| } | |
| for uniprot_id, gt_ecs_set in ground_truth.items(): | |
| if uniprot_id in predictions: | |
| pred_ecs_set = set(predictions[uniprot_id]) | |
| # 如果GT是空的,跳过这个蛋白的评估 | |
| if not gt_ecs_set: | |
| continue | |
| # 检查是否预测了0个EC号 | |
| if not pred_ecs_set: | |
| level_gt = set('.'.join(ec.split('.')[:level]) for ec in gt_ecs_set) | |
| fn = len(level_gt) | |
| total_fn += fn | |
| incorrect_proteins['zero_prediction'].append({ | |
| 'protein_id': uniprot_id, | |
| 'gt_ecs': list(level_gt) | |
| }) | |
| continue | |
| # --- 核心计算逻辑 --- | |
| # 为了处理level,我们需要小心地计算交集 | |
| # level_gt = {'1.2.3.4' -> '1.2.3'} | |
| level_gt = set('.'.join(ec.split('.')[:level]) for ec in gt_ecs_set) | |
| level_pred = set('.'.join(ec.split('.')[:level]) for ec in pred_ecs_set) | |
| # 计算 TP, FP, FN | |
| tp = len(level_pred.intersection(level_gt)) | |
| fp = len(level_pred) - tp | |
| fn = len(level_gt) - tp | |
| total_tp += tp | |
| total_fp += fp | |
| total_fn += fn | |
| # 记录有错误的蛋白ID | |
| if fp > 0 or fn > 0: | |
| fp_ecs = level_pred - level_gt # 假阳性的EC号 | |
| fn_ecs = level_gt - level_pred # 假阴性的EC号 | |
| if fp > 0: | |
| incorrect_proteins['false_positives'].append({ | |
| 'protein_id': uniprot_id, | |
| 'predicted_ecs': list(fp_ecs), | |
| 'gt_ecs': list(level_gt) | |
| }) | |
| if fn > 0: | |
| incorrect_proteins['false_negatives'].append({ | |
| 'protein_id': uniprot_id, | |
| 'missed_ecs': list(fn_ecs), | |
| 'predicted_ecs': list(level_pred) | |
| }) | |
| else: | |
| # 有GT但没有预测的情况 | |
| if gt_ecs_set: | |
| level_gt = set('.'.join(ec.split('.')[:level]) for ec in gt_ecs_set) | |
| fn = len(level_gt) | |
| total_fn += fn | |
| incorrect_proteins['no_prediction'].append({ | |
| 'protein_id': uniprot_id, | |
| 'gt_ecs': list(level_gt) | |
| }) | |
| # 使用微平均计算总的 Precision, Recall, F1 | |
| precision = total_tp / (total_tp + total_fp) if (total_tp + total_fp) > 0 else 0 | |
| recall = total_tp / (total_tp + total_fn) if (total_tp + total_fn) > 0 else 0 | |
| f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0 | |
| # total用于展示处理了多少个蛋白 | |
| total_proteins_evaluated = sum(1 for uid in ground_truth if uid in predictions and ground_truth[uid]) | |
| return { | |
| 'precision': precision, | |
| 'recall': recall, | |
| 'f1_score': f1, | |
| 'tp': total_tp, | |
| 'fp': total_fp, | |
| 'fn': total_fn, | |
| 'evaluated_proteins': total_proteins_evaluated, | |
| 'incorrect_proteins': incorrect_proteins | |
| } | |
| def main(): | |
| # 文件路径 | |
| import argparse | |
| parser = argparse.ArgumentParser(description='Calculate EC accuracy') | |
| parser.add_argument('--pkl_file', type=str, default='data/raw_data/difference_20241122_ec_dict_list.pkl') | |
| parser.add_argument('--predictions_dir', type=str, default='data/clean_test_results_top2go_deepseek-r1') | |
| args = parser.parse_args() | |
| pkl_file = args.pkl_file | |
| predictions_dir = args.predictions_dir | |
| print("正在加载ground truth数据...") | |
| ground_truth = load_ground_truth(pkl_file) | |
| print(f"加载了 {len(ground_truth)} 个蛋白的ground truth数据") | |
| print("正在加载预测结果...") | |
| predictions = load_predictions(predictions_dir) | |
| print(f"加载了 {len(predictions)} 个蛋白的预测结果") | |
| # print(f"predictions: {predictions}") | |
| # print(f"ground_truth: {ground_truth}") | |
| # 找到共同的蛋白ID | |
| common_ids = set(ground_truth.keys()) & set(predictions.keys()) | |
| valid_ids = {uid for uid in common_ids if ground_truth[uid]} # 只评估那些有GT EC号的蛋白 | |
| print(f"共同且有GT的蛋白数量: {len(valid_ids)}") | |
| # 过滤数据 | |
| filtered_gt = {uid: ground_truth[uid] for uid in valid_ids} | |
| filtered_pred = {uid: predictions[uid] for uid in valid_ids} | |
| # 计算不同级别的PRF1 | |
| results = {} | |
| print("\n=== 评估结果 ===") | |
| for level in [1, 2, 3, 4]: | |
| metrics = calculate_prf1(filtered_gt, filtered_pred, level=level) | |
| results[level] = metrics | |
| print(f"--- EC号前{level}级 ---") | |
| print(f" Precision: {metrics['precision']:.4f}") | |
| print(f" Recall: {metrics['recall']:.4f}") | |
| print(f" F1-Score: {metrics['f1_score']:.4f}") | |
| print(f" (TP: {metrics['tp']}, FP: {metrics['fp']}, FN: {metrics['fn']})") | |
| # 打印预测错误的蛋白ID | |
| incorrect = metrics['incorrect_proteins'] | |
| if incorrect['false_positives']: | |
| print(f" 假阳性错误 ({len(incorrect['false_positives'])}个蛋白):") | |
| for item in incorrect['false_positives'][:10]: # 只显示前10个 | |
| print(f" {item['protein_id']}: 错误预测了 {item['predicted_ecs']}, GT是 {item['gt_ecs']}") | |
| if len(incorrect['false_positives']) > 10: | |
| print(f" ... 还有 {len(incorrect['false_positives']) - 10} 个") | |
| if incorrect['false_negatives']: | |
| print(f" 假阴性错误 ({len(incorrect['false_negatives'])}个蛋白):") | |
| for item in incorrect['false_negatives'][:10]: # 只显示前10个 | |
| print(f" {item['protein_id']}: 漏掉了 {item['missed_ecs']}, 预测了 {item['predicted_ecs']}") | |
| if len(incorrect['false_negatives']) > 10: | |
| print(f" ... 还有 {len(incorrect['false_negatives']) - 10} 个") | |
| if incorrect['zero_prediction']: | |
| print(f" 零预测错误 ({len(incorrect['zero_prediction'])}个蛋白):") | |
| for item in incorrect['zero_prediction']: | |
| print(f" {item['protein_id']}: GT是 {item['gt_ecs']}, 但预测了0个EC号") | |
| if incorrect['no_prediction']: | |
| print(f" 无预测错误 ({len(incorrect['no_prediction'])}个蛋白):") | |
| for item in incorrect['no_prediction'][:10]: # 只显示前10个 | |
| print(f" {item['protein_id']}: GT是 {item['gt_ecs']}, 但没有预测") | |
| if len(incorrect['no_prediction']) > 10: | |
| print(f" ... 还有 {len(incorrect['no_prediction']) - 10} 个") | |
| print() # 空行分隔 | |
| # 统计信息 | |
| print("\n=== 详细统计信息 ===") | |
| # 统计ground truth中EC号的分布 | |
| gt_ec_counts = defaultdict(int) | |
| for ecs in filtered_gt.values(): | |
| gt_ec_counts[len(ecs)] += 1 | |
| print("Ground truth中EC号数量分布:") | |
| for count, freq in sorted(gt_ec_counts.items()): | |
| print(f" {count}个EC号: {freq}个蛋白") | |
| # 统计预测结果中EC号的分布 | |
| pred_ec_counts = defaultdict(int) | |
| for ecs in filtered_pred.values(): | |
| pred_ec_counts[len(ecs)] += 1 | |
| print("\n预测结果中EC号数量分布:") | |
| for count, freq in sorted(pred_ec_counts.items()): | |
| print(f" {count}个EC号: {freq}个蛋白") | |
| # 保存结果 | |
| output_file = 'test_results/ec_accuracy_results.json' | |
| with open(output_file, 'w', encoding='utf-8') as f: | |
| json.dump(results, f, indent=2, ensure_ascii=False) | |
| # #保存ground truth | |
| # with open('test_results/ground_truth.json', 'w', encoding='utf-8') as f: | |
| # json.dump(filtered_gt, f, indent=2, ensure_ascii=False) | |
| # #保存预测结果 | |
| # with open('test_results/predictions.json', 'w', encoding='utf-8') as f: | |
| # json.dump(filtered_pred, f, indent=2, ensure_ascii=False) | |
| print(f"\n结果已保存到 {output_file}") | |
| if __name__ == "__main__": | |
| main() |