carbon-demo / scripts /build_species_tree.py
lvwerra's picture
lvwerra HF Staff
Species tree: drop viruses + retitle (25 species, 571,789 sequences)
a54539b
"""Build a species tree from Carbon 3B embeddings.
Pipeline:
1. Read /tmp/carbon-umap/viz.csv to get the species of each row.
2. Stream /tmp/carbon-umap/embeddings.npy in chunks (no full load).
3. Accumulate per-species sum + count -> 27 mean-pooled centroids.
4. Compute cosine distance matrix 27x27.
5. Hierarchical clustering (Ward + UPGMA), build dendrograms.
6. Write data/species_tree.json (linkage + species labels + matrix)
and data/species_tree.png (preview).
The 6.5 GB .npy is mmapped, never fully loaded — RAM usage stays
under 1 GB (one chunk + 27 centroids accumulators).
"""
import csv
import json
import os
import sys
import time
import numpy as np
HERE = os.path.dirname(os.path.abspath(__file__))
DATA = os.path.join(os.path.dirname(HERE), "data")
CSV_PATH = "/tmp/carbon-umap/viz.csv"
NPY_PATH = "/tmp/carbon-umap/embeddings.npy"
CHUNK = 20000 # rows per streaming chunk
KINGDOMS = {
"vertebrates": ["human", "macaque", "mouse", "rat", "dog", "cow", "pig",
"chicken", "frog", "zebrafish"],
"invertebrates": ["fly", "worm"],
"plants": ["arabidopsis", "soybean", "tomato", "maize", "rice"],
"fungi": ["yeast", "fission_yeast", "candida", "aspergillus",
"neurospora"],
"bacteria": ["ecoli", "bsubtilis", "saureus"],
}
# Canonical NCBI clade for each species. Two species sharing the same
# value are sister (or near-sister) groups in standard taxonomy.
# A clade with a single member among our 27 species → the species is
# "solo" and not evaluable for sister-level agreement.
EXPECTED_CLADE = {
"human": "primates",
"macaque": "primates",
"mouse": "rodents",
"rat": "rodents",
"dog": "laurasiatheria",
"cow": "laurasiatheria",
"pig": "laurasiatheria",
"chicken": "sauropsida", # solo
"frog": "amphibia", # solo
"zebrafish": "actinopterygii", # solo
"fly": "insects", # solo
"worm": "nematodes", # solo
"arabidopsis": "dicots",
"tomato": "dicots",
"soybean": "dicots",
"rice": "monocots",
"maize": "monocots",
"yeast": "saccharomycetes",
"candida": "saccharomycetes",
"fission_yeast": "schizosaccharomycetes", # solo
"neurospora": "pezizomycotina",
"aspergillus": "pezizomycotina",
"ecoli": "proteobacteria", # solo
"bsubtilis": "firmicutes",
"saureus": "firmicutes",
}
def species_to_kingdom():
return {sp: k for k, members in KINGDOMS.items() for sp in members}
def main():
t0 = time.perf_counter()
print(f"[1/5] reading species column from {CSV_PATH} ...")
species_per_row = []
with open(CSV_PATH) as f:
reader = csv.DictReader(f)
for row in reader:
species_per_row.append(row["species"])
n = len(species_per_row)
print(f" {n:,} rows")
s2k = species_to_kingdom()
unknown = sorted(set(species_per_row) - set(s2k))
if unknown:
print(f" WARNING: {len(unknown)} species not in KINGDOMS: {unknown[:5]} ...")
species_order = [sp for k in KINGDOMS for sp in KINGDOMS[k] if sp in set(species_per_row)]
sp_to_idx = {sp: i for i, sp in enumerate(species_order)}
K = len(species_order)
print(f" {K} species in this dataset")
species_idx = np.array([sp_to_idx[sp] for sp in species_per_row], dtype=np.int32)
print(f"\n[2/5] memory-mapping {NPY_PATH} ...")
arr = np.lib.format.open_memmap(NPY_PATH, mode="r")
n_rows, dim = arr.shape
assert n_rows == n, f"row mismatch: npy={n_rows} csv={n}"
print(f" shape={arr.shape} dtype={arr.dtype}")
print(f"\n[3/5] streaming {n_rows:,} rows in chunks of {CHUNK:,} "
f"-> {K} centroids of dim {dim} ...")
sums = np.zeros((K, dim), dtype=np.float64)
counts = np.zeros(K, dtype=np.int64)
t_chunk = time.perf_counter()
for start in range(0, n_rows, CHUNK):
end = min(start + CHUNK, n_rows)
chunk = np.asarray(arr[start:end], dtype=np.float32)
sp_chunk = species_idx[start:end]
# group-wise accumulate: np.add.at handles repeated indices safely
np.add.at(sums, sp_chunk, chunk)
np.add.at(counts, sp_chunk, 1)
if (start // CHUNK) % 5 == 0:
elapsed = time.perf_counter() - t_chunk
pct = end / n_rows * 100
print(f" {end:>8,}/{n_rows:,} ({pct:5.1f}%) · {elapsed:.1f}s elapsed")
centroids = sums / counts[:, None]
print(f" done in {time.perf_counter() - t_chunk:.1f}s")
print(f" counts per species: min={counts.min():,} "
f"max={counts.max():,} median={int(np.median(counts)):,}")
print(f"\n[4/5] computing cosine distance matrix {K}x{K} ...")
norms = np.linalg.norm(centroids, axis=1, keepdims=True)
unit = centroids / norms
sim = unit @ unit.T
sim = np.clip(sim, -1.0, 1.0)
cos_dist = 1.0 - sim
np.fill_diagonal(cos_dist, 0.0)
from scipy.spatial.distance import squareform
from scipy.cluster.hierarchy import linkage, dendrogram
condensed = squareform(cos_dist, checks=False)
linkage_ward = linkage(condensed, method="ward")
linkage_upgma = linkage(condensed, method="average")
print(f" linkage matrices ready (Ward + UPGMA)")
# Pre-compute the dendrogram visual layout (icoord/dcoord/leaf order)
# for each linkage method so the frontend can render the tree spine
# in SVG without re-implementing scipy's traversal algorithm.
def dendro_layout(Z):
d = dendrogram(Z, no_plot=True, labels=species_order)
return {
"leaf_order": d["ivl"],
"icoord": d["icoord"],
"dcoord": d["dcoord"],
}
layout_ward = dendro_layout(linkage_ward)
layout_upgma = dendro_layout(linkage_upgma)
print(f"\n[5/5] writing outputs ...")
out = {
"species": species_order,
"kingdom": [s2k.get(sp, "?") for sp in species_order],
"expected_clade": [EXPECTED_CLADE.get(sp, "?") for sp in species_order],
"counts": counts.tolist(),
"distance_matrix": cos_dist.tolist(),
"linkage_ward": linkage_ward.tolist(),
"linkage_upgma": linkage_upgma.tolist(),
"layout_ward": layout_ward,
"layout_upgma": layout_upgma,
"dim": int(dim),
"n_total_points": int(n_rows),
}
json_path = os.path.join(DATA, "species_tree.json")
with open(json_path, "w") as f:
json.dump(out, f, indent=1)
print(f" {json_path} ({os.path.getsize(json_path):,} bytes)")
try:
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from scipy.cluster.hierarchy import dendrogram
kingdom_color = {
"vertebrates": "#1f1f1d",
"invertebrates": "#7a6242",
"plants": "#317f3f",
"fungi": "#a9762f",
"bacteria": "#b00020",
}
fig, axes = plt.subplots(1, 2, figsize=(20, 10))
for ax, lnk, title in zip(
axes,
[linkage_ward, linkage_upgma],
["Ward (cosine)", "UPGMA (cosine)"],
):
ddata = dendrogram(
lnk, labels=species_order, ax=ax,
orientation="right", leaf_font_size=11,
color_threshold=0,
above_threshold_color="#888",
)
ax.set_title(title, fontsize=14)
ax.set_xlabel("cosine distance")
for tick in ax.get_yticklabels():
k = s2k.get(tick.get_text(), "?")
tick.set_color(kingdom_color.get(k, "#666"))
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
plt.tight_layout()
png_path = os.path.join(DATA, "species_tree.png")
plt.savefig(png_path, dpi=120, bbox_inches="tight", facecolor="white")
print(f" {png_path}")
except ImportError:
print(f" (matplotlib not available, skipped PNG preview)")
print(f"\nTotal: {time.perf_counter() - t0:.1f}s")
if __name__ == "__main__":
main()