bertose-affinose-training-code / code /probes /probe_2_layerwise_cls.py
supanthadey1's picture
Add BERTose and AFFINose training code release
1d6f391 verified
Raw
History Blame Contribute Delete
16.6 kB
#!/usr/bin/env python3
"""
Probe 2: Layer-wise CLS Probing for GlycanBERT V6.
Extracts the [CLS] token representation at each of the 12 transformer layers
and trains a logistic regression classifier to measure how much task-relevant
information is encoded at each depth.
"""
import os, sys, json, csv, argparse
import numpy as np
from pathlib import Path
from collections import Counter
PROJECT_ROOT = Path(__file__).resolve().parents[2]
VOCAB_PATH = PROJECT_ROOT / 'bert_training_v4' / 'data' / 'vocabulary.json'
CHECKPOINTS = {
'v5': PROJECT_ROOT / 'checkpoints_v5_bpe_topo' / 'best_v5_bpe_topo_model.pt',
'v6': PROJECT_ROOT / 'bert_v5.1_contrastive' / 'checkpoints' / 'best_v51_contrastive_model.pt',
}
BENCH_DIR = PROJECT_ROOT / 'bench' / 'GlycanML' / 'data'
sys.path.insert(0, str(PROJECT_ROOT))
sys.path.insert(0, str(PROJECT_ROOT / 'bert_training_v4'))
from model.multimodal_glycan_bert_v3 import MultimodalGlycanBERT, MultimodalGlycanBERTConfig
from downstream_tasks.utils.tokenizer import WURCSTokenizer
def load_model(ckpt_path, device='cuda'):
import torch
print(f"Loading model from {ckpt_path}...")
ckpt = torch.load(ckpt_path, map_location='cpu', weights_only=False)
state_dict = ckpt.get('model_state_dict', ckpt)
backbone_sd = {k: v for k, v in state_dict.items() if not k.startswith('proj_head.')}
n_stripped = len(state_dict) - len(backbone_sd)
if n_stripped > 0:
print(f" Stripped {n_stripped} projection head keys")
vocab_size = backbone_sd['seq_embeddings.token_embeddings.weight'].shape[0]
ms_total_vocab = None
if 'ms_embeddings.token_embeddings.weight' in backbone_sd:
ms_total_vocab = backbone_sd['ms_embeddings.token_embeddings.weight'].shape[0]
config_kwargs = dict(
seq_vocab_size=vocab_size, seq_hidden_size=768, seq_num_layers=12,
seq_num_heads=12, seq_max_length=256, use_cnn_frontend=True, cnn_kernel_size=3,
)
if ms_total_vocab is not None:
config_kwargs['ms_vocab_size'] = ms_total_vocab - vocab_size
config = MultimodalGlycanBERTConfig(**config_kwargs)
model = MultimodalGlycanBERT(config)
model.load_state_dict(backbone_sd, strict=False)
model.to(device).eval()
print(f" Model loaded: {sum(p.numel() for p in model.parameters()):,} params")
return model
def extract_layerwise_cls(model, samples, device='cuda', max_len=256):
import torch, torch.nn.functional as F
tokenizer = WURCSTokenizer(str(VOCAB_PATH))
n_layers = len(model.seq_layers)
layer_embs = {i: [] for i in range(n_layers + 1)}
n_errors = 0
for si, s in enumerate(samples):
try:
result = tokenizer.tokenize(s['wurcs'], max_length=max_len)
token_ids = torch.tensor(result['token_ids'], dtype=torch.long)
branch_depths = torch.tensor(result.get('branch_depths', [0]*len(result['token_ids'])), dtype=torch.long)
linkage_types = torch.tensor(result.get('linkage_types', [0]*len(result['token_ids'])), dtype=torch.long)
min_l = min(len(token_ids), len(branch_depths), len(linkage_types))
token_ids, branch_depths, linkage_types = token_ids[:min_l], branch_depths[:min_l], linkage_types[:min_l]
if min_l > max_len:
token_ids, branch_depths, linkage_types = token_ids[:max_len], branch_depths[:max_len], linkage_types[:max_len]
elif min_l < max_len:
pad = max_len - min_l
token_ids = F.pad(token_ids, (0, pad), value=0)
branch_depths = F.pad(branch_depths, (0, pad), value=0)
linkage_types = F.pad(linkage_types, (0, pad), value=0)
with torch.no_grad():
hidden = model.seq_embeddings(
token_ids.unsqueeze(0).to(device),
branch_depths=branch_depths.unsqueeze(0).to(device),
linkage_types=linkage_types.unsqueeze(0).to(device)
)
layer_embs[0].append(hidden[0, 0, :].cpu().numpy())
for layer_idx, layer in enumerate(model.seq_layers):
hidden = layer(hidden)
layer_embs[layer_idx + 1].append(hidden[0, 0, :].cpu().numpy())
except Exception as e:
n_errors += 1
if n_errors <= 3: print(f" ERROR sample {si}: {e}")
for i in range(n_layers + 1):
layer_embs[i].append(np.zeros(768))
if si > 0 and si % 500 == 0:
print(f" Processed {si}/{len(samples)}")
if n_errors > 0:
print(f" WARNING: {n_errors}/{len(samples)} errors")
return {i: np.array(embs) for i, embs in layer_embs.items()}
def load_domain_data():
csv_path = BENCH_DIR / 'glycan_classification_wurcs_subset.csv'
samples, labels = [], []
with open(csv_path) as f:
for row in csv.DictReader(f):
w = row.get('wurcs', '')
domain = row.get('domain', '')
if w.startswith('WURCS') and domain in ('Eukarya', 'Bacteria', 'Virus'):
samples.append({'wurcs': w})
labels.append(domain)
print(f" Domain data: {len(samples)} samples, {Counter(labels)}")
return samples, labels
def load_glycosylation_data():
csv_path = BENCH_DIR / 'glycan_link_wurcs_subset.csv'
samples, labels = [], []
with open(csv_path) as f:
for row in csv.DictReader(f):
w = row.get('wurcs', '')
link = row.get('link', '')
if w.startswith('WURCS') and link in ('N', 'O'):
samples.append({'wurcs': w})
labels.append(link)
print(f" Glycosylation data: {len(samples)} samples, {Counter(labels)}")
return samples, labels
def load_immunogenicity_data():
csv_path = BENCH_DIR / 'glycan_immunogenicity_wurcs_subset.csv'
samples, labels = [], []
with open(csv_path) as f:
for row in csv.DictReader(f):
w = row.get('wurcs', '')
imm = row.get('immunogenicity', '')
if w.startswith('WURCS') and imm:
samples.append({'wurcs': w})
labels.append(int(float(imm)))
print(f" Immunogenicity data: {len(samples)} samples, {Counter(labels)}")
return samples, labels
def train_linear_probe(X, y, task_name, n_splits=5):
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import StratifiedKFold
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.metrics import accuracy_score, f1_score
le = LabelEncoder()
y_enc = le.fit_transform(y)
n_classes = len(le.classes_)
skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=42)
accs, f1s = [], []
for fold, (train_idx, test_idx) in enumerate(skf.split(X, y_enc)):
scaler = StandardScaler()
X_train = scaler.fit_transform(X[train_idx])
X_test = scaler.transform(X[test_idx])
clf = LogisticRegression(max_iter=1000, multi_class='multinomial' if n_classes > 2 else 'auto', solver='lbfgs', random_state=42)
clf.fit(X_train, y_enc[train_idx])
y_pred = clf.predict(X_test)
accs.append(accuracy_score(y_enc[test_idx], y_pred))
avg = 'macro' if n_classes > 2 else 'binary'
f1s.append(f1_score(y_enc[test_idx], y_pred, average=avg))
return {'task': task_name, 'accuracy_mean': float(np.mean(accs)), 'accuracy_std': float(np.std(accs)),
'f1_mean': float(np.mean(f1s)), 'f1_std': float(np.std(f1s)),
'n_samples': len(y), 'n_classes': n_classes, 'classes': le.classes_.tolist()}
def run_layerwise_probing(layer_embs, labels, task_name, n_layers=13):
print(f"\n Probing: {task_name} ({len(labels)} samples)")
results = []
for layer_idx in range(n_layers):
X = layer_embs[layer_idx]
res = train_linear_probe(X, labels, task_name)
res['layer'] = layer_idx
results.append(res)
layer_label = "emb" if layer_idx == 0 else str(layer_idx)
print(f" Layer {layer_label:>3s}: acc={res['accuracy_mean']:.4f}+/-{res['accuracy_std']:.4f} f1={res['f1_mean']:.4f}+/-{res['f1_std']:.4f}")
return results
def plot_layerwise_results(all_results, output_dir, model_name='V6'):
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
plt.rcParams.update({
'font.family': 'sans-serif', 'font.sans-serif': ['Arial', 'Helvetica', 'DejaVu Sans'],
'font.size': 11, 'axes.titlesize': 13, 'axes.labelsize': 12,
'xtick.labelsize': 10, 'ytick.labelsize': 10, 'legend.fontsize': 10,
'figure.dpi': 300, 'savefig.dpi': 300, 'savefig.bbox': 'tight',
'axes.linewidth': 0.8, 'axes.spines.top': False, 'axes.spines.right': False,
})
task_colors = {'Domain': '#E63946', 'Immunogenicity': '#457B9D', 'Glycosylation': '#2A9D8F'}
task_markers = {'Domain': 'o', 'Immunogenicity': 's', 'Glycosylation': '^'}
tasks = {}
for r in all_results:
t = r['task']
if t not in tasks: tasks[t] = []
tasks[t].append(r)
# Two-panel figure
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5.5))
for task_name, results in tasks.items():
results = sorted(results, key=lambda r: r['layer'])
layers = [r['layer'] for r in results]
accs = [r['accuracy_mean'] for r in results]
acc_stds = [r['accuracy_std'] for r in results]
f1s = [r['f1_mean'] for r in results]
f1_stds = [r['f1_std'] for r in results]
color = task_colors.get(task_name, '#333')
marker = task_markers.get(task_name, 'o')
ax1.errorbar(layers, accs, yerr=acc_stds, color=color, marker=marker, markersize=7, linewidth=2, capsize=3, capthick=1.2, label=f"{task_name} (n={results[0]['n_samples']})", zorder=3)
ax1.fill_between(layers, [a-s for a,s in zip(accs,acc_stds)], [a+s for a,s in zip(accs,acc_stds)], color=color, alpha=0.1)
ax2.errorbar(layers, f1s, yerr=f1_stds, color=color, marker=marker, markersize=7, linewidth=2, capsize=3, capthick=1.2, label=task_name, zorder=3)
ax2.fill_between(layers, [f-s for f,s in zip(f1s,f1_stds)], [f+s for f,s in zip(f1s,f1_stds)], color=color, alpha=0.1)
for ax, ylabel, title, panel in [(ax1, 'Accuracy', 'Accuracy', '(a)'), (ax2, 'Macro F1', 'F1 Score', '(b)')]:
ax.set_xlabel('Layer'); ax.set_ylabel(ylabel)
ax.set_title(f'Layer-wise Linear Probe {title} — GlycanBERT {model_name}')
ax.set_xticks(range(13)); ax.set_xticklabels(['emb'] + [str(i) for i in range(1, 13)])
ax.legend(frameon=False, loc='lower right'); ax.grid(axis='y', alpha=0.3, linestyle='--')
ax.axhline(y=0.5, color='gray', linestyle=':', alpha=0.5)
ax.text(-0.08, 1.05, panel, transform=ax.transAxes, fontsize=14, fontweight='bold')
plt.tight_layout()
out = Path(output_dir); out.mkdir(parents=True, exist_ok=True)
for fmt in ['png', 'pdf']:
fp = out / f'accuracy_vs_layer_{model_name.lower()}.{fmt}'
plt.savefig(fp, dpi=300, bbox_inches='tight', facecolor='white')
print(f" Saved: {fp}")
plt.close()
# Standalone accuracy-only plot
fig2, ax = plt.subplots(1, 1, figsize=(8, 5.5))
for task_name, results in tasks.items():
results = sorted(results, key=lambda r: r['layer'])
layers = [r['layer'] for r in results]
accs = [r['accuracy_mean'] for r in results]
acc_stds = [r['accuracy_std'] for r in results]
color = task_colors.get(task_name, '#333')
marker = task_markers.get(task_name, 'o')
ax.errorbar(layers, accs, yerr=acc_stds, color=color, marker=marker, markersize=8, linewidth=2.5, capsize=3, capthick=1.2, label=f"{task_name} (n={results[0]['n_samples']})", zorder=3)
ax.fill_between(layers, [a-s for a,s in zip(accs,acc_stds)], [a+s for a,s in zip(accs,acc_stds)], color=color, alpha=0.1)
ax.set_xlabel('Transformer Layer', fontsize=13); ax.set_ylabel('Linear Probe Accuracy (5-fold CV)', fontsize=13)
ax.set_title(f'Layer-wise Representation Quality — GlycanBERT {model_name}', fontsize=14, fontweight='bold')
ax.set_xticks(range(13)); ax.set_xticklabels(['Emb'] + [str(i) for i in range(1, 13)], fontsize=10)
ax.legend(frameon=True, fancybox=True, shadow=False, edgecolor='#ccc', loc='lower right', fontsize=11)
ax.grid(axis='y', alpha=0.3, linestyle='--'); ax.axhline(y=0.5, color='gray', linestyle=':', alpha=0.4)
plt.tight_layout()
for fmt in ['png', 'pdf']:
fp = out / f'accuracy_vs_layer_standalone_{model_name.lower()}.{fmt}'
plt.savefig(fp, dpi=300, bbox_inches='tight', facecolor='white')
print(f" Saved: {fp}")
plt.close()
def main():
parser = argparse.ArgumentParser(description='Probe 2: Layer-wise CLS probing')
parser.add_argument('--model', type=str, default='v6', choices=['v5', 'v6'])
parser.add_argument('--device', type=str, default='cuda')
parser.add_argument('--output_dir', type=str, default=str(PROJECT_ROOT / 'bert_v6_contrastive' / 'analysis' / 'probe_results_v6' / 'probe_2_layerwise_cls'))
args = parser.parse_args()
model_name = args.model.upper()
print(f"\n{'='*60}")
print(f" Probe 2: Layer-wise CLS — GlycanBERT {model_name}")
print(f"{'='*60}")
ckpt = CHECKPOINTS[args.model]
model = load_model(str(ckpt), device=args.device)
print("\nLoading datasets...")
domain_samples, domain_labels = load_domain_data()
glyco_samples, glyco_labels = load_glycosylation_data()
immuno_samples, immuno_labels = load_immunogenicity_data()
if len(domain_samples) > 3000:
np.random.seed(42)
indices = np.random.choice(len(domain_samples), 3000, replace=False)
domain_samples = [domain_samples[i] for i in indices]
domain_labels = [domain_labels[i] for i in indices]
print(f" Subsampled domain to {len(domain_samples)} samples")
print(f"\nExtracting layer-wise CLS embeddings...")
print(f" Domain ({len(domain_samples)} samples)...")
domain_layer_embs = extract_layerwise_cls(model, domain_samples, device=args.device)
print(f" Glycosylation ({len(glyco_samples)} samples)...")
glyco_layer_embs = extract_layerwise_cls(model, glyco_samples, device=args.device)
print(f" Immunogenicity ({len(immuno_samples)} samples)...")
immuno_layer_embs = extract_layerwise_cls(model, immuno_samples, device=args.device)
import gc, torch
del model; torch.cuda.empty_cache(); gc.collect()
print(f"\nRunning linear probes (5-fold CV at each of 13 layers)...")
n_layers = 13
domain_results = run_layerwise_probing(domain_layer_embs, domain_labels, 'Domain', n_layers)
immuno_results = run_layerwise_probing(immuno_layer_embs, immuno_labels, 'Immunogenicity', n_layers)
glyco_results = run_layerwise_probing(glyco_layer_embs, glyco_labels, 'Glycosylation', n_layers)
all_results = domain_results + immuno_results + glyco_results
out = Path(args.output_dir); out.mkdir(parents=True, exist_ok=True)
csv_path = out / f'layerwise_results_{model_name.lower()}.csv'
with open(csv_path, 'w', newline='') as f:
writer = csv.DictWriter(f, fieldnames=['task','layer','accuracy_mean','accuracy_std','f1_mean','f1_std','n_samples','n_classes'])
writer.writeheader()
for r in all_results:
writer.writerow({k: r[k] for k in writer.fieldnames})
print(f"\n Saved: {csv_path}")
json_path = out / f'layerwise_results_{model_name.lower()}.json'
with open(json_path, 'w') as f:
json.dump(all_results, f, indent=2, default=str)
print(f" Saved: {json_path}")
print(f"\n{'='*60}"); print(f" SUMMARY"); print(f"{'='*60}")
for task_name in ['Domain', 'Immunogenicity', 'Glycosylation']:
task_res = [r for r in all_results if r['task'] == task_name]
best = max(task_res, key=lambda r: r['accuracy_mean'])
emb = next(r for r in task_res if r['layer'] == 0)
last = next(r for r in task_res if r['layer'] == 12)
print(f"\n {task_name}:")
print(f" Embedding layer (0): {emb['accuracy_mean']:.4f}")
print(f" Best layer ({best['layer']}): {best['accuracy_mean']:.4f}")
print(f" Final layer (12): {last['accuracy_mean']:.4f}")
print(f" Gain (best - emb): {best['accuracy_mean'] - emb['accuracy_mean']:+.4f}")
print(f"\nGenerating figures...")
plot_layerwise_results(all_results, args.output_dir, model_name)
print(f"\n{'='*60}"); print(f" COMPLETE"); print(f"{'='*60}")
if __name__ == '__main__':
main()