#!/usr/bin/env python3 """ Probe 2: Layer-wise CLS Probing for GlycanBERT V6. Extracts the [CLS] token representation at each of the 12 transformer layers and trains a logistic regression classifier to measure how much task-relevant information is encoded at each depth. """ import os, sys, json, csv, argparse import numpy as np from pathlib import Path from collections import Counter PROJECT_ROOT = Path(__file__).resolve().parents[2] VOCAB_PATH = PROJECT_ROOT / 'bert_training_v4' / 'data' / 'vocabulary.json' CHECKPOINTS = { 'v5': PROJECT_ROOT / 'checkpoints_v5_bpe_topo' / 'best_v5_bpe_topo_model.pt', 'v6': PROJECT_ROOT / 'bert_v5.1_contrastive' / 'checkpoints' / 'best_v51_contrastive_model.pt', } BENCH_DIR = PROJECT_ROOT / 'bench' / 'GlycanML' / 'data' sys.path.insert(0, str(PROJECT_ROOT)) sys.path.insert(0, str(PROJECT_ROOT / 'bert_training_v4')) from model.multimodal_glycan_bert_v3 import MultimodalGlycanBERT, MultimodalGlycanBERTConfig from downstream_tasks.utils.tokenizer import WURCSTokenizer def load_model(ckpt_path, device='cuda'): import torch print(f"Loading model from {ckpt_path}...") ckpt = torch.load(ckpt_path, map_location='cpu', weights_only=False) state_dict = ckpt.get('model_state_dict', ckpt) backbone_sd = {k: v for k, v in state_dict.items() if not k.startswith('proj_head.')} n_stripped = len(state_dict) - len(backbone_sd) if n_stripped > 0: print(f" Stripped {n_stripped} projection head keys") vocab_size = backbone_sd['seq_embeddings.token_embeddings.weight'].shape[0] ms_total_vocab = None if 'ms_embeddings.token_embeddings.weight' in backbone_sd: ms_total_vocab = backbone_sd['ms_embeddings.token_embeddings.weight'].shape[0] config_kwargs = dict( seq_vocab_size=vocab_size, seq_hidden_size=768, seq_num_layers=12, seq_num_heads=12, seq_max_length=256, use_cnn_frontend=True, cnn_kernel_size=3, ) if ms_total_vocab is not None: config_kwargs['ms_vocab_size'] = ms_total_vocab - vocab_size config = MultimodalGlycanBERTConfig(**config_kwargs) model = MultimodalGlycanBERT(config) model.load_state_dict(backbone_sd, strict=False) model.to(device).eval() print(f" Model loaded: {sum(p.numel() for p in model.parameters()):,} params") return model def extract_layerwise_cls(model, samples, device='cuda', max_len=256): import torch, torch.nn.functional as F tokenizer = WURCSTokenizer(str(VOCAB_PATH)) n_layers = len(model.seq_layers) layer_embs = {i: [] for i in range(n_layers + 1)} n_errors = 0 for si, s in enumerate(samples): try: result = tokenizer.tokenize(s['wurcs'], max_length=max_len) token_ids = torch.tensor(result['token_ids'], dtype=torch.long) branch_depths = torch.tensor(result.get('branch_depths', [0]*len(result['token_ids'])), dtype=torch.long) linkage_types = torch.tensor(result.get('linkage_types', [0]*len(result['token_ids'])), dtype=torch.long) min_l = min(len(token_ids), len(branch_depths), len(linkage_types)) token_ids, branch_depths, linkage_types = token_ids[:min_l], branch_depths[:min_l], linkage_types[:min_l] if min_l > max_len: token_ids, branch_depths, linkage_types = token_ids[:max_len], branch_depths[:max_len], linkage_types[:max_len] elif min_l < max_len: pad = max_len - min_l token_ids = F.pad(token_ids, (0, pad), value=0) branch_depths = F.pad(branch_depths, (0, pad), value=0) linkage_types = F.pad(linkage_types, (0, pad), value=0) with torch.no_grad(): hidden = model.seq_embeddings( token_ids.unsqueeze(0).to(device), branch_depths=branch_depths.unsqueeze(0).to(device), linkage_types=linkage_types.unsqueeze(0).to(device) ) layer_embs[0].append(hidden[0, 0, :].cpu().numpy()) for layer_idx, layer in enumerate(model.seq_layers): hidden = layer(hidden) layer_embs[layer_idx + 1].append(hidden[0, 0, :].cpu().numpy()) except Exception as e: n_errors += 1 if n_errors <= 3: print(f" ERROR sample {si}: {e}") for i in range(n_layers + 1): layer_embs[i].append(np.zeros(768)) if si > 0 and si % 500 == 0: print(f" Processed {si}/{len(samples)}") if n_errors > 0: print(f" WARNING: {n_errors}/{len(samples)} errors") return {i: np.array(embs) for i, embs in layer_embs.items()} def load_domain_data(): csv_path = BENCH_DIR / 'glycan_classification_wurcs_subset.csv' samples, labels = [], [] with open(csv_path) as f: for row in csv.DictReader(f): w = row.get('wurcs', '') domain = row.get('domain', '') if w.startswith('WURCS') and domain in ('Eukarya', 'Bacteria', 'Virus'): samples.append({'wurcs': w}) labels.append(domain) print(f" Domain data: {len(samples)} samples, {Counter(labels)}") return samples, labels def load_glycosylation_data(): csv_path = BENCH_DIR / 'glycan_link_wurcs_subset.csv' samples, labels = [], [] with open(csv_path) as f: for row in csv.DictReader(f): w = row.get('wurcs', '') link = row.get('link', '') if w.startswith('WURCS') and link in ('N', 'O'): samples.append({'wurcs': w}) labels.append(link) print(f" Glycosylation data: {len(samples)} samples, {Counter(labels)}") return samples, labels def load_immunogenicity_data(): csv_path = BENCH_DIR / 'glycan_immunogenicity_wurcs_subset.csv' samples, labels = [], [] with open(csv_path) as f: for row in csv.DictReader(f): w = row.get('wurcs', '') imm = row.get('immunogenicity', '') if w.startswith('WURCS') and imm: samples.append({'wurcs': w}) labels.append(int(float(imm))) print(f" Immunogenicity data: {len(samples)} samples, {Counter(labels)}") return samples, labels def train_linear_probe(X, y, task_name, n_splits=5): from sklearn.linear_model import LogisticRegression from sklearn.model_selection import StratifiedKFold from sklearn.preprocessing import StandardScaler, LabelEncoder from sklearn.metrics import accuracy_score, f1_score le = LabelEncoder() y_enc = le.fit_transform(y) n_classes = len(le.classes_) skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=42) accs, f1s = [], [] for fold, (train_idx, test_idx) in enumerate(skf.split(X, y_enc)): scaler = StandardScaler() X_train = scaler.fit_transform(X[train_idx]) X_test = scaler.transform(X[test_idx]) clf = LogisticRegression(max_iter=1000, multi_class='multinomial' if n_classes > 2 else 'auto', solver='lbfgs', random_state=42) clf.fit(X_train, y_enc[train_idx]) y_pred = clf.predict(X_test) accs.append(accuracy_score(y_enc[test_idx], y_pred)) avg = 'macro' if n_classes > 2 else 'binary' f1s.append(f1_score(y_enc[test_idx], y_pred, average=avg)) return {'task': task_name, 'accuracy_mean': float(np.mean(accs)), 'accuracy_std': float(np.std(accs)), 'f1_mean': float(np.mean(f1s)), 'f1_std': float(np.std(f1s)), 'n_samples': len(y), 'n_classes': n_classes, 'classes': le.classes_.tolist()} def run_layerwise_probing(layer_embs, labels, task_name, n_layers=13): print(f"\n Probing: {task_name} ({len(labels)} samples)") results = [] for layer_idx in range(n_layers): X = layer_embs[layer_idx] res = train_linear_probe(X, labels, task_name) res['layer'] = layer_idx results.append(res) layer_label = "emb" if layer_idx == 0 else str(layer_idx) print(f" Layer {layer_label:>3s}: acc={res['accuracy_mean']:.4f}+/-{res['accuracy_std']:.4f} f1={res['f1_mean']:.4f}+/-{res['f1_std']:.4f}") return results def plot_layerwise_results(all_results, output_dir, model_name='V6'): import matplotlib matplotlib.use('Agg') import matplotlib.pyplot as plt plt.rcParams.update({ 'font.family': 'sans-serif', 'font.sans-serif': ['Arial', 'Helvetica', 'DejaVu Sans'], 'font.size': 11, 'axes.titlesize': 13, 'axes.labelsize': 12, 'xtick.labelsize': 10, 'ytick.labelsize': 10, 'legend.fontsize': 10, 'figure.dpi': 300, 'savefig.dpi': 300, 'savefig.bbox': 'tight', 'axes.linewidth': 0.8, 'axes.spines.top': False, 'axes.spines.right': False, }) task_colors = {'Domain': '#E63946', 'Immunogenicity': '#457B9D', 'Glycosylation': '#2A9D8F'} task_markers = {'Domain': 'o', 'Immunogenicity': 's', 'Glycosylation': '^'} tasks = {} for r in all_results: t = r['task'] if t not in tasks: tasks[t] = [] tasks[t].append(r) # Two-panel figure fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5.5)) for task_name, results in tasks.items(): results = sorted(results, key=lambda r: r['layer']) layers = [r['layer'] for r in results] accs = [r['accuracy_mean'] for r in results] acc_stds = [r['accuracy_std'] for r in results] f1s = [r['f1_mean'] for r in results] f1_stds = [r['f1_std'] for r in results] color = task_colors.get(task_name, '#333') marker = task_markers.get(task_name, 'o') ax1.errorbar(layers, accs, yerr=acc_stds, color=color, marker=marker, markersize=7, linewidth=2, capsize=3, capthick=1.2, label=f"{task_name} (n={results[0]['n_samples']})", zorder=3) ax1.fill_between(layers, [a-s for a,s in zip(accs,acc_stds)], [a+s for a,s in zip(accs,acc_stds)], color=color, alpha=0.1) ax2.errorbar(layers, f1s, yerr=f1_stds, color=color, marker=marker, markersize=7, linewidth=2, capsize=3, capthick=1.2, label=task_name, zorder=3) ax2.fill_between(layers, [f-s for f,s in zip(f1s,f1_stds)], [f+s for f,s in zip(f1s,f1_stds)], color=color, alpha=0.1) for ax, ylabel, title, panel in [(ax1, 'Accuracy', 'Accuracy', '(a)'), (ax2, 'Macro F1', 'F1 Score', '(b)')]: ax.set_xlabel('Layer'); ax.set_ylabel(ylabel) ax.set_title(f'Layer-wise Linear Probe {title} — GlycanBERT {model_name}') ax.set_xticks(range(13)); ax.set_xticklabels(['emb'] + [str(i) for i in range(1, 13)]) ax.legend(frameon=False, loc='lower right'); ax.grid(axis='y', alpha=0.3, linestyle='--') ax.axhline(y=0.5, color='gray', linestyle=':', alpha=0.5) ax.text(-0.08, 1.05, panel, transform=ax.transAxes, fontsize=14, fontweight='bold') plt.tight_layout() out = Path(output_dir); out.mkdir(parents=True, exist_ok=True) for fmt in ['png', 'pdf']: fp = out / f'accuracy_vs_layer_{model_name.lower()}.{fmt}' plt.savefig(fp, dpi=300, bbox_inches='tight', facecolor='white') print(f" Saved: {fp}") plt.close() # Standalone accuracy-only plot fig2, ax = plt.subplots(1, 1, figsize=(8, 5.5)) for task_name, results in tasks.items(): results = sorted(results, key=lambda r: r['layer']) layers = [r['layer'] for r in results] accs = [r['accuracy_mean'] for r in results] acc_stds = [r['accuracy_std'] for r in results] color = task_colors.get(task_name, '#333') marker = task_markers.get(task_name, 'o') ax.errorbar(layers, accs, yerr=acc_stds, color=color, marker=marker, markersize=8, linewidth=2.5, capsize=3, capthick=1.2, label=f"{task_name} (n={results[0]['n_samples']})", zorder=3) ax.fill_between(layers, [a-s for a,s in zip(accs,acc_stds)], [a+s for a,s in zip(accs,acc_stds)], color=color, alpha=0.1) ax.set_xlabel('Transformer Layer', fontsize=13); ax.set_ylabel('Linear Probe Accuracy (5-fold CV)', fontsize=13) ax.set_title(f'Layer-wise Representation Quality — GlycanBERT {model_name}', fontsize=14, fontweight='bold') ax.set_xticks(range(13)); ax.set_xticklabels(['Emb'] + [str(i) for i in range(1, 13)], fontsize=10) ax.legend(frameon=True, fancybox=True, shadow=False, edgecolor='#ccc', loc='lower right', fontsize=11) ax.grid(axis='y', alpha=0.3, linestyle='--'); ax.axhline(y=0.5, color='gray', linestyle=':', alpha=0.4) plt.tight_layout() for fmt in ['png', 'pdf']: fp = out / f'accuracy_vs_layer_standalone_{model_name.lower()}.{fmt}' plt.savefig(fp, dpi=300, bbox_inches='tight', facecolor='white') print(f" Saved: {fp}") plt.close() def main(): parser = argparse.ArgumentParser(description='Probe 2: Layer-wise CLS probing') parser.add_argument('--model', type=str, default='v6', choices=['v5', 'v6']) parser.add_argument('--device', type=str, default='cuda') parser.add_argument('--output_dir', type=str, default=str(PROJECT_ROOT / 'bert_v6_contrastive' / 'analysis' / 'probe_results_v6' / 'probe_2_layerwise_cls')) args = parser.parse_args() model_name = args.model.upper() print(f"\n{'='*60}") print(f" Probe 2: Layer-wise CLS — GlycanBERT {model_name}") print(f"{'='*60}") ckpt = CHECKPOINTS[args.model] model = load_model(str(ckpt), device=args.device) print("\nLoading datasets...") domain_samples, domain_labels = load_domain_data() glyco_samples, glyco_labels = load_glycosylation_data() immuno_samples, immuno_labels = load_immunogenicity_data() if len(domain_samples) > 3000: np.random.seed(42) indices = np.random.choice(len(domain_samples), 3000, replace=False) domain_samples = [domain_samples[i] for i in indices] domain_labels = [domain_labels[i] for i in indices] print(f" Subsampled domain to {len(domain_samples)} samples") print(f"\nExtracting layer-wise CLS embeddings...") print(f" Domain ({len(domain_samples)} samples)...") domain_layer_embs = extract_layerwise_cls(model, domain_samples, device=args.device) print(f" Glycosylation ({len(glyco_samples)} samples)...") glyco_layer_embs = extract_layerwise_cls(model, glyco_samples, device=args.device) print(f" Immunogenicity ({len(immuno_samples)} samples)...") immuno_layer_embs = extract_layerwise_cls(model, immuno_samples, device=args.device) import gc, torch del model; torch.cuda.empty_cache(); gc.collect() print(f"\nRunning linear probes (5-fold CV at each of 13 layers)...") n_layers = 13 domain_results = run_layerwise_probing(domain_layer_embs, domain_labels, 'Domain', n_layers) immuno_results = run_layerwise_probing(immuno_layer_embs, immuno_labels, 'Immunogenicity', n_layers) glyco_results = run_layerwise_probing(glyco_layer_embs, glyco_labels, 'Glycosylation', n_layers) all_results = domain_results + immuno_results + glyco_results out = Path(args.output_dir); out.mkdir(parents=True, exist_ok=True) csv_path = out / f'layerwise_results_{model_name.lower()}.csv' with open(csv_path, 'w', newline='') as f: writer = csv.DictWriter(f, fieldnames=['task','layer','accuracy_mean','accuracy_std','f1_mean','f1_std','n_samples','n_classes']) writer.writeheader() for r in all_results: writer.writerow({k: r[k] for k in writer.fieldnames}) print(f"\n Saved: {csv_path}") json_path = out / f'layerwise_results_{model_name.lower()}.json' with open(json_path, 'w') as f: json.dump(all_results, f, indent=2, default=str) print(f" Saved: {json_path}") print(f"\n{'='*60}"); print(f" SUMMARY"); print(f"{'='*60}") for task_name in ['Domain', 'Immunogenicity', 'Glycosylation']: task_res = [r for r in all_results if r['task'] == task_name] best = max(task_res, key=lambda r: r['accuracy_mean']) emb = next(r for r in task_res if r['layer'] == 0) last = next(r for r in task_res if r['layer'] == 12) print(f"\n {task_name}:") print(f" Embedding layer (0): {emb['accuracy_mean']:.4f}") print(f" Best layer ({best['layer']}): {best['accuracy_mean']:.4f}") print(f" Final layer (12): {last['accuracy_mean']:.4f}") print(f" Gain (best - emb): {best['accuracy_mean'] - emb['accuracy_mean']:+.4f}") print(f"\nGenerating figures...") plot_layerwise_results(all_results, args.output_dir, model_name) print(f"\n{'='*60}"); print(f" COMPLETE"); print(f"{'='*60}") if __name__ == '__main__': main()