Spaces:
Runtime error
Runtime error
| import json | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| from sklearn.metrics import precision_recall_curve, average_precision_score | |
| def load_go_data(file_path): | |
| """加载JSON文件中的GO数据""" | |
| data = {} | |
| with open(file_path, 'r') as f: | |
| for line in f: | |
| entry = json.loads(line) | |
| data[entry["protein_id"]] = set(entry["GO_id"]) | |
| return data | |
| def calculate_pr_metrics(true_go_file, pred_go_file, scores_file=None): | |
| """计算precision、recall和绘制PR曲线""" | |
| # 加载真实GO和预测GO数据 | |
| true_go_data = load_go_data(true_go_file) | |
| pred_go_data = load_go_data(pred_go_file) | |
| # 如果提供了分数文件,加载分数 | |
| scores = {} | |
| if scores_file: | |
| with open(scores_file, 'r') as f: | |
| for line in f: | |
| entry = json.loads(line) | |
| scores[entry["protein_id"]] = {go: score for go, score in entry.get("GO_scores", {}).items()} | |
| # 准备计算PR曲线的数据 | |
| all_true = [] | |
| all_scores = [] | |
| # 处理每个蛋白质 | |
| common_proteins = set(true_go_data.keys()) & set(pred_go_data.keys()) | |
| # 计算每个蛋白质的precision和recall | |
| protein_metrics = {} | |
| for protein_id in common_proteins: | |
| true_gos = true_go_data[protein_id] | |
| pred_gos = pred_go_data[protein_id] | |
| # 计算当前蛋白质的precision和recall | |
| if len(pred_gos) > 0: | |
| precision = len(true_gos & pred_gos) / len(pred_gos) | |
| else: | |
| precision = 0.0 | |
| if len(true_gos) > 0: | |
| recall = len(true_gos & pred_gos) / len(true_gos) | |
| else: | |
| recall = 1.0 # 如果没有真实GO,则recall为1 | |
| f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0 | |
| protein_metrics[protein_id] = { | |
| "precision": precision, | |
| "recall": recall, | |
| "f1": f1 | |
| } | |
| # 如果有分数,为PR曲线准备数据 | |
| if scores_file: | |
| protein_scores = scores.get(protein_id, {}) | |
| for go in set(true_gos) | set(pred_go_data.get(protein_id, set())): | |
| all_true.append(1 if go in true_gos else 0) | |
| all_scores.append(protein_scores.get(go, 0.0)) | |
| # 计算整体指标 | |
| avg_precision = np.mean([m["precision"] for m in protein_metrics.values()]) | |
| avg_recall = np.mean([m["recall"] for m in protein_metrics.values()]) | |
| avg_f1 = np.mean([m["f1"] for m in protein_metrics.values()]) | |
| results = { | |
| "average_precision": avg_precision, | |
| "average_recall": avg_recall, | |
| "average_f1": avg_f1, | |
| "protein_metrics": protein_metrics | |
| } | |
| # 如果有分数,绘制PR曲线 | |
| if scores_file and all_true and all_scores: | |
| all_true = np.array(all_true) | |
| all_scores = np.array(all_scores) | |
| precision, recall, thresholds = precision_recall_curve(all_true, all_scores) | |
| avg_precision = average_precision_score(all_true, all_scores) | |
| # 计算每个阈值的F1分数 | |
| f1_scores = np.zeros_like(thresholds) | |
| for i, threshold in enumerate(thresholds): | |
| f1_scores[i] = 2 * precision[i] * recall[i] / (precision[i] + recall[i]) if (precision[i] + recall[i]) > 0 else 0 | |
| # 找到最佳F1分数对应的阈值 | |
| best_f1_idx = np.argmax(f1_scores) | |
| best_threshold = thresholds[best_f1_idx] | |
| best_precision = precision[best_f1_idx] | |
| best_recall = recall[best_f1_idx] | |
| best_f1 = f1_scores[best_f1_idx] | |
| # 绘制PR曲线 | |
| plt.figure(figsize=(10, 8)) | |
| plt.plot(recall, precision, label=f'平均精确率 = {avg_precision:.3f}') | |
| plt.scatter(best_recall, best_precision, color='red', | |
| label=f'最佳F1 = {best_f1:.3f} (阈值 = {best_threshold:.3f})') | |
| plt.xlabel('Recall') | |
| plt.ylabel('Precision') | |
| plt.title('Precision-Recall 曲线') | |
| plt.legend() | |
| plt.grid(True) | |
| # 保存图像 | |
| plt.savefig('pr_curve.png', dpi=300) | |
| plt.close() | |
| results.update({ | |
| "pr_curve": { | |
| "precision": precision.tolist(), | |
| "recall": recall.tolist(), | |
| "thresholds": thresholds.tolist(), | |
| "best_threshold": float(best_threshold), | |
| "best_f1": float(best_f1) | |
| } | |
| }) | |
| return results | |
| def main(): | |
| import argparse | |
| parser = argparse.ArgumentParser(description='计算GO预测的Precision和Recall并绘制PR曲线') | |
| parser.add_argument('--true', required=True, help='真实GO的JSON文件路径') | |
| parser.add_argument('--pred', required=True, help='预测GO的JSON文件路径') | |
| parser.add_argument('--scores', help='GO分数的JSON文件路径(可选)') | |
| parser.add_argument('--output', default='test_results/pr_results.json', help='输出结果的JSON文件路径') | |
| args = parser.parse_args() | |
| results = calculate_pr_metrics(args.true, args.pred, args.scores) | |
| # 保存结果 | |
| with open(args.output, 'w') as f: | |
| json.dump(results, f, indent=2) | |
| print(f"平均精确率: {results['average_precision']:.4f}") | |
| print(f"平均召回率: {results['average_recall']:.4f}") | |
| print(f"平均F1分数: {results['average_f1']:.4f}") | |
| if 'pr_curve' in results: | |
| print(f"最佳F1分数: {results['pr_curve']['best_f1']:.4f} (阈值: {results['pr_curve']['best_threshold']:.4f})") | |
| print(f"PR曲线已保存为 pr_curve.png") | |
| if __name__ == "__main__": | |
| main() | |