| | import torch |
| | import sys |
| | from sklearn.preprocessing import StandardScaler |
| | import pytorch_lightning as pl |
| | from torch.utils.data import DataLoader |
| | from lightning.pytorch.utilities.combined_loader import CombinedLoader |
| | import numpy as np |
| | from scipy.spatial import cKDTree |
| | import math |
| | from functools import partial |
| | import matplotlib.pyplot as plt |
| | import pandas as pd |
| | from torch.utils.data import TensorDataset |
| | from sklearn.neighbors import kneighbors_graph |
| | import igraph as ig |
| | from leidenalg import find_partition, ModularityVertexPartition |
| |
|
| | class WeightedBranchedVeresDataModule(pl.LightningDataModule): |
| |
|
| | def __init__(self, args): |
| | super().__init__() |
| | self.save_hyperparameters() |
| |
|
| | self.data_path = args.data_path |
| | self.batch_size = args.batch_size |
| | self.max_dim = args.dim |
| | self.whiten = args.whiten |
| | self.k = 20 |
| | self.num_timesteps = 8 |
| | |
| | self.num_branches = args.branches if hasattr(args, 'branches') else None |
| | self.split_ratios = args.split_ratios |
| | self.metric_clusters = args.metric_clusters |
| | self.discard_small = args.discard if hasattr(args, 'discard') else False |
| | self.args = args |
| | self._prepare_data() |
| |
|
| | def _prepare_data(self): |
| | print("Preparing Veres cell data with Leiden clustering in WeightedBranchedVeresLeidenDataModule") |
| | df = pd.read_csv(self.data_path) |
| |
|
| | |
| | coords_by_t = { |
| | t: df[df["samples"] == t].iloc[:, 1:].values |
| | for t in sorted(df["samples"].unique()) |
| | } |
| |
|
| | n0 = coords_by_t[0].shape[0] |
| | self.n_samples = n0 |
| |
|
| | print("Timepoint distribution:") |
| | for t in sorted(coords_by_t.keys()): |
| | print(f" t={t}: {coords_by_t[t].shape[0]} points") |
| |
|
| | |
| | final_t = max(coords_by_t.keys()) |
| | coords_final = coords_by_t[final_t] |
| | k = 20 |
| | knn_graph = kneighbors_graph(coords_final, k, mode='connectivity', include_self=False) |
| | sources, targets = knn_graph.nonzero() |
| | edgelist = list(zip(sources.tolist(), targets.tolist())) |
| | graph = ig.Graph(edgelist, directed=False) |
| | partition = find_partition(graph, ModularityVertexPartition) |
| | leiden_labels = np.array(partition.membership) |
| | n_leiden = len(np.unique(leiden_labels)) |
| | print(f"Leiden found {n_leiden} clusters at t={final_t}") |
| |
|
| | df_final = df[df["samples"] == final_t].copy() |
| | df_final["branch"] = leiden_labels |
| |
|
| | cluster_counts = df_final["branch"].value_counts().sort_index() |
| | print(f"Branch distribution at t={final_t} (pre-merge):") |
| | print(cluster_counts) |
| |
|
| | |
| | min_cells = 100 |
| | cluster_data_dict = {} |
| | cluster_sizes = [] |
| | for b in range(n_leiden): |
| | branch_data = df_final[df_final["branch"] == b].iloc[:, 1:-1].values |
| | cluster_data_dict[b] = branch_data |
| | cluster_sizes.append(branch_data.shape[0]) |
| |
|
| | large_clusters = [b for b, size in enumerate(cluster_sizes) if size >= min_cells] |
| | small_clusters = [b for b, size in enumerate(cluster_sizes) if size < min_cells] |
| |
|
| | |
| | if len(large_clusters) == 0: |
| | large_clusters = list(range(n_leiden)) |
| | small_clusters = [] |
| |
|
| | if self.discard_small: |
| | |
| | print(f"Discarding {len(small_clusters)} small clusters (< {min_cells} cells)") |
| | |
| | mask = np.isin(leiden_labels, large_clusters) |
| | df_final = df_final[mask].copy() |
| | merged_labels = leiden_labels[mask] |
| | |
| | |
| | new_ids = np.unique(merged_labels) |
| | id_map = {old: new for new, old in enumerate(new_ids)} |
| | merged_labels = np.array([id_map[x] for x in merged_labels]) |
| | n_merged = len(np.unique(merged_labels)) |
| | |
| | df_final["branch"] = merged_labels |
| | print(f"Kept {n_merged} large clusters") |
| | else: |
| | centroids = {b: np.mean(cluster_data_dict[b], axis=0) for b in range(n_leiden) if cluster_data_dict[b].shape[0] > 0} |
| |
|
| | merged_labels = leiden_labels.copy() |
| | for b in small_clusters: |
| | if cluster_data_dict[b].shape[0] == 0: |
| | continue |
| | |
| | dists = [np.linalg.norm(centroids[b] - centroids[bl]) for bl in large_clusters] |
| | nearest_large = large_clusters[int(np.argmin(dists))] |
| | merged_labels[leiden_labels == b] = nearest_large |
| |
|
| | |
| | new_ids = np.unique(merged_labels) |
| | id_map = {old: new for new, old in enumerate(new_ids)} |
| | merged_labels = np.array([id_map[x] for x in merged_labels]) |
| | n_merged = len(np.unique(merged_labels)) |
| |
|
| | df_final["branch"] = merged_labels |
| | print(f"Merged into {n_merged} clusters") |
| | |
| | cluster_counts_merged = df_final["branch"].value_counts().sort_index() |
| | print(f"Branch distribution at t={final_t} (post-merge):") |
| | print(cluster_counts_merged) |
| |
|
| | endpoints = {} |
| | cluster_sizes = [] |
| | for b in range(n_merged): |
| | branch_data = df_final[df_final["branch"] == b].iloc[:, 1:-1].values |
| | cluster_sizes.append(branch_data.shape[0]) |
| | replace = branch_data.shape[0] < n0 |
| | sampled_indices = np.random.choice(branch_data.shape[0], size=n0, replace=replace) |
| | endpoints[b] = branch_data[sampled_indices] |
| | total_t_final = sum(cluster_sizes) |
| |
|
| | x0 = torch.tensor(coords_by_t[0], dtype=torch.float32) |
| | self.coords_t0 = x0 |
| | |
| | self.coords_intermediate = {t: torch.tensor(coords_by_t[t], dtype=torch.float32) |
| | for t in coords_by_t.keys() if t != 0 and t != final_t} |
| |
|
| | self.branch_endpoints = {b: torch.tensor(endpoints[b], dtype=torch.float32) for b in range(n_merged)} |
| | self.num_branches = n_merged |
| |
|
| | |
| | time_labels_list = [np.zeros(len(self.coords_t0))] |
| | for t in sorted(self.coords_intermediate.keys()): |
| | time_labels_list.append(np.ones(len(self.coords_intermediate[t])) * t) |
| | for b in range(self.num_branches): |
| | time_labels_list.append(np.ones(len(self.branch_endpoints[b])) * final_t) |
| | self.time_labels = np.concatenate(time_labels_list) |
| |
|
| | |
| | split_index = int(n0 * self.split_ratios[0]) |
| | if n0 - split_index < self.batch_size: |
| | split_index = n0 - self.batch_size |
| |
|
| | train_x0 = x0[:split_index] |
| | val_x0 = x0[split_index:] |
| | self.val_x0 = val_x0 |
| |
|
| | train_x0_weights = torch.full((train_x0.shape[0], 1), fill_value=1.0) |
| | val_x0_weights = torch.full((val_x0.shape[0], 1), fill_value=1.0) |
| |
|
| | |
| | branch_weights = [size / total_t_final for size in cluster_sizes] |
| |
|
| | |
| | train_intermediate = {} |
| | val_intermediate = {} |
| | self.train_coords_intermediate = {} |
| | for t in sorted(self.coords_intermediate.keys()): |
| | coords_t = self.coords_intermediate[t] |
| | train_coords_t = coords_t[:split_index] |
| | val_coords_t = coords_t[split_index:] |
| | train_weights_t = torch.full((train_coords_t.shape[0], 1), fill_value=1.0) |
| | val_weights_t = torch.full((val_coords_t.shape[0], 1), fill_value=1.0) |
| | train_intermediate[f"x{t}"] = (train_coords_t, train_weights_t) |
| | val_intermediate[f"x{t}"] = (val_coords_t, val_weights_t) |
| | self.train_coords_intermediate[t] = train_coords_t |
| |
|
| | train_loaders = { |
| | "x0": DataLoader(TensorDataset(train_x0, train_x0_weights), batch_size=self.batch_size, shuffle=True, drop_last=True), |
| | } |
| | val_loaders = { |
| | "x0": DataLoader(TensorDataset(val_x0, val_x0_weights), batch_size=self.batch_size, shuffle=False, drop_last=True), |
| | } |
| |
|
| | |
| | for t_key in sorted(train_intermediate.keys()): |
| | train_coords_t, train_weights_t = train_intermediate[t_key] |
| | val_coords_t, val_weights_t = val_intermediate[t_key] |
| | train_loaders[t_key] = DataLoader( |
| | TensorDataset(train_coords_t, train_weights_t), |
| | batch_size=self.batch_size, |
| | shuffle=True, |
| | drop_last=True |
| | ) |
| | val_loaders[t_key] = DataLoader( |
| | TensorDataset(val_coords_t, val_weights_t), |
| | batch_size=self.batch_size, |
| | shuffle=False, |
| | drop_last=True |
| | ) |
| |
|
| | for b in range(self.num_branches): |
| | |
| | branch_size = self.branch_endpoints[b].shape[0] |
| | branch_split_index = int(branch_size * self.split_ratios[0]) |
| | if branch_size - branch_split_index < self.batch_size: |
| | branch_split_index = max(0, branch_size - self.batch_size) |
| | |
| | train_branch = self.branch_endpoints[b][:branch_split_index] |
| | val_branch = self.branch_endpoints[b][branch_split_index:] |
| | train_branch_weights = torch.full((train_branch.shape[0], 1), fill_value=branch_weights[b]) |
| | val_branch_weights = torch.full((val_branch.shape[0], 1), fill_value=branch_weights[b]) |
| | train_loaders[f"x1_{b+1}"] = DataLoader( |
| | TensorDataset(train_branch, train_branch_weights), |
| | batch_size=self.batch_size, |
| | shuffle=True, |
| | drop_last=True |
| | ) |
| | val_loaders[f"x1_{b+1}"] = DataLoader( |
| | TensorDataset(val_branch, val_branch_weights), |
| | batch_size=self.batch_size, |
| | shuffle=True, |
| | drop_last=True |
| | ) |
| |
|
| | self.train_dataloaders = train_loaders |
| | self.val_dataloaders = val_loaders |
| |
|
| | |
| | all_data_list = [coords_by_t[t] for t in sorted(coords_by_t.keys())] |
| | all_data = np.vstack(all_data_list) |
| | self.dataset = torch.tensor(all_data, dtype=torch.float32) |
| | self.tree = cKDTree(all_data) |
| |
|
| | self.test_dataloaders = { |
| | "x0": DataLoader(TensorDataset(self.val_x0, val_x0_weights), batch_size=self.val_x0.shape[0], shuffle=False, drop_last=False), |
| | "dataset": DataLoader(TensorDataset(self.dataset), batch_size=self.dataset.shape[0], shuffle=False, drop_last=False), |
| | } |
| |
|
| | |
| | cluster_0_data = self.coords_t0.cpu().numpy() |
| | cluster_1_list = [self.coords_intermediate[t].cpu().numpy() for t in sorted(self.coords_intermediate.keys())] |
| | cluster_1_list.extend([self.branch_endpoints[b].cpu().numpy() for b in range(self.num_branches)]) |
| | cluster_1_data = np.vstack(cluster_1_list) |
| |
|
| | self.metric_samples_dataloaders = [ |
| | DataLoader(torch.tensor(cluster_0_data, dtype=torch.float32), batch_size=cluster_0_data.shape[0], shuffle=False, drop_last=False), |
| | DataLoader(torch.tensor(cluster_1_data, dtype=torch.float32), batch_size=cluster_1_data.shape[0], shuffle=False, drop_last=False), |
| | ] |
| |
|
| | def train_dataloader(self): |
| | combined_loaders = { |
| | "train_samples": CombinedLoader(self.train_dataloaders, mode="min_size"), |
| | "metric_samples": CombinedLoader(self.metric_samples_dataloaders, mode="min_size"), |
| | } |
| | return CombinedLoader(combined_loaders, mode="max_size_cycle") |
| |
|
| | def val_dataloader(self): |
| | combined_loaders = { |
| | "val_samples": CombinedLoader(self.val_dataloaders, mode="min_size"), |
| | "metric_samples": CombinedLoader(self.metric_samples_dataloaders, mode="min_size"), |
| | } |
| | return CombinedLoader(combined_loaders, mode="max_size_cycle") |
| |
|
| | def test_dataloader(self): |
| | combined_loaders = { |
| | "test_samples": CombinedLoader(self.test_dataloaders, mode="min_size"), |
| | "metric_samples": CombinedLoader(self.metric_samples_dataloaders, mode="min_size"), |
| | } |
| | return CombinedLoader(combined_loaders, mode="max_size_cycle") |
| |
|
| | def get_manifold_proj(self, points): |
| | return partial(self.local_smoothing_op, tree=self.tree, dataset=self.dataset) |
| |
|
| | @staticmethod |
| | def local_smoothing_op(x, tree, dataset, k=10, temp=1e-3): |
| | points_np = x.detach().cpu().numpy() |
| | _, idx = tree.query(points_np, k=k) |
| | nearest_pts = dataset[idx] |
| | dists = (x.unsqueeze(1) - nearest_pts).pow(2).sum(-1, keepdim=True) |
| | weights = torch.exp(-dists / temp) |
| | weights = weights / weights.sum(dim=1, keepdim=True) |
| | smoothed = (weights * nearest_pts).sum(dim=1) |
| | alpha = 0.3 |
| | return (1 - alpha) * x + alpha * smoothed |
| |
|
| | def get_timepoint_data(self): |
| | result = { |
| | 't0': self.coords_t0, |
| | 'time_labels': self.time_labels |
| | } |
| | |
| | for t in sorted(self.coords_intermediate.keys()): |
| | result[f't{t}'] = self.coords_intermediate[t] |
| | final_t = max([0] + list(self.coords_intermediate.keys())) + 1 |
| | for b in range(self.num_branches): |
| | result[f't{final_t}_{b}'] = self.branch_endpoints[b] |
| | return result |
| |
|
| | def get_train_intermediate_data(self): |
| | if hasattr(self, 'train_coords_intermediate'): |
| | return self.train_coords_intermediate |
| | else: |
| | |
| | print("Warning: train_coords_intermediate not found, returning full intermediate data.") |
| | return self.coords_intermediate |
| |
|