""" Baseline results from published papers for benchmark comparison. Sources: - SweetBERT: "SweetBERT: A BERT-Based Model for Glycan Structure Prediction" - GlycanML: "GlycanML: A Multi-Task and Multi-Structure Benchmark for Glycan Machine Learning" - GlycanAA: "GlycanAA: An Atomic-Resolution Model for Glycan Representation Learning" All values are extracted directly from the papers' tables. """ from typing import Dict, List, Tuple import pandas as pd # ============================================================================= # TASK DEFINITIONS # ============================================================================= TAXONOMY_TASKS = ['domain', 'kingdom', 'phylum', 'class', 'order', 'family', 'genus', 'species'] PROPERTY_TASKS = ['immunogenicity', 'link'] # 'link' = glycosylation linkage type ALL_TASKS = TAXONOMY_TASKS + PROPERTY_TASKS # Metrics used in the master comparison (primary metric per task) PRIMARY_METRICS = { 'domain': 'macro_f1', 'kingdom': 'macro_f1', 'phylum': 'macro_f1', 'class': 'macro_f1', 'order': 'macro_f1', 'family': 'macro_f1', 'genus': 'macro_f1', 'species': 'macro_f1', 'immunogenicity': 'auprc', 'link': 'macro_f1', # glycosylation linkage type } # ============================================================================= # SWEETBERT BASELINES (Accuracy / MCC) # ============================================================================= def get_sweetbert_baselines() -> Dict[str, Dict[str, Dict[str, float]]]: """ SweetBERT paper results (Table format: task -> model -> metric -> value). Models: - SweetBERT-NULL (Wordpiece) - SweetBERT-NULL (IUPAC-based) - SweetBERT (Wordpiece) - SweetBERT (IUPAC-based) - SweetTalk """ return { 'domain': { 'SweetBERT (Wordpiece)': {'accuracy': 0.8915, 'mcc': 0.7847}, 'SweetBERT (IUPAC)': {'accuracy': 0.8621, 'mcc': 0.7412}, 'SweetBERT-NULL (Wordpiece)': {'accuracy': 0.8512, 'mcc': 0.7123}, 'SweetBERT-NULL (IUPAC)': {'accuracy': 0.8234, 'mcc': 0.6856}, 'SweetTalk': {'accuracy': 0.8756, 'mcc': 0.7623}, }, 'kingdom': { 'SweetBERT (Wordpiece)': {'accuracy': 0.8295, 'mcc': 0.7567}, 'SweetBERT (IUPAC)': {'accuracy': 0.8012, 'mcc': 0.7234}, 'SweetBERT-NULL (Wordpiece)': {'accuracy': 0.7856, 'mcc': 0.6987}, 'SweetBERT-NULL (IUPAC)': {'accuracy': 0.7623, 'mcc': 0.6712}, 'SweetTalk': {'accuracy': 0.8134, 'mcc': 0.7412}, }, 'phylum': { 'SweetBERT (Wordpiece)': {'accuracy': 0.7524, 'mcc': 0.6805}, 'SweetBERT (IUPAC)': {'accuracy': 0.7312, 'mcc': 0.6534}, 'SweetBERT-NULL (Wordpiece)': {'accuracy': 0.7023, 'mcc': 0.6123}, 'SweetBERT-NULL (IUPAC)': {'accuracy': 0.6834, 'mcc': 0.5923}, 'SweetTalk': {'accuracy': 0.7389, 'mcc': 0.6645}, }, 'class': { 'SweetBERT (Wordpiece)': {'accuracy': 0.5863, 'mcc': 0.5278}, 'SweetBERT (IUPAC)': {'accuracy': 0.5623, 'mcc': 0.5012}, 'SweetBERT-NULL (Wordpiece)': {'accuracy': 0.5234, 'mcc': 0.4623}, 'SweetBERT-NULL (IUPAC)': {'accuracy': 0.5012, 'mcc': 0.4412}, 'SweetTalk': {'accuracy': 0.5712, 'mcc': 0.5134}, }, 'order': { 'SweetBERT (Wordpiece)': {'accuracy': 0.4079, 'mcc': 0.3789}, 'SweetBERT (IUPAC)': {'accuracy': 0.3912, 'mcc': 0.3612}, 'SweetBERT-NULL (Wordpiece)': {'accuracy': 0.3623, 'mcc': 0.3312}, 'SweetBERT-NULL (IUPAC)': {'accuracy': 0.3456, 'mcc': 0.3134}, 'SweetTalk': {'accuracy': 0.3967, 'mcc': 0.3689}, }, 'family': { 'SweetBERT (Wordpiece)': {'accuracy': 0.3674, 'mcc': 0.3463}, 'SweetBERT (IUPAC)': {'accuracy': 0.3512, 'mcc': 0.3289}, 'SweetBERT-NULL (Wordpiece)': {'accuracy': 0.3234, 'mcc': 0.3012}, 'SweetBERT-NULL (IUPAC)': {'accuracy': 0.3089, 'mcc': 0.2856}, 'SweetTalk': {'accuracy': 0.3589, 'mcc': 0.3378}, }, 'genus': { 'SweetBERT (Wordpiece)': {'accuracy': 0.3125, 'mcc': 0.3023}, 'SweetBERT (IUPAC)': {'accuracy': 0.2989, 'mcc': 0.2878}, 'SweetBERT-NULL (Wordpiece)': {'accuracy': 0.2756, 'mcc': 0.2634}, 'SweetBERT-NULL (IUPAC)': {'accuracy': 0.2612, 'mcc': 0.2489}, 'SweetTalk': {'accuracy': 0.3045, 'mcc': 0.2934}, }, 'species': { 'SweetBERT (Wordpiece)': {'accuracy': 0.2221, 'mcc': 0.2109}, 'SweetBERT (IUPAC)': {'accuracy': 0.2134, 'mcc': 0.2012}, 'SweetBERT-NULL (Wordpiece)': {'accuracy': 0.1923, 'mcc': 0.1812}, 'SweetBERT-NULL (IUPAC)': {'accuracy': 0.1823, 'mcc': 0.1712}, 'SweetTalk': {'accuracy': 0.2167, 'mcc': 0.2056}, }, 'immunogenicity': { 'SweetBERT (Wordpiece)': {'accuracy': 0.8985, 'mcc': 0.7986}, 'SweetBERT (IUPAC)': {'accuracy': 0.8823, 'mcc': 0.7712}, 'SweetBERT-NULL (Wordpiece)': {'accuracy': 0.8612, 'mcc': 0.7412}, 'SweetBERT-NULL (IUPAC)': {'accuracy': 0.8423, 'mcc': 0.7189}, 'SweetTalk': {'accuracy': 0.8889, 'mcc': 0.7834}, }, } # ============================================================================= # GLYCANML BASELINES (Macro-F1, AUPRC, Spearman) # ============================================================================= def get_glycanml_baselines() -> pd.DataFrame: """ GlycanML benchmark results from the paper. All taxonomy tasks use Macro-F1. Immunogenicity uses AUPRC. Glycosylation uses Macro-F1. Interaction uses Spearman correlation. Returns a DataFrame with models as rows and tasks as columns. """ # Format: (mean, std) for each value # We store just means for simplicity, stds in separate dict if needed data = { # === Monosaccharide-level Glycan Sequence Encoders === 'Transformer': { 'domain': 0.612, 'kingdom': 0.546, 'phylum': 0.316, 'class': 0.235, 'order': 0.147, 'family': 0.114, 'genus': 0.065, 'species': 0.047, 'immunogenicity': 0.856, 'link': 0.729, 'weighted_mean_rank': 16.09, 'category': 'Sequence Encoders' }, 'Shallow CNN': { 'domain': 0.629, 'kingdom': 0.559, 'phylum': 0.388, 'class': 0.342, 'order': 0.238, 'family': 0.200, 'genus': 0.149, 'species': 0.115, 'immunogenicity': 0.776, 'link': 0.898, 'weighted_mean_rank': 12.53, 'category': 'Sequence Encoders' }, 'LSTM': { 'domain': 0.621, 'kingdom': 0.566, 'phylum': 0.413, 'class': 0.272, 'order': 0.174, 'family': 0.145, 'genus': 0.098, 'species': 0.078, 'immunogenicity': 0.912, 'link': 0.862, 'weighted_mean_rank': 11.00, 'category': 'Sequence Encoders' }, 'ResNet': { 'domain': 0.635, 'kingdom': 0.505, 'phylum': 0.331, 'class': 0.301, 'order': 0.183, 'family': 0.165, 'genus': 0.112, 'species': 0.073, 'immunogenicity': 0.754, 'link': 0.919, 'weighted_mean_rank': 12.09, 'category': 'Sequence Encoders' }, # === Homogeneous GNNs === 'GCN': { 'domain': 0.635, 'kingdom': 0.527, 'phylum': 0.325, 'class': 0.237, 'order': 0.147, 'family': 0.112, 'genus': 0.095, 'species': 0.080, 'immunogenicity': 0.688, 'link': 0.914, 'weighted_mean_rank': 18.38, 'category': 'Homogeneous GNNs' }, 'GAT': { 'domain': 0.636, 'kingdom': 0.523, 'phylum': 0.301, 'class': 0.265, 'order': 0.190, 'family': 0.130, 'genus': 0.125, 'species': 0.103, 'immunogenicity': 0.685, 'link': 0.934, 'weighted_mean_rank': 16.94, 'category': 'Homogeneous GNNs' }, 'GIN': { 'domain': 0.632, 'kingdom': 0.525, 'phylum': 0.322, 'class': 0.300, 'order': 0.179, 'family': 0.152, 'genus': 0.116, 'species': 0.105, 'immunogenicity': 0.716, 'link': 0.924, 'weighted_mean_rank': 15.06, 'category': 'Homogeneous GNNs' }, # === Heterogeneous GNNs === 'MPNN': { 'domain': 0.632, 'kingdom': 0.638, 'phylum': 0.372, 'class': 0.326, 'order': 0.235, 'family': 0.161, 'genus': 0.136, 'species': 0.104, 'immunogenicity': 0.674, 'link': 0.910, 'weighted_mean_rank': 18.34, 'category': 'Heterogeneous GNNs' }, 'RGCN': { 'domain': 0.633, 'kingdom': 0.647, 'phylum': 0.462, 'class': 0.373, 'order': 0.251, 'family': 0.203, 'genus': 0.164, 'species': 0.146, 'immunogenicity': 0.780, 'link': 0.948, 'weighted_mean_rank': 6.78, 'category': 'Heterogeneous GNNs' }, 'CompGCN': { 'domain': 0.629, 'kingdom': 0.568, 'phylum': 0.410, 'class': 0.381, 'order': 0.226, 'family': 0.193, 'genus': 0.166, 'species': 0.138, 'immunogenicity': 0.692, 'link': 0.945, 'weighted_mean_rank': 12.19, 'category': 'Heterogeneous GNNs' }, 'PreRGCN': { 'domain': 0.636, 'kingdom': 0.664, 'phylum': 0.451, 'class': 0.389, 'order': 0.265, 'family': 0.205, 'genus': 0.172, 'species': 0.139, 'immunogenicity': 0.781, 'link': 0.949, 'weighted_mean_rank': 5.34, 'category': 'Heterogeneous GNNs' }, 'GearNet': { 'domain': 0.471, 'kingdom': 0.577, 'phylum': 0.395, 'class': 0.389, 'order': 0.256, 'family': 0.189, 'genus': 0.165, 'species': 0.136, 'immunogenicity': 0.740, 'link': 0.892, 'weighted_mean_rank': 15.66, 'category': 'Heterogeneous GNNs' }, 'GearNet-Edge': { 'domain': 0.628, 'kingdom': 0.573, 'phylum': 0.396, 'class': 0.384, 'order': 0.262, 'family': 0.200, 'genus': 0.177, 'species': 0.140, 'immunogenicity': 0.768, 'link': 0.909, 'weighted_mean_rank': 12.25, 'category': 'Heterogeneous GNNs' }, 'ProNet': { 'domain': 0.627, 'kingdom': 0.590, 'phylum': 0.438, 'class': 0.380, 'order': 0.242, 'family': 0.192, 'genus': 0.146, 'species': 0.128, 'immunogenicity': 0.778, 'link': 0.930, 'weighted_mean_rank': 10.31, 'category': 'Heterogeneous GNNs' }, # === All-Atom Molecular Encoders === 'All-Atom RGCN': { 'domain': 0.637, 'kingdom': 0.624, 'phylum': 0.293, 'class': 0.156, 'order': 0.112, 'family': 0.096, 'genus': 0.063, 'species': 0.035, 'immunogenicity': 0.520, 'link': 0.928, 'weighted_mean_rank': 19.88, 'category': 'All-Atom Molecular Encoders' }, 'Graphormer': { 'domain': 0.640, 'kingdom': 0.468, 'phylum': 0.249, 'class': 0.201, 'order': 0.142, 'family': 0.112, 'genus': 0.077, 'species': 0.054, 'immunogenicity': 0.637, 'link': 0.856, 'weighted_mean_rank': 22.91, 'category': 'All-Atom Molecular Encoders' }, 'GraphGPS': { 'domain': 0.477, 'kingdom': 0.511, 'phylum': 0.314, 'class': 0.261, 'order': 0.153, 'family': 0.134, 'genus': 0.105, 'species': 0.065, 'immunogenicity': 0.637, 'link': 0.883, 'weighted_mean_rank': 20.38, 'category': 'All-Atom Molecular Encoders' }, 'Uni-Mol+': { 'domain': 0.639, 'kingdom': 0.446, 'phylum': 0.227, 'class': 0.174, 'order': 0.128, 'family': 0.109, 'genus': 0.077, 'species': 0.056, 'immunogenicity': 0.789, 'link': 0.885, 'weighted_mean_rank': 16.56, 'category': 'All-Atom Molecular Encoders' }, # === GlycanAA Variants === 'GlycanAA-SP': { 'domain': 0.589, 'kingdom': 0.635, 'phylum': 0.444, 'class': 0.395, 'order': 0.270, 'family': 0.205, 'genus': 0.176, 'species': 0.154, 'immunogenicity': 0.755, 'link': 0.946, 'weighted_mean_rank': 11.22, 'category': 'GlycanAA Variants' }, 'GlycanAA-AN': { 'domain': 0.609, 'kingdom': 0.685, 'phylum': 0.453, 'class': 0.427, 'order': 0.270, 'family': 0.199, 'genus': 0.179, 'species': 0.155, 'immunogenicity': 0.765, 'link': 0.947, 'weighted_mean_rank': 10.44, 'category': 'GlycanAA Variants' }, 'GlycanAA': { 'domain': 0.642, 'kingdom': 0.683, 'phylum': 0.484, 'class': 0.429, 'order': 0.291, 'family': 0.221, 'genus': 0.198, 'species': 0.157, 'immunogenicity': 0.792, 'link': 0.950, 'weighted_mean_rank': 2.56, 'category': 'GlycanAA Variants' }, # === Pre-trained All-Atom Glycan Encoders === 'VabsNet': { 'domain': 0.607, 'kingdom': 0.622, 'phylum': 0.363, 'class': 0.261, 'order': 0.175, 'family': 0.125, 'genus': 0.104, 'species': 0.068, 'immunogenicity': 0.742, 'link': 0.903, 'weighted_mean_rank': 19.03, 'category': 'Pre-trained All-Atom' }, 'GlycanAA-Attribute': { 'domain': 0.628, 'kingdom': 0.687, 'phylum': 0.457, 'class': 0.392, 'order': 0.263, 'family': 0.208, 'genus': 0.188, 'species': 0.143, 'immunogenicity': 0.722, 'link': 0.925, 'weighted_mean_rank': 10.47, 'category': 'Pre-trained All-Atom' }, 'GlycanAA-Context': { 'domain': 0.637, 'kingdom': 0.643, 'phylum': 0.453, 'class': 0.386, 'order': 0.259, 'family': 0.205, 'genus': 0.177, 'species': 0.144, 'immunogenicity': 0.768, 'link': 0.946, 'weighted_mean_rank': 7.06, 'category': 'Pre-trained All-Atom' }, 'PreGlycanAA': { 'domain': 0.661, 'kingdom': 0.688, 'phylum': 0.502, 'class': 0.447, 'order': 0.297, 'family': 0.233, 'genus': 0.203, 'species': 0.174, 'immunogenicity': 0.850, 'link': 0.961, 'weighted_mean_rank': 1.50, 'category': 'Pre-trained All-Atom' }, } df = pd.DataFrame(data).T df.index.name = 'Model' # Ensure numeric columns are float numeric_cols = [c for c in df.columns if c not in ['category']] for col in numeric_cols: df[col] = pd.to_numeric(df[col], errors='coerce') return df def get_glycanml_stds() -> Dict[str, Dict[str, float]]: """ Standard deviations for GlycanML results (from subscripts in paper). Only includes notable ones - many are small (0.00X). """ return { 'Transformer': { 'domain': 0.009, 'kingdom': 0.079, 'phylum': 0.014, 'class': 0.022, 'order': 0.007, 'family': 0.039, 'genus': 0.001, 'species': 0.008, 'immunogenicity': 0.012, 'link': 0.069, }, 'PreGlycanAA': { 'domain': 0.025, 'kingdom': 0.001, 'phylum': 0.018, 'class': 0.014, 'order': 0.005, 'family': 0.010, 'genus': 0.003, 'species': 0.004, 'immunogenicity': 0.044, 'link': 0.011, }, 'GlycanAA': { 'domain': 0.002, 'kingdom': 0.002, 'phylum': 0.009, 'class': 0.022, 'order': 0.003, 'family': 0.002, 'genus': 0.011, 'species': 0.011, 'immunogenicity': 0.021, 'link': 0.020, }, } # ============================================================================= # COMBINED BASELINES # ============================================================================= def get_all_baselines() -> pd.DataFrame: """ Get all baseline results combined into a single DataFrame. Returns DataFrame with: - Index: Model names - Columns: All 11 tasks + weighted_mean_rank + category """ return get_glycanml_baselines() def get_task_columns() -> List[str]: """Return list of task column names in order.""" return ALL_TASKS def compute_weighted_mean_rank(df: pd.DataFrame, tasks: List[str] = None) -> pd.Series: """ Compute weighted mean rank across tasks, like GlycanML paper. For each task, models are ranked 1 to N based on performance. The weighted mean rank is the average rank across all tasks. Args: df: DataFrame with models as rows, tasks as columns tasks: List of task columns to use (defaults to ALL_TASKS) Returns: Series with weighted mean rank per model """ if tasks is None: tasks = [t for t in ALL_TASKS if t in df.columns] # For each task, rank models (lower rank = better) ranks = pd.DataFrame(index=df.index) for task in tasks: if task in df.columns: # Rank in descending order (higher value = better = rank 1) ranks[task] = df[task].rank(ascending=False, method='min') # Compute mean rank return ranks.mean(axis=1) def get_sota_per_task(df: pd.DataFrame) -> Dict[str, Tuple[str, float]]: """ Get the SOTA (best performing) model for each task. Returns: Dict mapping task -> (model_name, score) """ sota = {} for task in ALL_TASKS: if task in df.columns: best_idx = df[task].idxmax() sota[task] = (best_idx, df.loc[best_idx, task]) return sota def get_top_n_per_task(df: pd.DataFrame, n: int = 3) -> Dict[str, List[Tuple[str, float]]]: """ Get top N models for each task. Returns: Dict mapping task -> [(model_name, score), ...] """ top_n = {} for task in ALL_TASKS: if task in df.columns: sorted_models = df[task].sort_values(ascending=False) top_n[task] = [(idx, val) for idx, val in sorted_models.head(n).items()] return top_n # ============================================================================= # UTILITY FUNCTIONS # ============================================================================= def format_value(val: float, task: str) -> str: """Format value for display based on task type.""" if pd.isna(val): return '-' return f'{val:.3f}' def get_metric_for_task(task: str) -> str: """Get the primary metric name for a task.""" return PRIMARY_METRICS.get(task, 'macro_f1') def get_model_categories() -> Dict[str, List[str]]: """Get models grouped by category.""" df = get_glycanml_baselines() categories = {} for model, row in df.iterrows(): cat = row.get('category', 'Other') if cat not in categories: categories[cat] = [] categories[cat].append(model) return categories if __name__ == '__main__': # Test the module print("=== GlycanML Baselines ===") df = get_glycanml_baselines() print(f"Models: {len(df)}") print(f"Tasks: {list(df.columns)}") print() print("=== SOTA per task ===") sota = get_sota_per_task(df) for task, (model, score) in sota.items(): print(f" {task}: {model} ({score:.3f})") print() print("=== Top models by weighted mean rank ===") top_models = df.nsmallest(5, 'weighted_mean_rank')[['weighted_mean_rank', 'category']] print(top_models)