bertose-affinose-training-code / code /probes /analyze_embeddings.py
supanthadey1's picture
Add BERTose and AFFINose training code release
1d6f391 verified
Raw
History Blame Contribute Delete
13 kB
#!/usr/bin/env python3
"""Embedding Space Deep Dive Analysis - 6 analysis types."""
import os, sys, json, argparse
import numpy as np
from collections import Counter
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
def load_embeddings(npz_path):
print(f"Loading {npz_path}")
data = np.load(npz_path, allow_pickle=True)
for k in data.keys():
if hasattr(data[k], 'shape'):
print(f" {k}: {data[k].shape}")
return data
def compute_umap(embeddings, n_neighbors=15, min_dist=0.1, random_state=42):
try:
from umap import UMAP
return UMAP(n_neighbors=n_neighbors, min_dist=min_dist, random_state=random_state, n_components=2).fit_transform(embeddings)
except ImportError:
print(" umap-learn not installed, falling back to t-SNE")
from sklearn.manifold import TSNE
if len(embeddings) > 15000:
idx = np.random.choice(len(embeddings), 15000, replace=False)
embeddings = embeddings[idx]
return TSNE(n_components=2, perplexity=30, random_state=random_state).fit_transform(embeddings)
def analysis_1_valid_vs_impossible(data, output_dir, name):
"""UMAP: valid training samples vs impossible negatives by difficulty."""
print("\n=== Analysis 1: Valid vs Impossible ===")
train, easy, medium, hard = data['train_embs'], data['easy_embs'], data['medium_embs'], data['hard_embs']
n_neg = len(easy) + len(medium) + len(hard)
n = min(len(train), n_neg)
train_sub = train[np.random.choice(len(train), n, replace=False)]
all_embs = np.vstack([train_sub, easy, medium, hard])
labels = ['Valid']*len(train_sub) + ['Easy']*len(easy) + ['Medium']*len(medium) + ['Hard']*len(hard)
proj = compute_umap(all_embs)
fig, ax = plt.subplots(figsize=(12, 10))
colors = {'Valid': '#2196F3', 'Easy': '#66BB6A', 'Medium': '#FFA726', 'Hard': '#EF5350'}
for label in ['Valid', 'Easy', 'Medium', 'Hard']:
mask = np.array([l == label for l in labels])
ax.scatter(proj[mask, 0], proj[mask, 1], c=colors[label], s=3, alpha=0.4, label=label, rasterized=True)
ax.set_title(f'{name}: Valid vs Impossible Glycans', fontsize=16, fontweight='bold')
ax.legend(markerscale=5, fontsize=12)
ax.set_xlabel('UMAP-1'); ax.set_ylabel('UMAP-2')
plt.tight_layout()
plt.savefig(os.path.join(output_dir, f'umap_valid_vs_impossible_{name}.png'), dpi=200, bbox_inches='tight')
plt.close()
print(f" Saved umap_valid_vs_impossible_{name}.png")
def analysis_2_train_vs_heldout(data, output_dir, name):
"""UMAP: train vs val vs test split."""
print("\n=== Analysis 2: Train vs Held-Out ===")
embs, splits = data['benchmark_embs'], data['benchmark_split']
if len(embs) == 0: print(" No data."); return {}
proj = compute_umap(embs)
fig, ax = plt.subplots(figsize=(12, 10))
colors = {'train': '#2196F3', 'val': '#FFA726', 'test': '#EF5350'}
for split in ['train', 'val', 'test']:
mask = np.array([s == split for s in splits])
if mask.sum(): ax.scatter(proj[mask, 0], proj[mask, 1], c=colors.get(split, '#999'), s=5, alpha=0.5, label=f'{split} ({mask.sum()})', rasterized=True)
ax.set_title(f'{name}: Train vs Held-Out', fontsize=16, fontweight='bold')
ax.legend(markerscale=5, fontsize=12)
ax.set_xlabel('UMAP-1'); ax.set_ylabel('UMAP-2')
plt.tight_layout()
plt.savefig(os.path.join(output_dir, f'umap_train_vs_heldout_{name}.png'), dpi=200, bbox_inches='tight')
plt.close()
print(f" Saved umap_train_vs_heldout_{name}.png")
return dict(Counter(splits))
def analysis_3_taxonomy(data, output_dir, name):
"""UMAP colored by taxonomy level + silhouette scores."""
print("\n=== Analysis 3: Taxonomy Clustering ===")
embs = data['benchmark_embs']
if len(embs) == 0: print(" No data."); return {}
proj = compute_umap(embs)
metrics = {}
for level in ['kingdom', 'phylum', 'class']:
labels = data[f'benchmark_{level}']
valid = np.array([l != '' and l != 'nan' for l in labels])
if valid.sum() < 10: continue
proj_v, labels_v = proj[valid], labels[valid]
counts = Counter(labels_v)
top12 = [l for l, _ in counts.most_common(12)]
cmap = plt.cm.get_cmap('tab20', len(top12))
fig, ax = plt.subplots(figsize=(14, 10))
other = np.array([l not in top12 for l in labels_v])
if other.sum(): ax.scatter(proj_v[other, 0], proj_v[other, 1], c='#CCC', s=3, alpha=0.2, label='Other', rasterized=True)
for i, lab in enumerate(top12):
m = np.array([l == lab for l in labels_v])
ax.scatter(proj_v[m, 0], proj_v[m, 1], c=[cmap(i)], s=5, alpha=0.5, label=f'{lab} ({m.sum()})', rasterized=True)
ax.set_title(f'{name}: {level.capitalize()} Clustering', fontsize=16, fontweight='bold')
ax.legend(markerscale=5, fontsize=9, ncol=2)
ax.set_xlabel('UMAP-1'); ax.set_ylabel('UMAP-2')
plt.tight_layout()
plt.savefig(os.path.join(output_dir, f'umap_taxonomy_{level}_{name}.png'), dpi=200, bbox_inches='tight')
plt.close()
try:
from sklearn.metrics import silhouette_score
label_map = {l: i for i, l in enumerate(set(labels_v))}
numeric = np.array([label_map[l] for l in labels_v])
if len(set(numeric)) > 1:
sil = silhouette_score(proj_v, numeric, sample_size=min(5000, len(proj_v)))
metrics[f'silhouette_{level}'] = round(float(sil), 4)
print(f" Silhouette ({level}): {sil:.4f}")
except Exception as e:
print(f" Silhouette error: {e}")
return metrics
def analysis_4_distances(data, output_dir, name):
"""Cosine distance distributions: same vs different kingdom."""
print("\n=== Analysis 4: Distance Distributions ===")
embs, kingdoms = data['benchmark_embs'], data['benchmark_kingdom']
if len(embs) < 100: print(" Not enough data."); return {}
n = min(2000, len(embs))
idx = np.random.choice(len(embs), n, replace=False)
embs_sub = embs[idx]
labels = kingdoms[idx]
norms = np.linalg.norm(embs_sub, axis=1, keepdims=True)
embs_n = embs_sub / (norms + 1e-8)
sim = embs_n @ embs_n.T
same, diff = [], []
for i in range(n):
for j in range(i+1, min(i+200, n)):
s = float(sim[i, j])
if labels[i] == labels[j] and labels[i] != '': same.append(s)
elif labels[i] != '' and labels[j] != '': diff.append(s)
if not same or not diff: return {}
fig, ax = plt.subplots(figsize=(10, 6))
ax.hist(same, bins=60, alpha=0.6, color='#2196F3', density=True, label=f'Same kingdom (n={len(same)})')
ax.hist(diff, bins=60, alpha=0.6, color='#EF5350', density=True, label=f'Diff kingdom (n={len(diff)})')
ax.axvline(np.mean(same), color='#1565C0', ls='--', alpha=0.7)
ax.axvline(np.mean(diff), color='#C62828', ls='--', alpha=0.7)
ax.set_xlabel('Cosine Similarity', fontsize=14); ax.set_ylabel('Density', fontsize=14)
ax.set_title(f'{name}: Cosine Similarity Distribution', fontsize=16, fontweight='bold')
ax.legend(fontsize=12)
plt.tight_layout()
plt.savefig(os.path.join(output_dir, f'distance_distributions_{name}.png'), dpi=200, bbox_inches='tight')
plt.close()
gap = float(np.mean(same) - np.mean(diff))
print(f" Same: {np.mean(same):.4f}, Diff: {np.mean(diff):.4f}, Gap: {gap:.4f}")
return {'mean_same_sim': round(float(np.mean(same)), 4), 'mean_diff_sim': round(float(np.mean(diff)), 4), 'separation_gap': round(gap, 4)}
def analysis_5_knn_purity(data, output_dir, name, k=10):
"""KNN purity: do test glycans match train neighbors?"""
print(f"\n=== Analysis 5: KNN Purity (K={k}) ===")
embs, splits, kingdoms = data['benchmark_embs'], data['benchmark_split'], data['benchmark_kingdom']
train_m = np.array([s == 'train' for s in splits])
test_m = np.array([s == 'test' for s in splits])
if train_m.sum() == 0 or test_m.sum() == 0: print(" No train/test data."); return {}
tr_e = embs[train_m]; tr_l = kingdoms[train_m]
te_e = embs[test_m]; te_l = kingdoms[test_m]
tr_n = tr_e / (np.linalg.norm(tr_e, axis=1, keepdims=True) + 1e-8)
te_n = te_e / (np.linalg.norm(te_e, axis=1, keepdims=True) + 1e-8)
purities = []
for i in range(len(te_e)):
sims = te_n[i] @ tr_n.T
topk = np.argsort(sims)[-k:]
if te_l[i] != '' and te_l[i] != 'nan':
purities.append(float(np.mean(tr_l[topk] == te_l[i])))
if not purities: return {}
fig, ax = plt.subplots(figsize=(10, 6))
ax.hist(purities, bins=30, color='#4CAF50', alpha=0.7, edgecolor='black')
ax.axvline(np.mean(purities), color='red', ls='--', lw=2, label=f'Mean: {np.mean(purities):.3f}')
ax.set_xlabel(f'KNN Purity (K={k})', fontsize=14); ax.set_ylabel('Count', fontsize=14)
ax.set_title(f'{name}: KNN Purity (Generalization Test)', fontsize=14, fontweight='bold')
ax.legend(fontsize=12)
plt.tight_layout()
plt.savefig(os.path.join(output_dir, f'knn_purity_{name}.png'), dpi=200, bbox_inches='tight')
plt.close()
print(f" Mean: {np.mean(purities):.4f}, Median: {np.median(purities):.4f}")
return {'knn_purity_mean': round(float(np.mean(purities)), 4), 'knn_purity_median': round(float(np.median(purities)), 4), 'n_test': len(purities)}
def analysis_6_v5_vs_v6(output_dir):
"""Side-by-side V5 vs V6 comparison."""
print("\n=== Analysis 6: V5 vs V6 Comparison ===")
v5p, v6p = os.path.join(output_dir, 'embeddings_v5.npz'), os.path.join(output_dir, 'embeddings_v6.npz')
if not os.path.exists(v5p) or not os.path.exists(v6p): print(" Need both."); return {}
v5, v6 = np.load(v5p, allow_pickle=True), np.load(v6p, allow_pickle=True)
metrics = {}
# Pairwise similarity stats
for ver, d in [('v5', v5), ('v6', v6)]:
n = min(1000, len(d['train_embs']))
e = d['train_embs'][:n]
e_n = e / (np.linalg.norm(e, axis=1, keepdims=True) + 1e-8)
sim = e_n @ e_n.T
mask = np.triu(np.ones_like(sim, dtype=bool), k=1)
metrics[f'{ver}_mean_pairwise_sim'] = round(float(np.mean(sim[mask])), 4)
metrics[f'{ver}_std_pairwise_sim'] = round(float(np.std(sim[mask])), 4)
# Valid vs hard impossible separation
for ver, d in [('v5', v5), ('v6', v6)]:
tr, hr = d['train_embs'], d['hard_embs']
n = min(500, len(tr), len(hr))
t_n = tr[:n] / (np.linalg.norm(tr[:n], axis=1, keepdims=True) + 1e-8)
h_n = hr[:n] / (np.linalg.norm(hr[:n], axis=1, keepdims=True) + 1e-8)
metrics[f'{ver}_valid_hard_sim'] = round(float(np.mean(t_n @ h_n.T)), 4)
# Side-by-side UMAP
fig, axes = plt.subplots(1, 2, figsize=(24, 10))
for ax, (ver, d) in zip(axes, [('V5', v5), ('V6', v6)]):
n = min(3000, len(d['train_embs']), len(d['hard_embs']))
combined = np.vstack([d['train_embs'][:n], d['hard_embs'][:n]])
labels = ['Valid']*n + ['Hard Impossible']*min(n, len(d['hard_embs']))
proj = compute_umap(combined)
for lab, col in [('Valid', '#2196F3'), ('Hard Impossible', '#EF5350')]:
m = np.array([l == lab for l in labels])
ax.scatter(proj[m, 0], proj[m, 1], c=col, s=3, alpha=0.4, label=lab, rasterized=True)
ax.set_title(f'{ver}: Valid vs Hard Impossible', fontsize=16, fontweight='bold')
ax.legend(markerscale=5, fontsize=12)
ax.set_xlabel('UMAP-1'); ax.set_ylabel('UMAP-2')
plt.tight_layout()
plt.savefig(os.path.join(output_dir, 'v5_vs_v6_comparison.png'), dpi=200, bbox_inches='tight')
plt.close()
for k, v in metrics.items(): print(f" {k}: {v}")
return metrics
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--input', required=True)
parser.add_argument('--name', required=True)
parser.add_argument('--output_dir', default='bert_v6_contrastive/analysis')
parser.add_argument('--compare', action='store_true')
args = parser.parse_args()
os.makedirs(args.output_dir, exist_ok=True)
data = load_embeddings(args.input)
metrics = {'model': args.name}
analysis_1_valid_vs_impossible(data, args.output_dir, args.name)
metrics.update(analysis_2_train_vs_heldout(data, args.output_dir, args.name))
metrics.update(analysis_3_taxonomy(data, args.output_dir, args.name))
metrics.update(analysis_4_distances(data, args.output_dir, args.name))
metrics.update(analysis_5_knn_purity(data, args.output_dir, args.name))
if args.compare:
metrics.update(analysis_6_v5_vs_v6(args.output_dir))
out = os.path.join(args.output_dir, f'metrics_{args.name}.json')
json.dump(metrics, open(out, 'w'), indent=2)
print(f"\nAll metrics saved to {out}")
if __name__ == '__main__':
main()