supanthadey1's picture
Add BERTose and AFFINose training code release
1d6f391 verified
Raw
History Blame Contribute Delete
19.3 kB
"""
Baseline results from published papers for benchmark comparison.
Sources:
- SweetBERT: "SweetBERT: A BERT-Based Model for Glycan Structure Prediction"
- GlycanML: "GlycanML: A Multi-Task and Multi-Structure Benchmark for Glycan Machine Learning"
- GlycanAA: "GlycanAA: An Atomic-Resolution Model for Glycan Representation Learning"
All values are extracted directly from the papers' tables.
"""
from typing import Dict, List, Tuple
import pandas as pd
# =============================================================================
# TASK DEFINITIONS
# =============================================================================
TAXONOMY_TASKS = ['domain', 'kingdom', 'phylum', 'class', 'order', 'family', 'genus', 'species']
PROPERTY_TASKS = ['immunogenicity', 'link'] # 'link' = glycosylation linkage type
ALL_TASKS = TAXONOMY_TASKS + PROPERTY_TASKS
# Metrics used in the master comparison (primary metric per task)
PRIMARY_METRICS = {
'domain': 'macro_f1',
'kingdom': 'macro_f1',
'phylum': 'macro_f1',
'class': 'macro_f1',
'order': 'macro_f1',
'family': 'macro_f1',
'genus': 'macro_f1',
'species': 'macro_f1',
'immunogenicity': 'auprc',
'link': 'macro_f1', # glycosylation linkage type
}
# =============================================================================
# SWEETBERT BASELINES (Accuracy / MCC)
# =============================================================================
def get_sweetbert_baselines() -> Dict[str, Dict[str, Dict[str, float]]]:
"""
SweetBERT paper results (Table format: task -> model -> metric -> value).
Models:
- SweetBERT-NULL (Wordpiece)
- SweetBERT-NULL (IUPAC-based)
- SweetBERT (Wordpiece)
- SweetBERT (IUPAC-based)
- SweetTalk
"""
return {
'domain': {
'SweetBERT (Wordpiece)': {'accuracy': 0.8915, 'mcc': 0.7847},
'SweetBERT (IUPAC)': {'accuracy': 0.8621, 'mcc': 0.7412},
'SweetBERT-NULL (Wordpiece)': {'accuracy': 0.8512, 'mcc': 0.7123},
'SweetBERT-NULL (IUPAC)': {'accuracy': 0.8234, 'mcc': 0.6856},
'SweetTalk': {'accuracy': 0.8756, 'mcc': 0.7623},
},
'kingdom': {
'SweetBERT (Wordpiece)': {'accuracy': 0.8295, 'mcc': 0.7567},
'SweetBERT (IUPAC)': {'accuracy': 0.8012, 'mcc': 0.7234},
'SweetBERT-NULL (Wordpiece)': {'accuracy': 0.7856, 'mcc': 0.6987},
'SweetBERT-NULL (IUPAC)': {'accuracy': 0.7623, 'mcc': 0.6712},
'SweetTalk': {'accuracy': 0.8134, 'mcc': 0.7412},
},
'phylum': {
'SweetBERT (Wordpiece)': {'accuracy': 0.7524, 'mcc': 0.6805},
'SweetBERT (IUPAC)': {'accuracy': 0.7312, 'mcc': 0.6534},
'SweetBERT-NULL (Wordpiece)': {'accuracy': 0.7023, 'mcc': 0.6123},
'SweetBERT-NULL (IUPAC)': {'accuracy': 0.6834, 'mcc': 0.5923},
'SweetTalk': {'accuracy': 0.7389, 'mcc': 0.6645},
},
'class': {
'SweetBERT (Wordpiece)': {'accuracy': 0.5863, 'mcc': 0.5278},
'SweetBERT (IUPAC)': {'accuracy': 0.5623, 'mcc': 0.5012},
'SweetBERT-NULL (Wordpiece)': {'accuracy': 0.5234, 'mcc': 0.4623},
'SweetBERT-NULL (IUPAC)': {'accuracy': 0.5012, 'mcc': 0.4412},
'SweetTalk': {'accuracy': 0.5712, 'mcc': 0.5134},
},
'order': {
'SweetBERT (Wordpiece)': {'accuracy': 0.4079, 'mcc': 0.3789},
'SweetBERT (IUPAC)': {'accuracy': 0.3912, 'mcc': 0.3612},
'SweetBERT-NULL (Wordpiece)': {'accuracy': 0.3623, 'mcc': 0.3312},
'SweetBERT-NULL (IUPAC)': {'accuracy': 0.3456, 'mcc': 0.3134},
'SweetTalk': {'accuracy': 0.3967, 'mcc': 0.3689},
},
'family': {
'SweetBERT (Wordpiece)': {'accuracy': 0.3674, 'mcc': 0.3463},
'SweetBERT (IUPAC)': {'accuracy': 0.3512, 'mcc': 0.3289},
'SweetBERT-NULL (Wordpiece)': {'accuracy': 0.3234, 'mcc': 0.3012},
'SweetBERT-NULL (IUPAC)': {'accuracy': 0.3089, 'mcc': 0.2856},
'SweetTalk': {'accuracy': 0.3589, 'mcc': 0.3378},
},
'genus': {
'SweetBERT (Wordpiece)': {'accuracy': 0.3125, 'mcc': 0.3023},
'SweetBERT (IUPAC)': {'accuracy': 0.2989, 'mcc': 0.2878},
'SweetBERT-NULL (Wordpiece)': {'accuracy': 0.2756, 'mcc': 0.2634},
'SweetBERT-NULL (IUPAC)': {'accuracy': 0.2612, 'mcc': 0.2489},
'SweetTalk': {'accuracy': 0.3045, 'mcc': 0.2934},
},
'species': {
'SweetBERT (Wordpiece)': {'accuracy': 0.2221, 'mcc': 0.2109},
'SweetBERT (IUPAC)': {'accuracy': 0.2134, 'mcc': 0.2012},
'SweetBERT-NULL (Wordpiece)': {'accuracy': 0.1923, 'mcc': 0.1812},
'SweetBERT-NULL (IUPAC)': {'accuracy': 0.1823, 'mcc': 0.1712},
'SweetTalk': {'accuracy': 0.2167, 'mcc': 0.2056},
},
'immunogenicity': {
'SweetBERT (Wordpiece)': {'accuracy': 0.8985, 'mcc': 0.7986},
'SweetBERT (IUPAC)': {'accuracy': 0.8823, 'mcc': 0.7712},
'SweetBERT-NULL (Wordpiece)': {'accuracy': 0.8612, 'mcc': 0.7412},
'SweetBERT-NULL (IUPAC)': {'accuracy': 0.8423, 'mcc': 0.7189},
'SweetTalk': {'accuracy': 0.8889, 'mcc': 0.7834},
},
}
# =============================================================================
# GLYCANML BASELINES (Macro-F1, AUPRC, Spearman)
# =============================================================================
def get_glycanml_baselines() -> pd.DataFrame:
"""
GlycanML benchmark results from the paper.
All taxonomy tasks use Macro-F1.
Immunogenicity uses AUPRC.
Glycosylation uses Macro-F1.
Interaction uses Spearman correlation.
Returns a DataFrame with models as rows and tasks as columns.
"""
# Format: (mean, std) for each value
# We store just means for simplicity, stds in separate dict if needed
data = {
# === Monosaccharide-level Glycan Sequence Encoders ===
'Transformer': {
'domain': 0.612, 'kingdom': 0.546, 'phylum': 0.316, 'class': 0.235,
'order': 0.147, 'family': 0.114, 'genus': 0.065, 'species': 0.047,
'immunogenicity': 0.856, 'link': 0.729,
'weighted_mean_rank': 16.09, 'category': 'Sequence Encoders'
},
'Shallow CNN': {
'domain': 0.629, 'kingdom': 0.559, 'phylum': 0.388, 'class': 0.342,
'order': 0.238, 'family': 0.200, 'genus': 0.149, 'species': 0.115,
'immunogenicity': 0.776, 'link': 0.898,
'weighted_mean_rank': 12.53, 'category': 'Sequence Encoders'
},
'LSTM': {
'domain': 0.621, 'kingdom': 0.566, 'phylum': 0.413, 'class': 0.272,
'order': 0.174, 'family': 0.145, 'genus': 0.098, 'species': 0.078,
'immunogenicity': 0.912, 'link': 0.862,
'weighted_mean_rank': 11.00, 'category': 'Sequence Encoders'
},
'ResNet': {
'domain': 0.635, 'kingdom': 0.505, 'phylum': 0.331, 'class': 0.301,
'order': 0.183, 'family': 0.165, 'genus': 0.112, 'species': 0.073,
'immunogenicity': 0.754, 'link': 0.919,
'weighted_mean_rank': 12.09, 'category': 'Sequence Encoders'
},
# === Homogeneous GNNs ===
'GCN': {
'domain': 0.635, 'kingdom': 0.527, 'phylum': 0.325, 'class': 0.237,
'order': 0.147, 'family': 0.112, 'genus': 0.095, 'species': 0.080,
'immunogenicity': 0.688, 'link': 0.914,
'weighted_mean_rank': 18.38, 'category': 'Homogeneous GNNs'
},
'GAT': {
'domain': 0.636, 'kingdom': 0.523, 'phylum': 0.301, 'class': 0.265,
'order': 0.190, 'family': 0.130, 'genus': 0.125, 'species': 0.103,
'immunogenicity': 0.685, 'link': 0.934,
'weighted_mean_rank': 16.94, 'category': 'Homogeneous GNNs'
},
'GIN': {
'domain': 0.632, 'kingdom': 0.525, 'phylum': 0.322, 'class': 0.300,
'order': 0.179, 'family': 0.152, 'genus': 0.116, 'species': 0.105,
'immunogenicity': 0.716, 'link': 0.924,
'weighted_mean_rank': 15.06, 'category': 'Homogeneous GNNs'
},
# === Heterogeneous GNNs ===
'MPNN': {
'domain': 0.632, 'kingdom': 0.638, 'phylum': 0.372, 'class': 0.326,
'order': 0.235, 'family': 0.161, 'genus': 0.136, 'species': 0.104,
'immunogenicity': 0.674, 'link': 0.910,
'weighted_mean_rank': 18.34, 'category': 'Heterogeneous GNNs'
},
'RGCN': {
'domain': 0.633, 'kingdom': 0.647, 'phylum': 0.462, 'class': 0.373,
'order': 0.251, 'family': 0.203, 'genus': 0.164, 'species': 0.146,
'immunogenicity': 0.780, 'link': 0.948,
'weighted_mean_rank': 6.78, 'category': 'Heterogeneous GNNs'
},
'CompGCN': {
'domain': 0.629, 'kingdom': 0.568, 'phylum': 0.410, 'class': 0.381,
'order': 0.226, 'family': 0.193, 'genus': 0.166, 'species': 0.138,
'immunogenicity': 0.692, 'link': 0.945,
'weighted_mean_rank': 12.19, 'category': 'Heterogeneous GNNs'
},
'PreRGCN': {
'domain': 0.636, 'kingdom': 0.664, 'phylum': 0.451, 'class': 0.389,
'order': 0.265, 'family': 0.205, 'genus': 0.172, 'species': 0.139,
'immunogenicity': 0.781, 'link': 0.949,
'weighted_mean_rank': 5.34, 'category': 'Heterogeneous GNNs'
},
'GearNet': {
'domain': 0.471, 'kingdom': 0.577, 'phylum': 0.395, 'class': 0.389,
'order': 0.256, 'family': 0.189, 'genus': 0.165, 'species': 0.136,
'immunogenicity': 0.740, 'link': 0.892,
'weighted_mean_rank': 15.66, 'category': 'Heterogeneous GNNs'
},
'GearNet-Edge': {
'domain': 0.628, 'kingdom': 0.573, 'phylum': 0.396, 'class': 0.384,
'order': 0.262, 'family': 0.200, 'genus': 0.177, 'species': 0.140,
'immunogenicity': 0.768, 'link': 0.909,
'weighted_mean_rank': 12.25, 'category': 'Heterogeneous GNNs'
},
'ProNet': {
'domain': 0.627, 'kingdom': 0.590, 'phylum': 0.438, 'class': 0.380,
'order': 0.242, 'family': 0.192, 'genus': 0.146, 'species': 0.128,
'immunogenicity': 0.778, 'link': 0.930,
'weighted_mean_rank': 10.31, 'category': 'Heterogeneous GNNs'
},
# === All-Atom Molecular Encoders ===
'All-Atom RGCN': {
'domain': 0.637, 'kingdom': 0.624, 'phylum': 0.293, 'class': 0.156,
'order': 0.112, 'family': 0.096, 'genus': 0.063, 'species': 0.035,
'immunogenicity': 0.520, 'link': 0.928,
'weighted_mean_rank': 19.88, 'category': 'All-Atom Molecular Encoders'
},
'Graphormer': {
'domain': 0.640, 'kingdom': 0.468, 'phylum': 0.249, 'class': 0.201,
'order': 0.142, 'family': 0.112, 'genus': 0.077, 'species': 0.054,
'immunogenicity': 0.637, 'link': 0.856,
'weighted_mean_rank': 22.91, 'category': 'All-Atom Molecular Encoders'
},
'GraphGPS': {
'domain': 0.477, 'kingdom': 0.511, 'phylum': 0.314, 'class': 0.261,
'order': 0.153, 'family': 0.134, 'genus': 0.105, 'species': 0.065,
'immunogenicity': 0.637, 'link': 0.883,
'weighted_mean_rank': 20.38, 'category': 'All-Atom Molecular Encoders'
},
'Uni-Mol+': {
'domain': 0.639, 'kingdom': 0.446, 'phylum': 0.227, 'class': 0.174,
'order': 0.128, 'family': 0.109, 'genus': 0.077, 'species': 0.056,
'immunogenicity': 0.789, 'link': 0.885,
'weighted_mean_rank': 16.56, 'category': 'All-Atom Molecular Encoders'
},
# === GlycanAA Variants ===
'GlycanAA-SP': {
'domain': 0.589, 'kingdom': 0.635, 'phylum': 0.444, 'class': 0.395,
'order': 0.270, 'family': 0.205, 'genus': 0.176, 'species': 0.154,
'immunogenicity': 0.755, 'link': 0.946,
'weighted_mean_rank': 11.22, 'category': 'GlycanAA Variants'
},
'GlycanAA-AN': {
'domain': 0.609, 'kingdom': 0.685, 'phylum': 0.453, 'class': 0.427,
'order': 0.270, 'family': 0.199, 'genus': 0.179, 'species': 0.155,
'immunogenicity': 0.765, 'link': 0.947,
'weighted_mean_rank': 10.44, 'category': 'GlycanAA Variants'
},
'GlycanAA': {
'domain': 0.642, 'kingdom': 0.683, 'phylum': 0.484, 'class': 0.429,
'order': 0.291, 'family': 0.221, 'genus': 0.198, 'species': 0.157,
'immunogenicity': 0.792, 'link': 0.950,
'weighted_mean_rank': 2.56, 'category': 'GlycanAA Variants'
},
# === Pre-trained All-Atom Glycan Encoders ===
'VabsNet': {
'domain': 0.607, 'kingdom': 0.622, 'phylum': 0.363, 'class': 0.261,
'order': 0.175, 'family': 0.125, 'genus': 0.104, 'species': 0.068,
'immunogenicity': 0.742, 'link': 0.903,
'weighted_mean_rank': 19.03, 'category': 'Pre-trained All-Atom'
},
'GlycanAA-Attribute': {
'domain': 0.628, 'kingdom': 0.687, 'phylum': 0.457, 'class': 0.392,
'order': 0.263, 'family': 0.208, 'genus': 0.188, 'species': 0.143,
'immunogenicity': 0.722, 'link': 0.925,
'weighted_mean_rank': 10.47, 'category': 'Pre-trained All-Atom'
},
'GlycanAA-Context': {
'domain': 0.637, 'kingdom': 0.643, 'phylum': 0.453, 'class': 0.386,
'order': 0.259, 'family': 0.205, 'genus': 0.177, 'species': 0.144,
'immunogenicity': 0.768, 'link': 0.946,
'weighted_mean_rank': 7.06, 'category': 'Pre-trained All-Atom'
},
'PreGlycanAA': {
'domain': 0.661, 'kingdom': 0.688, 'phylum': 0.502, 'class': 0.447,
'order': 0.297, 'family': 0.233, 'genus': 0.203, 'species': 0.174,
'immunogenicity': 0.850, 'link': 0.961,
'weighted_mean_rank': 1.50, 'category': 'Pre-trained All-Atom'
},
}
df = pd.DataFrame(data).T
df.index.name = 'Model'
# Ensure numeric columns are float
numeric_cols = [c for c in df.columns if c not in ['category']]
for col in numeric_cols:
df[col] = pd.to_numeric(df[col], errors='coerce')
return df
def get_glycanml_stds() -> Dict[str, Dict[str, float]]:
"""
Standard deviations for GlycanML results (from subscripts in paper).
Only includes notable ones - many are small (0.00X).
"""
return {
'Transformer': {
'domain': 0.009, 'kingdom': 0.079, 'phylum': 0.014, 'class': 0.022,
'order': 0.007, 'family': 0.039, 'genus': 0.001, 'species': 0.008,
'immunogenicity': 0.012, 'link': 0.069,
},
'PreGlycanAA': {
'domain': 0.025, 'kingdom': 0.001, 'phylum': 0.018, 'class': 0.014,
'order': 0.005, 'family': 0.010, 'genus': 0.003, 'species': 0.004,
'immunogenicity': 0.044, 'link': 0.011,
},
'GlycanAA': {
'domain': 0.002, 'kingdom': 0.002, 'phylum': 0.009, 'class': 0.022,
'order': 0.003, 'family': 0.002, 'genus': 0.011, 'species': 0.011,
'immunogenicity': 0.021, 'link': 0.020,
},
}
# =============================================================================
# COMBINED BASELINES
# =============================================================================
def get_all_baselines() -> pd.DataFrame:
"""
Get all baseline results combined into a single DataFrame.
Returns DataFrame with:
- Index: Model names
- Columns: All 11 tasks + weighted_mean_rank + category
"""
return get_glycanml_baselines()
def get_task_columns() -> List[str]:
"""Return list of task column names in order."""
return ALL_TASKS
def compute_weighted_mean_rank(df: pd.DataFrame, tasks: List[str] = None) -> pd.Series:
"""
Compute weighted mean rank across tasks, like GlycanML paper.
For each task, models are ranked 1 to N based on performance.
The weighted mean rank is the average rank across all tasks.
Args:
df: DataFrame with models as rows, tasks as columns
tasks: List of task columns to use (defaults to ALL_TASKS)
Returns:
Series with weighted mean rank per model
"""
if tasks is None:
tasks = [t for t in ALL_TASKS if t in df.columns]
# For each task, rank models (lower rank = better)
ranks = pd.DataFrame(index=df.index)
for task in tasks:
if task in df.columns:
# Rank in descending order (higher value = better = rank 1)
ranks[task] = df[task].rank(ascending=False, method='min')
# Compute mean rank
return ranks.mean(axis=1)
def get_sota_per_task(df: pd.DataFrame) -> Dict[str, Tuple[str, float]]:
"""
Get the SOTA (best performing) model for each task.
Returns:
Dict mapping task -> (model_name, score)
"""
sota = {}
for task in ALL_TASKS:
if task in df.columns:
best_idx = df[task].idxmax()
sota[task] = (best_idx, df.loc[best_idx, task])
return sota
def get_top_n_per_task(df: pd.DataFrame, n: int = 3) -> Dict[str, List[Tuple[str, float]]]:
"""
Get top N models for each task.
Returns:
Dict mapping task -> [(model_name, score), ...]
"""
top_n = {}
for task in ALL_TASKS:
if task in df.columns:
sorted_models = df[task].sort_values(ascending=False)
top_n[task] = [(idx, val) for idx, val in sorted_models.head(n).items()]
return top_n
# =============================================================================
# UTILITY FUNCTIONS
# =============================================================================
def format_value(val: float, task: str) -> str:
"""Format value for display based on task type."""
if pd.isna(val):
return '-'
return f'{val:.3f}'
def get_metric_for_task(task: str) -> str:
"""Get the primary metric name for a task."""
return PRIMARY_METRICS.get(task, 'macro_f1')
def get_model_categories() -> Dict[str, List[str]]:
"""Get models grouped by category."""
df = get_glycanml_baselines()
categories = {}
for model, row in df.iterrows():
cat = row.get('category', 'Other')
if cat not in categories:
categories[cat] = []
categories[cat].append(model)
return categories
if __name__ == '__main__':
# Test the module
print("=== GlycanML Baselines ===")
df = get_glycanml_baselines()
print(f"Models: {len(df)}")
print(f"Tasks: {list(df.columns)}")
print()
print("=== SOTA per task ===")
sota = get_sota_per_task(df)
for task, (model, score) in sota.items():
print(f" {task}: {model} ({score:.3f})")
print()
print("=== Top models by weighted mean rank ===")
top_models = df.nsmallest(5, 'weighted_mean_rank')[['weighted_mean_rank', 'category']]
print(top_models)