Spaces:
Running
Running
| """Build a species tree from Carbon 3B embeddings. | |
| Pipeline: | |
| 1. Read /tmp/carbon-umap/viz.csv to get the species of each row. | |
| 2. Stream /tmp/carbon-umap/embeddings.npy in chunks (no full load). | |
| 3. Accumulate per-species sum + count -> 27 mean-pooled centroids. | |
| 4. Compute cosine distance matrix 27x27. | |
| 5. Hierarchical clustering (Ward + UPGMA), build dendrograms. | |
| 6. Write data/species_tree.json (linkage + species labels + matrix) | |
| and data/species_tree.png (preview). | |
| The 6.5 GB .npy is mmapped, never fully loaded — RAM usage stays | |
| under 1 GB (one chunk + 27 centroids accumulators). | |
| """ | |
| import csv | |
| import json | |
| import os | |
| import sys | |
| import time | |
| import numpy as np | |
| HERE = os.path.dirname(os.path.abspath(__file__)) | |
| DATA = os.path.join(os.path.dirname(HERE), "data") | |
| CSV_PATH = "/tmp/carbon-umap/viz.csv" | |
| NPY_PATH = "/tmp/carbon-umap/embeddings.npy" | |
| CHUNK = 20000 # rows per streaming chunk | |
| KINGDOMS = { | |
| "vertebrates": ["human", "macaque", "mouse", "rat", "dog", "cow", "pig", | |
| "chicken", "frog", "zebrafish"], | |
| "invertebrates": ["fly", "worm"], | |
| "plants": ["arabidopsis", "soybean", "tomato", "maize", "rice"], | |
| "fungi": ["yeast", "fission_yeast", "candida", "aspergillus", | |
| "neurospora"], | |
| "bacteria": ["ecoli", "bsubtilis", "saureus"], | |
| } | |
| # Canonical NCBI clade for each species. Two species sharing the same | |
| # value are sister (or near-sister) groups in standard taxonomy. | |
| # A clade with a single member among our 27 species → the species is | |
| # "solo" and not evaluable for sister-level agreement. | |
| EXPECTED_CLADE = { | |
| "human": "primates", | |
| "macaque": "primates", | |
| "mouse": "rodents", | |
| "rat": "rodents", | |
| "dog": "laurasiatheria", | |
| "cow": "laurasiatheria", | |
| "pig": "laurasiatheria", | |
| "chicken": "sauropsida", # solo | |
| "frog": "amphibia", # solo | |
| "zebrafish": "actinopterygii", # solo | |
| "fly": "insects", # solo | |
| "worm": "nematodes", # solo | |
| "arabidopsis": "dicots", | |
| "tomato": "dicots", | |
| "soybean": "dicots", | |
| "rice": "monocots", | |
| "maize": "monocots", | |
| "yeast": "saccharomycetes", | |
| "candida": "saccharomycetes", | |
| "fission_yeast": "schizosaccharomycetes", # solo | |
| "neurospora": "pezizomycotina", | |
| "aspergillus": "pezizomycotina", | |
| "ecoli": "proteobacteria", # solo | |
| "bsubtilis": "firmicutes", | |
| "saureus": "firmicutes", | |
| } | |
| def species_to_kingdom(): | |
| return {sp: k for k, members in KINGDOMS.items() for sp in members} | |
| def main(): | |
| t0 = time.perf_counter() | |
| print(f"[1/5] reading species column from {CSV_PATH} ...") | |
| species_per_row = [] | |
| with open(CSV_PATH) as f: | |
| reader = csv.DictReader(f) | |
| for row in reader: | |
| species_per_row.append(row["species"]) | |
| n = len(species_per_row) | |
| print(f" {n:,} rows") | |
| s2k = species_to_kingdom() | |
| unknown = sorted(set(species_per_row) - set(s2k)) | |
| if unknown: | |
| print(f" WARNING: {len(unknown)} species not in KINGDOMS: {unknown[:5]} ...") | |
| species_order = [sp for k in KINGDOMS for sp in KINGDOMS[k] if sp in set(species_per_row)] | |
| sp_to_idx = {sp: i for i, sp in enumerate(species_order)} | |
| K = len(species_order) | |
| print(f" {K} species in this dataset") | |
| species_idx = np.array([sp_to_idx[sp] for sp in species_per_row], dtype=np.int32) | |
| print(f"\n[2/5] memory-mapping {NPY_PATH} ...") | |
| arr = np.lib.format.open_memmap(NPY_PATH, mode="r") | |
| n_rows, dim = arr.shape | |
| assert n_rows == n, f"row mismatch: npy={n_rows} csv={n}" | |
| print(f" shape={arr.shape} dtype={arr.dtype}") | |
| print(f"\n[3/5] streaming {n_rows:,} rows in chunks of {CHUNK:,} " | |
| f"-> {K} centroids of dim {dim} ...") | |
| sums = np.zeros((K, dim), dtype=np.float64) | |
| counts = np.zeros(K, dtype=np.int64) | |
| t_chunk = time.perf_counter() | |
| for start in range(0, n_rows, CHUNK): | |
| end = min(start + CHUNK, n_rows) | |
| chunk = np.asarray(arr[start:end], dtype=np.float32) | |
| sp_chunk = species_idx[start:end] | |
| # group-wise accumulate: np.add.at handles repeated indices safely | |
| np.add.at(sums, sp_chunk, chunk) | |
| np.add.at(counts, sp_chunk, 1) | |
| if (start // CHUNK) % 5 == 0: | |
| elapsed = time.perf_counter() - t_chunk | |
| pct = end / n_rows * 100 | |
| print(f" {end:>8,}/{n_rows:,} ({pct:5.1f}%) · {elapsed:.1f}s elapsed") | |
| centroids = sums / counts[:, None] | |
| print(f" done in {time.perf_counter() - t_chunk:.1f}s") | |
| print(f" counts per species: min={counts.min():,} " | |
| f"max={counts.max():,} median={int(np.median(counts)):,}") | |
| print(f"\n[4/5] computing cosine distance matrix {K}x{K} ...") | |
| norms = np.linalg.norm(centroids, axis=1, keepdims=True) | |
| unit = centroids / norms | |
| sim = unit @ unit.T | |
| sim = np.clip(sim, -1.0, 1.0) | |
| cos_dist = 1.0 - sim | |
| np.fill_diagonal(cos_dist, 0.0) | |
| from scipy.spatial.distance import squareform | |
| from scipy.cluster.hierarchy import linkage, dendrogram | |
| condensed = squareform(cos_dist, checks=False) | |
| linkage_ward = linkage(condensed, method="ward") | |
| linkage_upgma = linkage(condensed, method="average") | |
| print(f" linkage matrices ready (Ward + UPGMA)") | |
| # Pre-compute the dendrogram visual layout (icoord/dcoord/leaf order) | |
| # for each linkage method so the frontend can render the tree spine | |
| # in SVG without re-implementing scipy's traversal algorithm. | |
| def dendro_layout(Z): | |
| d = dendrogram(Z, no_plot=True, labels=species_order) | |
| return { | |
| "leaf_order": d["ivl"], | |
| "icoord": d["icoord"], | |
| "dcoord": d["dcoord"], | |
| } | |
| layout_ward = dendro_layout(linkage_ward) | |
| layout_upgma = dendro_layout(linkage_upgma) | |
| print(f"\n[5/5] writing outputs ...") | |
| out = { | |
| "species": species_order, | |
| "kingdom": [s2k.get(sp, "?") for sp in species_order], | |
| "expected_clade": [EXPECTED_CLADE.get(sp, "?") for sp in species_order], | |
| "counts": counts.tolist(), | |
| "distance_matrix": cos_dist.tolist(), | |
| "linkage_ward": linkage_ward.tolist(), | |
| "linkage_upgma": linkage_upgma.tolist(), | |
| "layout_ward": layout_ward, | |
| "layout_upgma": layout_upgma, | |
| "dim": int(dim), | |
| "n_total_points": int(n_rows), | |
| } | |
| json_path = os.path.join(DATA, "species_tree.json") | |
| with open(json_path, "w") as f: | |
| json.dump(out, f, indent=1) | |
| print(f" {json_path} ({os.path.getsize(json_path):,} bytes)") | |
| try: | |
| import matplotlib | |
| matplotlib.use("Agg") | |
| import matplotlib.pyplot as plt | |
| from scipy.cluster.hierarchy import dendrogram | |
| kingdom_color = { | |
| "vertebrates": "#1f1f1d", | |
| "invertebrates": "#7a6242", | |
| "plants": "#317f3f", | |
| "fungi": "#a9762f", | |
| "bacteria": "#b00020", | |
| } | |
| fig, axes = plt.subplots(1, 2, figsize=(20, 10)) | |
| for ax, lnk, title in zip( | |
| axes, | |
| [linkage_ward, linkage_upgma], | |
| ["Ward (cosine)", "UPGMA (cosine)"], | |
| ): | |
| ddata = dendrogram( | |
| lnk, labels=species_order, ax=ax, | |
| orientation="right", leaf_font_size=11, | |
| color_threshold=0, | |
| above_threshold_color="#888", | |
| ) | |
| ax.set_title(title, fontsize=14) | |
| ax.set_xlabel("cosine distance") | |
| for tick in ax.get_yticklabels(): | |
| k = s2k.get(tick.get_text(), "?") | |
| tick.set_color(kingdom_color.get(k, "#666")) | |
| ax.spines["top"].set_visible(False) | |
| ax.spines["right"].set_visible(False) | |
| plt.tight_layout() | |
| png_path = os.path.join(DATA, "species_tree.png") | |
| plt.savefig(png_path, dpi=120, bbox_inches="tight", facecolor="white") | |
| print(f" {png_path}") | |
| except ImportError: | |
| print(f" (matplotlib not available, skipped PNG preview)") | |
| print(f"\nTotal: {time.perf_counter() - t0:.1f}s") | |
| if __name__ == "__main__": | |
| main() | |