Spaces:
Runtime error
Runtime error
| import sys | |
| import os | |
| sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
| import json | |
| from pathlib import Path | |
| from tqdm import tqdm | |
| import pandas as pd | |
| import numpy as np | |
| from utils.openai_access import call_chatgpt | |
| from utils.mpr import MultipleProcessRunnerSimplifier | |
| from utils.prompts import LLM_SCORE_PROMPT | |
| import re | |
| qa_data = {} | |
| def load_qa_results_from_dir(results_dir): | |
| """从结果目录加载所有QA结果""" | |
| global qa_data | |
| qa_data = {} | |
| results_path = Path(results_dir) | |
| json_files = list(results_path.glob("*.json")) | |
| print(f"找到 {len(json_files)} 个结果文件") | |
| for json_file in tqdm(json_files, desc="加载QA结果"): | |
| try: | |
| with open(json_file, 'r') as f: | |
| data = json.load(f) | |
| if ('index' in data and 'protein_id' in data and | |
| 'ground_truth' in data and 'llm_answer' in data): | |
| qa_data[data['index']] = data | |
| except Exception as e: | |
| print(f"加载文件 {json_file} 时出错: {e}") | |
| print(f"成功加载 {len(qa_data)} 个QA对") | |
| return qa_data | |
| def extract_score_from_response(response): | |
| """从LLM响应中提取分数""" | |
| if not response: | |
| return None | |
| # 尝试解析JSON格式的响应 | |
| try: | |
| if isinstance(response, str): | |
| # 尝试直接解析JSON | |
| json_match = re.search(r'\{[^}]*"score"[^}]*\}', response) | |
| if json_match: | |
| json_obj = json.loads(json_match.group()) | |
| return json_obj.get('score') | |
| # 尝试提取数字 | |
| score_match = re.search(r'"score":\s*(\d+(?:\.\d+)?)', response) | |
| if score_match: | |
| return float(score_match.group(1)) | |
| # 尝试提取纯数字 | |
| number_match = re.search(r'(\d+(?:\.\d+)?)', response) | |
| if number_match: | |
| score = float(number_match.group(1)) | |
| if 0 <= score <= 100: | |
| return score | |
| elif isinstance(response, dict): | |
| return response.get('score') | |
| except: | |
| pass | |
| return None | |
| def process_single_scoring(process_id, idx, qa_index, writer, save_dir): | |
| """处理单个QA对的打分""" | |
| try: | |
| qa_item = qa_data[qa_index] | |
| protein_id = qa_item['protein_id'] | |
| question = qa_item.get('question', '') | |
| ground_truth = qa_item['ground_truth'] | |
| llm_answer = qa_item['llm_answer'] | |
| # 构建打分提示 | |
| scoring_prompt = LLM_SCORE_PROMPT.replace('{{ground_truth}}', str(ground_truth)) | |
| scoring_prompt = scoring_prompt.replace('{{llm_answer}}', str(llm_answer)) | |
| # 调用LLM进行打分 | |
| score_response = call_chatgpt(scoring_prompt) | |
| score = extract_score_from_response(score_response) | |
| # 构建结果数据 | |
| result = { | |
| 'index': qa_index, | |
| 'protein_id': protein_id, | |
| 'question': question, | |
| 'ground_truth': ground_truth, | |
| 'llm_answer': llm_answer, | |
| 'score': score, | |
| 'raw_score_response': score_response | |
| } | |
| # 保存文件 | |
| save_path = os.path.join(save_dir, f"score_{protein_id}_{qa_index}.json") | |
| with open(save_path, 'w') as f: | |
| json.dump(result, f, indent=2, ensure_ascii=False) | |
| except Exception as e: | |
| print(f"处理QA索引 {qa_index} 时出错: {str(e)}") | |
| def get_missing_score_indices(save_dir): | |
| """检查哪些QA索引尚未完成打分""" | |
| all_qa_indices = list(qa_data.keys()) | |
| problem_qa_indices = set() | |
| for qa_index in tqdm(all_qa_indices, desc="检查打分文件"): | |
| protein_id = qa_data[qa_index]['protein_id'] | |
| json_file = Path(save_dir) / f"score_{protein_id}_{qa_index}.json" | |
| if not json_file.exists(): | |
| problem_qa_indices.add(qa_index) | |
| continue | |
| try: | |
| with open(json_file, 'r') as f: | |
| data = json.load(f) | |
| if (data is None or len(data) == 0 or | |
| 'score' not in data or | |
| data.get('score') is None): | |
| problem_qa_indices.add(qa_index) | |
| json_file.unlink() | |
| except Exception as e: | |
| problem_qa_indices.add(qa_index) | |
| try: | |
| json_file.unlink() | |
| except: | |
| pass | |
| return problem_qa_indices | |
| def collect_scores_to_json(save_dir, output_json): | |
| """收集所有打分结果并保存为JSON文件""" | |
| results = [] | |
| save_path = Path(save_dir) | |
| score_files = list(save_path.glob("score_*.json")) | |
| for score_file in tqdm(score_files, desc="收集打分结果"): | |
| try: | |
| with open(score_file, 'r') as f: | |
| data = json.load(f) | |
| results.append({ | |
| 'index': data.get('index'), | |
| 'protein_id': data.get('protein_id'), | |
| 'question': data.get('question', ''), | |
| 'ground_truth': data.get('ground_truth'), | |
| 'llm_answer': data.get('llm_answer'), | |
| 'score': data.get('score') | |
| }) | |
| except Exception as e: | |
| print(f"读取文件 {score_file} 时出错: {e}") | |
| # 按index排序 | |
| results.sort(key=lambda x: x.get('index', 0)) | |
| # 保存为JSON文件 | |
| with open(output_json, 'w', encoding='utf-8') as f: | |
| json.dump(results, f, indent=2, ensure_ascii=False) | |
| print(f"打分结果已保存到: {output_json}") | |
| # 转换为DataFrame用于分析 | |
| df = pd.DataFrame(results) | |
| return df | |
| def analyze_scores(df): | |
| """对打分结果进行统计分析""" | |
| print("\n=== 打分结果统计分析 ===") | |
| # 基本统计 | |
| valid_scores = df[df['score'].notna()]['score'] | |
| if len(valid_scores) == 0: | |
| print("没有有效的打分结果") | |
| return | |
| print(f"总样本数: {len(df)}") | |
| print(f"有效打分数: {len(valid_scores)}") | |
| print(f"无效打分数: {len(df) - len(valid_scores)}") | |
| print(f"有效率: {len(valid_scores)/len(df)*100:.2f}%") | |
| print(f"\n分数统计:") | |
| print(f"平均分: {valid_scores.mean():.2f}") | |
| print(f"中位数: {valid_scores.median():.2f}") | |
| print(f"标准差: {valid_scores.std():.2f}") | |
| print(f"最高分: {valid_scores.max():.2f}") | |
| print(f"最低分: {valid_scores.min():.2f}") | |
| # 分数分布 | |
| print(f"\n分数分布:") | |
| bins = [0, 20, 40, 60, 80, 100] | |
| labels = ['0-20', '21-40', '41-60', '61-80', '81-100'] | |
| for i, (low, high) in enumerate(zip(bins[:-1], bins[1:])): | |
| count = len(valid_scores[(valid_scores >= low) & (valid_scores <= high)]) | |
| percentage = count / len(valid_scores) * 100 | |
| print(f"{labels[i]}: {count} ({percentage:.1f}%)") | |
| # 分位数 | |
| print(f"\n分位数:") | |
| quantiles = [0.1, 0.25, 0.5, 0.75, 0.9] | |
| for q in quantiles: | |
| print(f"{int(q*100)}%分位数: {valid_scores.quantile(q):.2f}") | |
| # 按蛋白质ID分析(如果样本足够多) | |
| if len(df['protein_id'].unique()) > 1: | |
| print(f"\n按蛋白质ID分析:") | |
| protein_stats = df[df['score'].notna()].groupby('protein_id')['score'].agg(['count', 'mean', 'std']).round(2) | |
| print(protein_stats.head(10)) | |
| # 保存统计分析结果 | |
| stats_result = { | |
| "basic_stats": { | |
| "total_samples": len(df), | |
| "valid_scores": len(valid_scores), | |
| "invalid_scores": len(df) - len(valid_scores), | |
| "valid_rate": len(valid_scores)/len(df)*100, | |
| "mean_score": float(valid_scores.mean()), | |
| "median_score": float(valid_scores.median()), | |
| "std_score": float(valid_scores.std()), | |
| "max_score": float(valid_scores.max()), | |
| "min_score": float(valid_scores.min()) | |
| }, | |
| "distribution": {}, | |
| "quantiles": {} | |
| } | |
| # 分数分布统计 | |
| for i, (low, high) in enumerate(zip(bins[:-1], bins[1:])): | |
| count = len(valid_scores[(valid_scores >= low) & (valid_scores <= high)]) | |
| percentage = count / len(valid_scores) * 100 | |
| stats_result["distribution"][labels[i]] = { | |
| "count": count, | |
| "percentage": percentage | |
| } | |
| # 分位数统计 | |
| for q in quantiles: | |
| stats_result["quantiles"][f"{int(q*100)}%"] = float(valid_scores.quantile(q)) | |
| return stats_result | |
| def main(): | |
| import argparse | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--results_dir", type=str, | |
| default="data/evolla_hard_motif_go", | |
| help="包含LLM答案结果的目录") | |
| parser.add_argument("--n_process", type=int, default=32, | |
| help="并行进程数") | |
| parser.add_argument("--save_dir", type=str, | |
| default="data/llm_scores", | |
| help="保存打分结果的目录") | |
| parser.add_argument("--output_json", type=str, | |
| default="data/llm_scores_results.json", | |
| help="输出JSON文件路径") | |
| parser.add_argument("--stats_json", type=str, | |
| default="data/llm_scores_stats.json", | |
| help="统计分析结果JSON文件路径") | |
| parser.add_argument("--max_iterations", type=int, default=3, | |
| help="最大迭代次数") | |
| args = parser.parse_args() | |
| # 创建保存目录 | |
| os.makedirs(args.save_dir, exist_ok=True) | |
| os.makedirs(os.path.dirname(args.output_json), exist_ok=True) | |
| # 加载QA结果数据 | |
| load_qa_results_from_dir(args.results_dir) | |
| if not qa_data: | |
| print("没有找到有效的QA结果数据") | |
| return | |
| # 循环检查和打分 | |
| iteration = 0 | |
| while iteration < args.max_iterations: | |
| iteration += 1 | |
| print(f"\n开始第 {iteration} 轮打分") | |
| # 获取缺失打分的QA索引 | |
| missing_indices = get_missing_score_indices(args.save_dir) | |
| if not missing_indices: | |
| print("所有QA对已完成打分!") | |
| break | |
| print(f"发现 {len(missing_indices)} 个待打分的QA对") | |
| missing_indices_list = sorted(list(missing_indices)) | |
| # 使用多进程处理打分 | |
| mprs = MultipleProcessRunnerSimplifier( | |
| data=missing_indices_list, | |
| do=lambda process_id, idx, qa_index, writer: process_single_scoring(process_id, idx, qa_index, writer, args.save_dir), | |
| n_process=args.n_process, | |
| split_strategy="static" | |
| ) | |
| mprs.run() | |
| print(f"第 {iteration} 轮打分完成") | |
| # 收集结果并保存为JSON | |
| df = collect_scores_to_json(args.save_dir, args.output_json) | |
| # 进行统计分析 | |
| stats_result = analyze_scores(df) | |
| # 保存统计分析结果为JSON | |
| with open(args.stats_json, 'w', encoding='utf-8') as f: | |
| json.dump(stats_result, f, indent=2, ensure_ascii=False) | |
| print(f"统计分析结果已保存到: {args.stats_json}") | |
| # 检查最终结果 | |
| final_missing = get_missing_score_indices(args.save_dir) | |
| if final_missing: | |
| print(f"\n仍有 {len(final_missing)} 个QA对未能成功打分") | |
| else: | |
| print(f"\n所有 {len(qa_data)} 个QA对已成功完成打分!") | |
| if __name__ == "__main__": | |
| main() | |