| |
| """Embedding Space Deep Dive Analysis - 6 analysis types.""" |
|
|
| import os, sys, json, argparse |
| import numpy as np |
| from collections import Counter |
| import matplotlib |
| matplotlib.use('Agg') |
| import matplotlib.pyplot as plt |
|
|
|
|
| def load_embeddings(npz_path): |
| print(f"Loading {npz_path}") |
| data = np.load(npz_path, allow_pickle=True) |
| for k in data.keys(): |
| if hasattr(data[k], 'shape'): |
| print(f" {k}: {data[k].shape}") |
| return data |
|
|
|
|
| def compute_umap(embeddings, n_neighbors=15, min_dist=0.1, random_state=42): |
| try: |
| from umap import UMAP |
| return UMAP(n_neighbors=n_neighbors, min_dist=min_dist, random_state=random_state, n_components=2).fit_transform(embeddings) |
| except ImportError: |
| print(" umap-learn not installed, falling back to t-SNE") |
| from sklearn.manifold import TSNE |
| if len(embeddings) > 15000: |
| idx = np.random.choice(len(embeddings), 15000, replace=False) |
| embeddings = embeddings[idx] |
| return TSNE(n_components=2, perplexity=30, random_state=random_state).fit_transform(embeddings) |
|
|
|
|
| def analysis_1_valid_vs_impossible(data, output_dir, name): |
| """UMAP: valid training samples vs impossible negatives by difficulty.""" |
| print("\n=== Analysis 1: Valid vs Impossible ===") |
| train, easy, medium, hard = data['train_embs'], data['easy_embs'], data['medium_embs'], data['hard_embs'] |
| n_neg = len(easy) + len(medium) + len(hard) |
| n = min(len(train), n_neg) |
| train_sub = train[np.random.choice(len(train), n, replace=False)] |
| |
| all_embs = np.vstack([train_sub, easy, medium, hard]) |
| labels = ['Valid']*len(train_sub) + ['Easy']*len(easy) + ['Medium']*len(medium) + ['Hard']*len(hard) |
| |
| proj = compute_umap(all_embs) |
| fig, ax = plt.subplots(figsize=(12, 10)) |
| colors = {'Valid': '#2196F3', 'Easy': '#66BB6A', 'Medium': '#FFA726', 'Hard': '#EF5350'} |
| for label in ['Valid', 'Easy', 'Medium', 'Hard']: |
| mask = np.array([l == label for l in labels]) |
| ax.scatter(proj[mask, 0], proj[mask, 1], c=colors[label], s=3, alpha=0.4, label=label, rasterized=True) |
| ax.set_title(f'{name}: Valid vs Impossible Glycans', fontsize=16, fontweight='bold') |
| ax.legend(markerscale=5, fontsize=12) |
| ax.set_xlabel('UMAP-1'); ax.set_ylabel('UMAP-2') |
| plt.tight_layout() |
| plt.savefig(os.path.join(output_dir, f'umap_valid_vs_impossible_{name}.png'), dpi=200, bbox_inches='tight') |
| plt.close() |
| print(f" Saved umap_valid_vs_impossible_{name}.png") |
|
|
|
|
| def analysis_2_train_vs_heldout(data, output_dir, name): |
| """UMAP: train vs val vs test split.""" |
| print("\n=== Analysis 2: Train vs Held-Out ===") |
| embs, splits = data['benchmark_embs'], data['benchmark_split'] |
| if len(embs) == 0: print(" No data."); return {} |
| |
| proj = compute_umap(embs) |
| fig, ax = plt.subplots(figsize=(12, 10)) |
| colors = {'train': '#2196F3', 'val': '#FFA726', 'test': '#EF5350'} |
| for split in ['train', 'val', 'test']: |
| mask = np.array([s == split for s in splits]) |
| if mask.sum(): ax.scatter(proj[mask, 0], proj[mask, 1], c=colors.get(split, '#999'), s=5, alpha=0.5, label=f'{split} ({mask.sum()})', rasterized=True) |
| ax.set_title(f'{name}: Train vs Held-Out', fontsize=16, fontweight='bold') |
| ax.legend(markerscale=5, fontsize=12) |
| ax.set_xlabel('UMAP-1'); ax.set_ylabel('UMAP-2') |
| plt.tight_layout() |
| plt.savefig(os.path.join(output_dir, f'umap_train_vs_heldout_{name}.png'), dpi=200, bbox_inches='tight') |
| plt.close() |
| print(f" Saved umap_train_vs_heldout_{name}.png") |
| return dict(Counter(splits)) |
|
|
|
|
| def analysis_3_taxonomy(data, output_dir, name): |
| """UMAP colored by taxonomy level + silhouette scores.""" |
| print("\n=== Analysis 3: Taxonomy Clustering ===") |
| embs = data['benchmark_embs'] |
| if len(embs) == 0: print(" No data."); return {} |
| |
| proj = compute_umap(embs) |
| metrics = {} |
| |
| for level in ['kingdom', 'phylum', 'class']: |
| labels = data[f'benchmark_{level}'] |
| valid = np.array([l != '' and l != 'nan' for l in labels]) |
| if valid.sum() < 10: continue |
| |
| proj_v, labels_v = proj[valid], labels[valid] |
| counts = Counter(labels_v) |
| top12 = [l for l, _ in counts.most_common(12)] |
| cmap = plt.cm.get_cmap('tab20', len(top12)) |
| |
| fig, ax = plt.subplots(figsize=(14, 10)) |
| other = np.array([l not in top12 for l in labels_v]) |
| if other.sum(): ax.scatter(proj_v[other, 0], proj_v[other, 1], c='#CCC', s=3, alpha=0.2, label='Other', rasterized=True) |
| for i, lab in enumerate(top12): |
| m = np.array([l == lab for l in labels_v]) |
| ax.scatter(proj_v[m, 0], proj_v[m, 1], c=[cmap(i)], s=5, alpha=0.5, label=f'{lab} ({m.sum()})', rasterized=True) |
| ax.set_title(f'{name}: {level.capitalize()} Clustering', fontsize=16, fontweight='bold') |
| ax.legend(markerscale=5, fontsize=9, ncol=2) |
| ax.set_xlabel('UMAP-1'); ax.set_ylabel('UMAP-2') |
| plt.tight_layout() |
| plt.savefig(os.path.join(output_dir, f'umap_taxonomy_{level}_{name}.png'), dpi=200, bbox_inches='tight') |
| plt.close() |
| |
| try: |
| from sklearn.metrics import silhouette_score |
| label_map = {l: i for i, l in enumerate(set(labels_v))} |
| numeric = np.array([label_map[l] for l in labels_v]) |
| if len(set(numeric)) > 1: |
| sil = silhouette_score(proj_v, numeric, sample_size=min(5000, len(proj_v))) |
| metrics[f'silhouette_{level}'] = round(float(sil), 4) |
| print(f" Silhouette ({level}): {sil:.4f}") |
| except Exception as e: |
| print(f" Silhouette error: {e}") |
| |
| return metrics |
|
|
|
|
| def analysis_4_distances(data, output_dir, name): |
| """Cosine distance distributions: same vs different kingdom.""" |
| print("\n=== Analysis 4: Distance Distributions ===") |
| embs, kingdoms = data['benchmark_embs'], data['benchmark_kingdom'] |
| if len(embs) < 100: print(" Not enough data."); return {} |
| |
| n = min(2000, len(embs)) |
| idx = np.random.choice(len(embs), n, replace=False) |
| embs_sub = embs[idx] |
| labels = kingdoms[idx] |
| norms = np.linalg.norm(embs_sub, axis=1, keepdims=True) |
| embs_n = embs_sub / (norms + 1e-8) |
| sim = embs_n @ embs_n.T |
| |
| same, diff = [], [] |
| for i in range(n): |
| for j in range(i+1, min(i+200, n)): |
| s = float(sim[i, j]) |
| if labels[i] == labels[j] and labels[i] != '': same.append(s) |
| elif labels[i] != '' and labels[j] != '': diff.append(s) |
| |
| if not same or not diff: return {} |
| |
| fig, ax = plt.subplots(figsize=(10, 6)) |
| ax.hist(same, bins=60, alpha=0.6, color='#2196F3', density=True, label=f'Same kingdom (n={len(same)})') |
| ax.hist(diff, bins=60, alpha=0.6, color='#EF5350', density=True, label=f'Diff kingdom (n={len(diff)})') |
| ax.axvline(np.mean(same), color='#1565C0', ls='--', alpha=0.7) |
| ax.axvline(np.mean(diff), color='#C62828', ls='--', alpha=0.7) |
| ax.set_xlabel('Cosine Similarity', fontsize=14); ax.set_ylabel('Density', fontsize=14) |
| ax.set_title(f'{name}: Cosine Similarity Distribution', fontsize=16, fontweight='bold') |
| ax.legend(fontsize=12) |
| plt.tight_layout() |
| plt.savefig(os.path.join(output_dir, f'distance_distributions_{name}.png'), dpi=200, bbox_inches='tight') |
| plt.close() |
| |
| gap = float(np.mean(same) - np.mean(diff)) |
| print(f" Same: {np.mean(same):.4f}, Diff: {np.mean(diff):.4f}, Gap: {gap:.4f}") |
| return {'mean_same_sim': round(float(np.mean(same)), 4), 'mean_diff_sim': round(float(np.mean(diff)), 4), 'separation_gap': round(gap, 4)} |
|
|
|
|
| def analysis_5_knn_purity(data, output_dir, name, k=10): |
| """KNN purity: do test glycans match train neighbors?""" |
| print(f"\n=== Analysis 5: KNN Purity (K={k}) ===") |
| embs, splits, kingdoms = data['benchmark_embs'], data['benchmark_split'], data['benchmark_kingdom'] |
| train_m = np.array([s == 'train' for s in splits]) |
| test_m = np.array([s == 'test' for s in splits]) |
| if train_m.sum() == 0 or test_m.sum() == 0: print(" No train/test data."); return {} |
| |
| tr_e = embs[train_m]; tr_l = kingdoms[train_m] |
| te_e = embs[test_m]; te_l = kingdoms[test_m] |
| tr_n = tr_e / (np.linalg.norm(tr_e, axis=1, keepdims=True) + 1e-8) |
| te_n = te_e / (np.linalg.norm(te_e, axis=1, keepdims=True) + 1e-8) |
| |
| purities = [] |
| for i in range(len(te_e)): |
| sims = te_n[i] @ tr_n.T |
| topk = np.argsort(sims)[-k:] |
| if te_l[i] != '' and te_l[i] != 'nan': |
| purities.append(float(np.mean(tr_l[topk] == te_l[i]))) |
| |
| if not purities: return {} |
| |
| fig, ax = plt.subplots(figsize=(10, 6)) |
| ax.hist(purities, bins=30, color='#4CAF50', alpha=0.7, edgecolor='black') |
| ax.axvline(np.mean(purities), color='red', ls='--', lw=2, label=f'Mean: {np.mean(purities):.3f}') |
| ax.set_xlabel(f'KNN Purity (K={k})', fontsize=14); ax.set_ylabel('Count', fontsize=14) |
| ax.set_title(f'{name}: KNN Purity (Generalization Test)', fontsize=14, fontweight='bold') |
| ax.legend(fontsize=12) |
| plt.tight_layout() |
| plt.savefig(os.path.join(output_dir, f'knn_purity_{name}.png'), dpi=200, bbox_inches='tight') |
| plt.close() |
| print(f" Mean: {np.mean(purities):.4f}, Median: {np.median(purities):.4f}") |
| return {'knn_purity_mean': round(float(np.mean(purities)), 4), 'knn_purity_median': round(float(np.median(purities)), 4), 'n_test': len(purities)} |
|
|
|
|
| def analysis_6_v5_vs_v6(output_dir): |
| """Side-by-side V5 vs V6 comparison.""" |
| print("\n=== Analysis 6: V5 vs V6 Comparison ===") |
| v5p, v6p = os.path.join(output_dir, 'embeddings_v5.npz'), os.path.join(output_dir, 'embeddings_v6.npz') |
| if not os.path.exists(v5p) or not os.path.exists(v6p): print(" Need both."); return {} |
| |
| v5, v6 = np.load(v5p, allow_pickle=True), np.load(v6p, allow_pickle=True) |
| metrics = {} |
| |
| |
| for ver, d in [('v5', v5), ('v6', v6)]: |
| n = min(1000, len(d['train_embs'])) |
| e = d['train_embs'][:n] |
| e_n = e / (np.linalg.norm(e, axis=1, keepdims=True) + 1e-8) |
| sim = e_n @ e_n.T |
| mask = np.triu(np.ones_like(sim, dtype=bool), k=1) |
| metrics[f'{ver}_mean_pairwise_sim'] = round(float(np.mean(sim[mask])), 4) |
| metrics[f'{ver}_std_pairwise_sim'] = round(float(np.std(sim[mask])), 4) |
| |
| |
| for ver, d in [('v5', v5), ('v6', v6)]: |
| tr, hr = d['train_embs'], d['hard_embs'] |
| n = min(500, len(tr), len(hr)) |
| t_n = tr[:n] / (np.linalg.norm(tr[:n], axis=1, keepdims=True) + 1e-8) |
| h_n = hr[:n] / (np.linalg.norm(hr[:n], axis=1, keepdims=True) + 1e-8) |
| metrics[f'{ver}_valid_hard_sim'] = round(float(np.mean(t_n @ h_n.T)), 4) |
| |
| |
| fig, axes = plt.subplots(1, 2, figsize=(24, 10)) |
| for ax, (ver, d) in zip(axes, [('V5', v5), ('V6', v6)]): |
| n = min(3000, len(d['train_embs']), len(d['hard_embs'])) |
| combined = np.vstack([d['train_embs'][:n], d['hard_embs'][:n]]) |
| labels = ['Valid']*n + ['Hard Impossible']*min(n, len(d['hard_embs'])) |
| proj = compute_umap(combined) |
| for lab, col in [('Valid', '#2196F3'), ('Hard Impossible', '#EF5350')]: |
| m = np.array([l == lab for l in labels]) |
| ax.scatter(proj[m, 0], proj[m, 1], c=col, s=3, alpha=0.4, label=lab, rasterized=True) |
| ax.set_title(f'{ver}: Valid vs Hard Impossible', fontsize=16, fontweight='bold') |
| ax.legend(markerscale=5, fontsize=12) |
| ax.set_xlabel('UMAP-1'); ax.set_ylabel('UMAP-2') |
| plt.tight_layout() |
| plt.savefig(os.path.join(output_dir, 'v5_vs_v6_comparison.png'), dpi=200, bbox_inches='tight') |
| plt.close() |
| |
| for k, v in metrics.items(): print(f" {k}: {v}") |
| return metrics |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument('--input', required=True) |
| parser.add_argument('--name', required=True) |
| parser.add_argument('--output_dir', default='bert_v6_contrastive/analysis') |
| parser.add_argument('--compare', action='store_true') |
| args = parser.parse_args() |
| |
| os.makedirs(args.output_dir, exist_ok=True) |
| data = load_embeddings(args.input) |
| metrics = {'model': args.name} |
| |
| analysis_1_valid_vs_impossible(data, args.output_dir, args.name) |
| metrics.update(analysis_2_train_vs_heldout(data, args.output_dir, args.name)) |
| metrics.update(analysis_3_taxonomy(data, args.output_dir, args.name)) |
| metrics.update(analysis_4_distances(data, args.output_dir, args.name)) |
| metrics.update(analysis_5_knn_purity(data, args.output_dir, args.name)) |
| if args.compare: |
| metrics.update(analysis_6_v5_vs_v6(args.output_dir)) |
| |
| out = os.path.join(args.output_dir, f'metrics_{args.name}.json') |
| json.dump(metrics, open(out, 'w'), indent=2) |
| print(f"\nAll metrics saved to {out}") |
|
|
| if __name__ == '__main__': |
| main() |
|
|