| |
| """ |
| Probe 2: Layer-wise CLS Probing for GlycanBERT V6. |
| |
| Extracts the [CLS] token representation at each of the 12 transformer layers |
| and trains a logistic regression classifier to measure how much task-relevant |
| information is encoded at each depth. |
| """ |
|
|
| import os, sys, json, csv, argparse |
| import numpy as np |
| from pathlib import Path |
| from collections import Counter |
|
|
| PROJECT_ROOT = Path(__file__).resolve().parents[2] |
| VOCAB_PATH = PROJECT_ROOT / 'bert_training_v4' / 'data' / 'vocabulary.json' |
| CHECKPOINTS = { |
| 'v5': PROJECT_ROOT / 'checkpoints_v5_bpe_topo' / 'best_v5_bpe_topo_model.pt', |
| 'v6': PROJECT_ROOT / 'bert_v5.1_contrastive' / 'checkpoints' / 'best_v51_contrastive_model.pt', |
| } |
| BENCH_DIR = PROJECT_ROOT / 'bench' / 'GlycanML' / 'data' |
|
|
| sys.path.insert(0, str(PROJECT_ROOT)) |
| sys.path.insert(0, str(PROJECT_ROOT / 'bert_training_v4')) |
| from model.multimodal_glycan_bert_v3 import MultimodalGlycanBERT, MultimodalGlycanBERTConfig |
| from downstream_tasks.utils.tokenizer import WURCSTokenizer |
|
|
|
|
| def load_model(ckpt_path, device='cuda'): |
| import torch |
| print(f"Loading model from {ckpt_path}...") |
| ckpt = torch.load(ckpt_path, map_location='cpu', weights_only=False) |
| state_dict = ckpt.get('model_state_dict', ckpt) |
| backbone_sd = {k: v for k, v in state_dict.items() if not k.startswith('proj_head.')} |
| n_stripped = len(state_dict) - len(backbone_sd) |
| if n_stripped > 0: |
| print(f" Stripped {n_stripped} projection head keys") |
| vocab_size = backbone_sd['seq_embeddings.token_embeddings.weight'].shape[0] |
| ms_total_vocab = None |
| if 'ms_embeddings.token_embeddings.weight' in backbone_sd: |
| ms_total_vocab = backbone_sd['ms_embeddings.token_embeddings.weight'].shape[0] |
| config_kwargs = dict( |
| seq_vocab_size=vocab_size, seq_hidden_size=768, seq_num_layers=12, |
| seq_num_heads=12, seq_max_length=256, use_cnn_frontend=True, cnn_kernel_size=3, |
| ) |
| if ms_total_vocab is not None: |
| config_kwargs['ms_vocab_size'] = ms_total_vocab - vocab_size |
| config = MultimodalGlycanBERTConfig(**config_kwargs) |
| model = MultimodalGlycanBERT(config) |
| model.load_state_dict(backbone_sd, strict=False) |
| model.to(device).eval() |
| print(f" Model loaded: {sum(p.numel() for p in model.parameters()):,} params") |
| return model |
|
|
|
|
| def extract_layerwise_cls(model, samples, device='cuda', max_len=256): |
| import torch, torch.nn.functional as F |
| tokenizer = WURCSTokenizer(str(VOCAB_PATH)) |
| n_layers = len(model.seq_layers) |
| layer_embs = {i: [] for i in range(n_layers + 1)} |
| n_errors = 0 |
| for si, s in enumerate(samples): |
| try: |
| result = tokenizer.tokenize(s['wurcs'], max_length=max_len) |
| token_ids = torch.tensor(result['token_ids'], dtype=torch.long) |
| branch_depths = torch.tensor(result.get('branch_depths', [0]*len(result['token_ids'])), dtype=torch.long) |
| linkage_types = torch.tensor(result.get('linkage_types', [0]*len(result['token_ids'])), dtype=torch.long) |
| min_l = min(len(token_ids), len(branch_depths), len(linkage_types)) |
| token_ids, branch_depths, linkage_types = token_ids[:min_l], branch_depths[:min_l], linkage_types[:min_l] |
| if min_l > max_len: |
| token_ids, branch_depths, linkage_types = token_ids[:max_len], branch_depths[:max_len], linkage_types[:max_len] |
| elif min_l < max_len: |
| pad = max_len - min_l |
| token_ids = F.pad(token_ids, (0, pad), value=0) |
| branch_depths = F.pad(branch_depths, (0, pad), value=0) |
| linkage_types = F.pad(linkage_types, (0, pad), value=0) |
| with torch.no_grad(): |
| hidden = model.seq_embeddings( |
| token_ids.unsqueeze(0).to(device), |
| branch_depths=branch_depths.unsqueeze(0).to(device), |
| linkage_types=linkage_types.unsqueeze(0).to(device) |
| ) |
| layer_embs[0].append(hidden[0, 0, :].cpu().numpy()) |
| for layer_idx, layer in enumerate(model.seq_layers): |
| hidden = layer(hidden) |
| layer_embs[layer_idx + 1].append(hidden[0, 0, :].cpu().numpy()) |
| except Exception as e: |
| n_errors += 1 |
| if n_errors <= 3: print(f" ERROR sample {si}: {e}") |
| for i in range(n_layers + 1): |
| layer_embs[i].append(np.zeros(768)) |
| if si > 0 and si % 500 == 0: |
| print(f" Processed {si}/{len(samples)}") |
| if n_errors > 0: |
| print(f" WARNING: {n_errors}/{len(samples)} errors") |
| return {i: np.array(embs) for i, embs in layer_embs.items()} |
|
|
|
|
| def load_domain_data(): |
| csv_path = BENCH_DIR / 'glycan_classification_wurcs_subset.csv' |
| samples, labels = [], [] |
| with open(csv_path) as f: |
| for row in csv.DictReader(f): |
| w = row.get('wurcs', '') |
| domain = row.get('domain', '') |
| if w.startswith('WURCS') and domain in ('Eukarya', 'Bacteria', 'Virus'): |
| samples.append({'wurcs': w}) |
| labels.append(domain) |
| print(f" Domain data: {len(samples)} samples, {Counter(labels)}") |
| return samples, labels |
|
|
|
|
| def load_glycosylation_data(): |
| csv_path = BENCH_DIR / 'glycan_link_wurcs_subset.csv' |
| samples, labels = [], [] |
| with open(csv_path) as f: |
| for row in csv.DictReader(f): |
| w = row.get('wurcs', '') |
| link = row.get('link', '') |
| if w.startswith('WURCS') and link in ('N', 'O'): |
| samples.append({'wurcs': w}) |
| labels.append(link) |
| print(f" Glycosylation data: {len(samples)} samples, {Counter(labels)}") |
| return samples, labels |
|
|
|
|
| def load_immunogenicity_data(): |
| csv_path = BENCH_DIR / 'glycan_immunogenicity_wurcs_subset.csv' |
| samples, labels = [], [] |
| with open(csv_path) as f: |
| for row in csv.DictReader(f): |
| w = row.get('wurcs', '') |
| imm = row.get('immunogenicity', '') |
| if w.startswith('WURCS') and imm: |
| samples.append({'wurcs': w}) |
| labels.append(int(float(imm))) |
| print(f" Immunogenicity data: {len(samples)} samples, {Counter(labels)}") |
| return samples, labels |
|
|
|
|
| def train_linear_probe(X, y, task_name, n_splits=5): |
| from sklearn.linear_model import LogisticRegression |
| from sklearn.model_selection import StratifiedKFold |
| from sklearn.preprocessing import StandardScaler, LabelEncoder |
| from sklearn.metrics import accuracy_score, f1_score |
| le = LabelEncoder() |
| y_enc = le.fit_transform(y) |
| n_classes = len(le.classes_) |
| skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=42) |
| accs, f1s = [], [] |
| for fold, (train_idx, test_idx) in enumerate(skf.split(X, y_enc)): |
| scaler = StandardScaler() |
| X_train = scaler.fit_transform(X[train_idx]) |
| X_test = scaler.transform(X[test_idx]) |
| clf = LogisticRegression(max_iter=1000, multi_class='multinomial' if n_classes > 2 else 'auto', solver='lbfgs', random_state=42) |
| clf.fit(X_train, y_enc[train_idx]) |
| y_pred = clf.predict(X_test) |
| accs.append(accuracy_score(y_enc[test_idx], y_pred)) |
| avg = 'macro' if n_classes > 2 else 'binary' |
| f1s.append(f1_score(y_enc[test_idx], y_pred, average=avg)) |
| return {'task': task_name, 'accuracy_mean': float(np.mean(accs)), 'accuracy_std': float(np.std(accs)), |
| 'f1_mean': float(np.mean(f1s)), 'f1_std': float(np.std(f1s)), |
| 'n_samples': len(y), 'n_classes': n_classes, 'classes': le.classes_.tolist()} |
|
|
|
|
| def run_layerwise_probing(layer_embs, labels, task_name, n_layers=13): |
| print(f"\n Probing: {task_name} ({len(labels)} samples)") |
| results = [] |
| for layer_idx in range(n_layers): |
| X = layer_embs[layer_idx] |
| res = train_linear_probe(X, labels, task_name) |
| res['layer'] = layer_idx |
| results.append(res) |
| layer_label = "emb" if layer_idx == 0 else str(layer_idx) |
| print(f" Layer {layer_label:>3s}: acc={res['accuracy_mean']:.4f}+/-{res['accuracy_std']:.4f} f1={res['f1_mean']:.4f}+/-{res['f1_std']:.4f}") |
| return results |
|
|
|
|
| def plot_layerwise_results(all_results, output_dir, model_name='V6'): |
| import matplotlib |
| matplotlib.use('Agg') |
| import matplotlib.pyplot as plt |
| plt.rcParams.update({ |
| 'font.family': 'sans-serif', 'font.sans-serif': ['Arial', 'Helvetica', 'DejaVu Sans'], |
| 'font.size': 11, 'axes.titlesize': 13, 'axes.labelsize': 12, |
| 'xtick.labelsize': 10, 'ytick.labelsize': 10, 'legend.fontsize': 10, |
| 'figure.dpi': 300, 'savefig.dpi': 300, 'savefig.bbox': 'tight', |
| 'axes.linewidth': 0.8, 'axes.spines.top': False, 'axes.spines.right': False, |
| }) |
| task_colors = {'Domain': '#E63946', 'Immunogenicity': '#457B9D', 'Glycosylation': '#2A9D8F'} |
| task_markers = {'Domain': 'o', 'Immunogenicity': 's', 'Glycosylation': '^'} |
| tasks = {} |
| for r in all_results: |
| t = r['task'] |
| if t not in tasks: tasks[t] = [] |
| tasks[t].append(r) |
|
|
| |
| fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5.5)) |
| for task_name, results in tasks.items(): |
| results = sorted(results, key=lambda r: r['layer']) |
| layers = [r['layer'] for r in results] |
| accs = [r['accuracy_mean'] for r in results] |
| acc_stds = [r['accuracy_std'] for r in results] |
| f1s = [r['f1_mean'] for r in results] |
| f1_stds = [r['f1_std'] for r in results] |
| color = task_colors.get(task_name, '#333') |
| marker = task_markers.get(task_name, 'o') |
| ax1.errorbar(layers, accs, yerr=acc_stds, color=color, marker=marker, markersize=7, linewidth=2, capsize=3, capthick=1.2, label=f"{task_name} (n={results[0]['n_samples']})", zorder=3) |
| ax1.fill_between(layers, [a-s for a,s in zip(accs,acc_stds)], [a+s for a,s in zip(accs,acc_stds)], color=color, alpha=0.1) |
| ax2.errorbar(layers, f1s, yerr=f1_stds, color=color, marker=marker, markersize=7, linewidth=2, capsize=3, capthick=1.2, label=task_name, zorder=3) |
| ax2.fill_between(layers, [f-s for f,s in zip(f1s,f1_stds)], [f+s for f,s in zip(f1s,f1_stds)], color=color, alpha=0.1) |
|
|
| for ax, ylabel, title, panel in [(ax1, 'Accuracy', 'Accuracy', '(a)'), (ax2, 'Macro F1', 'F1 Score', '(b)')]: |
| ax.set_xlabel('Layer'); ax.set_ylabel(ylabel) |
| ax.set_title(f'Layer-wise Linear Probe {title} — GlycanBERT {model_name}') |
| ax.set_xticks(range(13)); ax.set_xticklabels(['emb'] + [str(i) for i in range(1, 13)]) |
| ax.legend(frameon=False, loc='lower right'); ax.grid(axis='y', alpha=0.3, linestyle='--') |
| ax.axhline(y=0.5, color='gray', linestyle=':', alpha=0.5) |
| ax.text(-0.08, 1.05, panel, transform=ax.transAxes, fontsize=14, fontweight='bold') |
| plt.tight_layout() |
| out = Path(output_dir); out.mkdir(parents=True, exist_ok=True) |
| for fmt in ['png', 'pdf']: |
| fp = out / f'accuracy_vs_layer_{model_name.lower()}.{fmt}' |
| plt.savefig(fp, dpi=300, bbox_inches='tight', facecolor='white') |
| print(f" Saved: {fp}") |
| plt.close() |
|
|
| |
| fig2, ax = plt.subplots(1, 1, figsize=(8, 5.5)) |
| for task_name, results in tasks.items(): |
| results = sorted(results, key=lambda r: r['layer']) |
| layers = [r['layer'] for r in results] |
| accs = [r['accuracy_mean'] for r in results] |
| acc_stds = [r['accuracy_std'] for r in results] |
| color = task_colors.get(task_name, '#333') |
| marker = task_markers.get(task_name, 'o') |
| ax.errorbar(layers, accs, yerr=acc_stds, color=color, marker=marker, markersize=8, linewidth=2.5, capsize=3, capthick=1.2, label=f"{task_name} (n={results[0]['n_samples']})", zorder=3) |
| ax.fill_between(layers, [a-s for a,s in zip(accs,acc_stds)], [a+s for a,s in zip(accs,acc_stds)], color=color, alpha=0.1) |
| ax.set_xlabel('Transformer Layer', fontsize=13); ax.set_ylabel('Linear Probe Accuracy (5-fold CV)', fontsize=13) |
| ax.set_title(f'Layer-wise Representation Quality — GlycanBERT {model_name}', fontsize=14, fontweight='bold') |
| ax.set_xticks(range(13)); ax.set_xticklabels(['Emb'] + [str(i) for i in range(1, 13)], fontsize=10) |
| ax.legend(frameon=True, fancybox=True, shadow=False, edgecolor='#ccc', loc='lower right', fontsize=11) |
| ax.grid(axis='y', alpha=0.3, linestyle='--'); ax.axhline(y=0.5, color='gray', linestyle=':', alpha=0.4) |
| plt.tight_layout() |
| for fmt in ['png', 'pdf']: |
| fp = out / f'accuracy_vs_layer_standalone_{model_name.lower()}.{fmt}' |
| plt.savefig(fp, dpi=300, bbox_inches='tight', facecolor='white') |
| print(f" Saved: {fp}") |
| plt.close() |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description='Probe 2: Layer-wise CLS probing') |
| parser.add_argument('--model', type=str, default='v6', choices=['v5', 'v6']) |
| parser.add_argument('--device', type=str, default='cuda') |
| parser.add_argument('--output_dir', type=str, default=str(PROJECT_ROOT / 'bert_v6_contrastive' / 'analysis' / 'probe_results_v6' / 'probe_2_layerwise_cls')) |
| args = parser.parse_args() |
| model_name = args.model.upper() |
| print(f"\n{'='*60}") |
| print(f" Probe 2: Layer-wise CLS — GlycanBERT {model_name}") |
| print(f"{'='*60}") |
| ckpt = CHECKPOINTS[args.model] |
| model = load_model(str(ckpt), device=args.device) |
| print("\nLoading datasets...") |
| domain_samples, domain_labels = load_domain_data() |
| glyco_samples, glyco_labels = load_glycosylation_data() |
| immuno_samples, immuno_labels = load_immunogenicity_data() |
| if len(domain_samples) > 3000: |
| np.random.seed(42) |
| indices = np.random.choice(len(domain_samples), 3000, replace=False) |
| domain_samples = [domain_samples[i] for i in indices] |
| domain_labels = [domain_labels[i] for i in indices] |
| print(f" Subsampled domain to {len(domain_samples)} samples") |
| print(f"\nExtracting layer-wise CLS embeddings...") |
| print(f" Domain ({len(domain_samples)} samples)...") |
| domain_layer_embs = extract_layerwise_cls(model, domain_samples, device=args.device) |
| print(f" Glycosylation ({len(glyco_samples)} samples)...") |
| glyco_layer_embs = extract_layerwise_cls(model, glyco_samples, device=args.device) |
| print(f" Immunogenicity ({len(immuno_samples)} samples)...") |
| immuno_layer_embs = extract_layerwise_cls(model, immuno_samples, device=args.device) |
| import gc, torch |
| del model; torch.cuda.empty_cache(); gc.collect() |
| print(f"\nRunning linear probes (5-fold CV at each of 13 layers)...") |
| n_layers = 13 |
| domain_results = run_layerwise_probing(domain_layer_embs, domain_labels, 'Domain', n_layers) |
| immuno_results = run_layerwise_probing(immuno_layer_embs, immuno_labels, 'Immunogenicity', n_layers) |
| glyco_results = run_layerwise_probing(glyco_layer_embs, glyco_labels, 'Glycosylation', n_layers) |
| all_results = domain_results + immuno_results + glyco_results |
| out = Path(args.output_dir); out.mkdir(parents=True, exist_ok=True) |
| csv_path = out / f'layerwise_results_{model_name.lower()}.csv' |
| with open(csv_path, 'w', newline='') as f: |
| writer = csv.DictWriter(f, fieldnames=['task','layer','accuracy_mean','accuracy_std','f1_mean','f1_std','n_samples','n_classes']) |
| writer.writeheader() |
| for r in all_results: |
| writer.writerow({k: r[k] for k in writer.fieldnames}) |
| print(f"\n Saved: {csv_path}") |
| json_path = out / f'layerwise_results_{model_name.lower()}.json' |
| with open(json_path, 'w') as f: |
| json.dump(all_results, f, indent=2, default=str) |
| print(f" Saved: {json_path}") |
| print(f"\n{'='*60}"); print(f" SUMMARY"); print(f"{'='*60}") |
| for task_name in ['Domain', 'Immunogenicity', 'Glycosylation']: |
| task_res = [r for r in all_results if r['task'] == task_name] |
| best = max(task_res, key=lambda r: r['accuracy_mean']) |
| emb = next(r for r in task_res if r['layer'] == 0) |
| last = next(r for r in task_res if r['layer'] == 12) |
| print(f"\n {task_name}:") |
| print(f" Embedding layer (0): {emb['accuracy_mean']:.4f}") |
| print(f" Best layer ({best['layer']}): {best['accuracy_mean']:.4f}") |
| print(f" Final layer (12): {last['accuracy_mean']:.4f}") |
| print(f" Gain (best - emb): {best['accuracy_mean'] - emb['accuracy_mean']:+.4f}") |
| print(f"\nGenerating figures...") |
| plot_layerwise_results(all_results, args.output_dir, model_name) |
| print(f"\n{'='*60}"); print(f" COMPLETE"); print(f"{'='*60}") |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|