| |
| """ |
| Extract all benchmark results from GlycanAA and SweetTalk logs. |
| Creates comprehensive comparison CSV files. |
| """ |
|
|
| import os |
| import re |
| import glob |
| import pandas as pd |
| from collections import defaultdict |
|
|
| |
| TASKS = ['domain', 'kingdom', 'phylum', 'class', 'order', 'family', 'genus', 'species', 'immunogenicity', 'link'] |
| DATASETS = ['strict', 'strict_3'] |
|
|
| def extract_glycanaa_results(log_dir='GlycanAA/logs'): |
| """Extract results from GlycanAA logs.""" |
| results = [] |
| |
| |
| pattern = os.path.join(log_dir, 'AA_*.out') |
| log_files = glob.glob(pattern) |
| |
| print(f"Found {len(log_files)} GlycanAA log files") |
| |
| for log_file in log_files: |
| basename = os.path.basename(log_file) |
| |
| |
| parts = basename.replace('.out', '').split('_') |
| |
| if len(parts) < 4: |
| continue |
| |
| |
| if parts[1] == 'strict' and parts[2] == '3': |
| dataset = 'strict_3' |
| model_start = 3 |
| elif parts[1] == 'strict': |
| dataset = 'strict' |
| model_start = 2 |
| else: |
| continue |
| |
| |
| |
| job_id = parts[-1] |
| task = parts[-2] |
| model = '_'.join(parts[model_start:-2]).replace('_', '-') if '-' in '_'.join(parts[model_start:-2]) else '-'.join(parts[model_start:-2]) |
| |
| |
| remaining = '_'.join(parts[model_start:]) |
| |
| remaining = '_'.join(remaining.split('_')[:-1]) |
| |
| |
| found_task = None |
| for t in TASKS: |
| if remaining.endswith(t): |
| found_task = t |
| model = remaining[:-len(t)-1] |
| break |
| |
| if not found_task: |
| continue |
| |
| task = found_task |
| |
| |
| try: |
| with open(log_file, 'r') as f: |
| content = f.read() |
| |
| |
| |
| acc_match = re.search(rf'accuracy \[{task}\]: ([\d.]+)', content) |
| f1_match = re.search(rf'macrof1 \[{task}\]: ([\d.]+)', content) |
| mcc_match = re.search(rf'matthews correlation coefficient \[{task}\]: ([\d.]+)', content) |
| auprc_match = re.search(rf'auprc \[{task}\]: ([\d.]+)', content) |
| |
| if acc_match: |
| accuracy = float(acc_match.group(1)) |
| macro_f1 = float(f1_match.group(1)) if f1_match else None |
| mcc = float(mcc_match.group(1)) if mcc_match else None |
| auprc = float(auprc_match.group(1)) if auprc_match else None |
| |
| results.append({ |
| 'source': 'GlycanAA', |
| 'dataset': dataset, |
| 'model': model, |
| 'task': task, |
| 'accuracy': accuracy, |
| 'macro_f1': macro_f1, |
| 'mcc': mcc, |
| 'auprc': auprc, |
| 'job_id': job_id, |
| 'log_file': basename |
| }) |
| except Exception as e: |
| print(f"Error reading {log_file}: {e}") |
| |
| return results |
|
|
| def extract_sweettalk_results(): |
| """Extract results from SweetTalk result CSVs and logs.""" |
| results = [] |
| |
| |
| for dataset in ['strict', 'strict_3']: |
| if dataset == 'strict': |
| result_dir = 'sweettalk/results' |
| else: |
| result_dir = 'sweettalk/results_strict_3' |
| |
| if not os.path.exists(result_dir): |
| print(f"SweetTalk {dataset} results directory not found: {result_dir}") |
| continue |
| |
| for task in TASKS: |
| csv_path = os.path.join(result_dir, task, f'results_{task}.csv') |
| if os.path.exists(csv_path): |
| try: |
| df = pd.read_csv(csv_path) |
| if len(df) > 0: |
| row = df.iloc[0] |
| results.append({ |
| 'source': 'SweetTalk', |
| 'dataset': dataset, |
| 'model': 'SweetTalk', |
| 'task': task, |
| 'accuracy': row.get('accuracy'), |
| 'macro_f1': row.get('macro_f1'), |
| 'mcc': None, |
| 'auprc': None, |
| 'job_id': None, |
| 'log_file': csv_path |
| }) |
| except Exception as e: |
| print(f"Error reading {csv_path}: {e}") |
| |
| return results |
|
|
| def extract_our_results(): |
| """Extract our BERTv4 results from publication summary.""" |
| results = [] |
| |
| for dataset in ['strict', 'strict_3']: |
| csv_path = f'downstream_tasks/results_publication_summary/{dataset}_results.csv' |
| if os.path.exists(csv_path): |
| try: |
| df = pd.read_csv(csv_path) |
| for _, row in df.iterrows(): |
| results.append({ |
| 'source': 'BERTv4', |
| 'dataset': dataset, |
| 'model': 'BERTv4-Topology', |
| 'task': row['Task'], |
| 'accuracy': row['Accuracy'], |
| 'macro_f1': row['Macro-F1'], |
| 'mcc': row.get('MCC'), |
| 'auprc': row.get('AUPRC'), |
| 'job_id': None, |
| 'log_file': csv_path |
| }) |
| except Exception as e: |
| print(f"Error reading {csv_path}: {e}") |
| |
| return results |
|
|
| def main(): |
| print("="*60) |
| print("EXTRACTING ALL BENCHMARK RESULTS") |
| print("="*60) |
| |
| all_results = [] |
| |
| |
| print("\n1. Extracting GlycanAA results...") |
| glycanaa_results = extract_glycanaa_results() |
| all_results.extend(glycanaa_results) |
| print(f" Found {len(glycanaa_results)} GlycanAA results") |
| |
| print("\n2. Extracting SweetTalk results...") |
| sweettalk_results = extract_sweettalk_results() |
| all_results.extend(sweettalk_results) |
| print(f" Found {len(sweettalk_results)} SweetTalk results") |
| |
| print("\n3. Extracting BERTv4 results...") |
| our_results = extract_our_results() |
| all_results.extend(our_results) |
| print(f" Found {len(our_results)} BERTv4 results") |
| |
| |
| df = pd.DataFrame(all_results) |
| |
| if len(df) == 0: |
| print("\nNo results found!") |
| return |
| |
| |
| df.to_csv('all_benchmark_results_raw.csv', index=False) |
| print(f"\nSaved raw results to all_benchmark_results_raw.csv ({len(df)} rows)") |
| |
| |
| |
| summary = df.groupby(['dataset', 'model', 'task']).agg({ |
| 'accuracy': 'max', |
| 'macro_f1': 'max', |
| 'source': 'first' |
| }).reset_index() |
| |
| summary.to_csv('all_benchmark_results_summary.csv', index=False) |
| print(f"Saved summary to all_benchmark_results_summary.csv ({len(summary)} rows)") |
| |
| |
| for dataset in DATASETS: |
| subset = summary[summary['dataset'] == dataset] |
| if len(subset) > 0: |
| |
| pivot = subset.pivot(index='model', columns='task', values='accuracy') |
| |
| cols = [t for t in TASKS if t in pivot.columns] |
| pivot = pivot[cols] |
| pivot.to_csv(f'benchmark_comparison_{dataset}.csv') |
| print(f"\nSaved {dataset} comparison to benchmark_comparison_{dataset}.csv") |
| print(pivot.to_string()) |
| |
| |
| print("\n" + "="*60) |
| print("RESULTS SUMMARY") |
| print("="*60) |
| for dataset in DATASETS: |
| print(f"\n{dataset.upper()}:") |
| subset = summary[summary['dataset'] == dataset] |
| models = subset['model'].unique() |
| print(f" Models: {len(models)}") |
| for model in sorted(models): |
| model_subset = subset[subset['model'] == model] |
| tasks_complete = len(model_subset) |
| print(f" - {model}: {tasks_complete}/10 tasks") |
|
|
| if __name__ == '__main__': |
| main() |
|
|