| import os |
| import argparse |
| import numpy as np |
| import umap |
| import matplotlib.pyplot as plt |
| import seaborn as sns |
| import torch |
| from dataclasses import dataclass, field |
| from sklearn.decomposition import PCA as SklearnPCA |
| from sklearn.manifold import TSNE as SklearnTSNE |
| from typing import Optional, Union, List |
| from matplotlib.colors import LinearSegmentedColormap |
|
|
| try: |
| from utils import torch_load, print_message |
| from seed_utils import get_global_seed, set_global_seed, set_determinism |
| from data.data_mixin import DataMixin, DataArguments |
| from embedder import Embedder, EmbeddingArguments, get_embedding_filename |
| except ImportError: |
| from ..utils import torch_load, print_message |
| from ..seed_utils import get_global_seed, set_global_seed, set_determinism |
| from ..data.data_mixin import DataMixin, DataArguments |
| from ..embedder import Embedder, EmbeddingArguments, get_embedding_filename |
|
|
|
|
| os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" |
| os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" |
| os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0" |
| os.environ["TOKENIZERS_PARALLELISM"] = "true" |
|
|
|
|
| @dataclass |
| class VisualizationArguments: |
| |
| embedding_save_dir: str = "embeddings" |
| fig_dir: str = "figures" |
| |
| |
| model_name: str = "ESM2-8" |
| matrix_embed: bool = False |
| sql: bool = False |
| |
| |
| embedding_batch_size: int = 16 |
| num_workers: int = 0 |
| download_embeddings: bool = False |
| download_dir: str = "Synthyra/vector_embeddings" |
| embedding_pooling_types: List[str] = field(default_factory=lambda: ["mean"]) |
| save_embeddings: bool = False |
| embed_dtype: str = "float32" |
| |
| |
| n_components: int = 2 |
| perplexity: float = 30.0 |
| n_neighbors: int = 15 |
| min_dist: float = 0.1 |
| |
| |
| seed: Optional[int] = None |
| deterministic: bool = False |
| fig_size: tuple = (10, 10) |
| save_fig: bool = True |
| task_type: str = "singlelabel" |
|
|
|
|
| class DimensionalityReducer(DataMixin): |
| """Base class for dimensionality reduction techniques""" |
| def __init__(self, args: VisualizationArguments): |
| |
| super().__init__(data_args=None) |
| self.args = args |
| self.embeddings = None |
| self.labels = None |
| |
| self._sql = args.sql |
| self._full = args.matrix_embed |
| |
| def _check_and_embed(self, sequences: List[str]): |
| """Check if embeddings exist, and embed sequences if they don't""" |
| |
| os.makedirs(self.args.embedding_save_dir, exist_ok=True) |
| |
| |
| pooling_types = self.args.embedding_pooling_types |
| filename_pth = get_embedding_filename(self.args.model_name, self.args.matrix_embed, pooling_types, 'pth') |
| filename_db = get_embedding_filename(self.args.model_name, self.args.matrix_embed, pooling_types, 'db') |
| save_path = os.path.join(self.args.embedding_save_dir, filename_pth) |
| db_path = os.path.join(self.args.embedding_save_dir, filename_db) |
| |
| if self._sql: |
| |
| import sqlite3 |
| if os.path.exists(db_path): |
| conn = sqlite3.connect(db_path) |
| c = conn.cursor() |
| c.execute('CREATE TABLE IF NOT EXISTS embeddings (sequence text PRIMARY KEY, embedding blob)') |
| c.execute("SELECT sequence FROM embeddings") |
| already_embedded = set(row[0] for row in c.fetchall()) |
| conn.close() |
| to_embed = [seq for seq in sequences if seq not in already_embedded] |
| else: |
| to_embed = sequences |
| else: |
| |
| if os.path.exists(save_path): |
| emb_dict = torch_load(save_path) |
| to_embed = [seq for seq in sequences if seq not in emb_dict] |
| else: |
| to_embed = sequences |
| |
| |
| if len(to_embed) > 0: |
| print_message(f"Embedding {len(to_embed)} sequences that are not yet embedded") |
| |
| dtype_map = { |
| "float32": torch.float32, |
| "float16": torch.float16, |
| "bfloat16": torch.bfloat16, |
| } |
| embed_dtype = dtype_map.get(self.args.embed_dtype, torch.float32) |
| |
| |
| embedding_args = EmbeddingArguments( |
| embedding_batch_size=self.args.embedding_batch_size, |
| embedding_num_workers=self.args.num_workers, |
| download_embeddings=self.args.download_embeddings, |
| download_dir=self.args.download_dir, |
| matrix_embed=self.args.matrix_embed, |
| embedding_pooling_types=self.args.embedding_pooling_types, |
| save_embeddings=True, |
| embed_dtype=embed_dtype, |
| sql=self.args.sql, |
| embedding_save_dir=self.args.embedding_save_dir |
| ) |
| |
| embedder = Embedder(embedding_args, sequences) |
| |
| embedder(self.args.model_name) |
| print_message(f"Finished embedding sequences") |
| else: |
| print_message(f"All {len(sequences)} sequences are already embedded") |
| |
| def load_embeddings(self, sequences: List[str], labels: Optional[List[Union[int, float, List[int]]]] = None): |
| """Load embeddings from file using DataMixin functionality""" |
| |
| self._check_and_embed(sequences) |
| |
| embeddings = [] |
| |
| pooling_types = self.args.embedding_pooling_types |
| if self._sql: |
| import sqlite3 |
| filename = get_embedding_filename(self.args.model_name, self.args.matrix_embed, pooling_types, 'db') |
| save_path = os.path.join(self.args.embedding_save_dir, filename) |
| with sqlite3.connect(save_path) as conn: |
| c = conn.cursor() |
| for seq in sequences: |
| |
| embedding = self._select_from_sql(c, seq, cast_to_torch=False) |
| |
| if len(embedding.shape) > 1: |
| if self._full: |
| |
| embedding = embedding.mean(axis=0) |
| else: |
| |
| embedding = embedding.squeeze(0) |
| embeddings.append(embedding) |
| else: |
| filename = get_embedding_filename(self.args.model_name, self.args.matrix_embed, pooling_types, 'pth') |
| save_path = os.path.join(self.args.embedding_save_dir, filename) |
| emb_dict = torch_load(save_path) |
| for seq in sequences: |
| |
| embedding = self._select_from_pth(emb_dict, seq, cast_to_np=True) |
| |
| if len(embedding.shape) > 1: |
| if self._full: |
| |
| embedding = embedding.mean(axis=0) |
| else: |
| |
| embedding = embedding.squeeze(0) |
| embeddings.append(embedding) |
|
|
| print_message(f"Loaded {len(embeddings)} embeddings") |
| self.embeddings = np.stack(embeddings) |
| if labels is not None: |
| |
| self.labels = np.array(labels) |
| else: |
| self.labels = None |
| |
| def fit_transform(self): |
| """Implement in child class""" |
| raise NotImplementedError |
| |
| def plot(self, save_name: Optional[str] = None): |
| """Plot the reduced dimensionality embeddings with appropriate coloring scheme""" |
| if self.embeddings is None: |
| raise ValueError("No embeddings loaded. Call load_embeddings() first.") |
| |
| print_message("Fitting and transforming") |
| reduced = self.fit_transform() |
| print_message("Plotting") |
| plt.figure(figsize=self.args.fig_size) |
| |
| if self.labels is None: |
| |
| scatter = plt.scatter(reduced[:, 0], reduced[:, 1], alpha=0.6) |
| |
| elif self.args.task_type == "singlelabel": |
| unique_labels = np.unique(self.labels) |
| |
| if len(unique_labels) == 2: |
| colors = ['#ff7f0e', '#1f77b4'] |
| cmap = LinearSegmentedColormap.from_list('binary', colors) |
| scatter = plt.scatter(reduced[:, 0], reduced[:, 1], |
| c=self.labels, cmap=cmap, alpha=0.6) |
| plt.colorbar(scatter, ticks=[0, 1]) |
| else: |
| n_classes = len(unique_labels) |
| if n_classes <= 10: |
| cmap = 'tab10' |
| elif n_classes <= 20: |
| cmap = 'tab20' |
| else: |
| |
| colors = sns.color_palette('husl', n_colors=n_classes) |
| cmap = LinearSegmentedColormap.from_list('custom', colors) |
| |
| scatter = plt.scatter(reduced[:, 0], reduced[:, 1], |
| c=self.labels, cmap=cmap, alpha=0.6) |
| plt.colorbar(scatter, ticks=unique_labels) |
| |
| elif self.args.task_type == "multilabel": |
| |
| |
| |
| label_colors = np.zeros(len(self.labels)) |
| label_counts = np.sum(self.labels, axis=1) |
| |
| |
| for i, label_row in enumerate(self.labels): |
| if label_counts[i] > 0: |
| |
| positive_indices = np.where(label_row == 1)[0] |
| avg_position = np.mean(positive_indices) / (self.labels.shape[1] - 1) |
| label_colors[i] = avg_position |
| |
| |
| blue_red_cmap = LinearSegmentedColormap.from_list('blue_red', ['blue', 'red']) |
| |
| |
| scatter = plt.scatter(reduced[:, 0], reduced[:, 1], |
| c=label_colors, cmap=blue_red_cmap, |
| s=30 + 20 * label_counts, alpha=0.6) |
| |
| |
| plt.colorbar(scatter, label='Label Position (blue=first, red=last)') |
| |
| |
| handles, labels = [], [] |
| for count in sorted(set(label_counts)): |
| handles.append(plt.scatter([], [], s=30 + 20 * count, color='gray')) |
| labels.append(f'{int(count)} labels') |
| plt.legend(handles, labels, title='Label Count', loc='upper right') |
| |
| elif self.args.task_type == "regression": |
| |
| vmin, vmax = np.percentile(self.labels, [2, 98]) |
| norm = plt.Normalize(vmin=vmin, vmax=vmax) |
| scatter = plt.scatter(reduced[:, 0], reduced[:, 1], |
| c=self.labels, cmap='viridis', |
| norm=norm, alpha=0.6) |
| plt.colorbar(scatter, label='Value') |
| |
| plt.title(f'{self.__class__.__name__} visualization of {self.args.model_name} embeddings') |
| plt.xlabel('Component 1') |
| plt.ylabel('Component 2') |
| |
| if save_name is not None and self.args.save_fig: |
| os.makedirs(self.args.fig_dir, exist_ok=True) |
| plt.savefig(os.path.join(self.args.fig_dir, save_name), |
| dpi=300, bbox_inches='tight') |
| plt.show() |
| plt.close() |
|
|
|
|
| class PCA(DimensionalityReducer): |
| def __init__(self, args: VisualizationArguments): |
| super().__init__(args) |
| self.pca = SklearnPCA(n_components=args.n_components, random_state=get_global_seed() or args.seed) |
| |
| def fit_transform(self): |
| return self.pca.fit_transform(self.embeddings) |
|
|
|
|
| class TSNE(DimensionalityReducer): |
| def __init__(self, args: VisualizationArguments): |
| super().__init__(args) |
| self.tsne = SklearnTSNE( |
| n_components=self.args.n_components, |
| perplexity=self.args.perplexity, |
| random_state=get_global_seed() or self.args.seed |
| ) |
| |
| def fit_transform(self): |
| return self.tsne.fit_transform(self.embeddings) |
|
|
|
|
| class UMAP(DimensionalityReducer): |
| def __init__(self, args: VisualizationArguments): |
| super().__init__(args) |
| self.umap = umap.UMAP( |
| n_components=self.args.n_components, |
| n_neighbors=self.args.n_neighbors, |
| min_dist=self.args.min_dist, |
| random_state=get_global_seed() or self.args.seed |
| ) |
| |
| def fit_transform(self): |
| return self.umap.fit_transform(self.embeddings) |
|
|
|
|
| def parse_arguments(): |
| """Parse command line arguments for visualization""" |
| parser = argparse.ArgumentParser(description="Dimensionality reduction visualization for protein embeddings") |
| |
| |
| parser.add_argument("--embedding_save_dir", type=str, default="embeddings", |
| help="Directory to save/load embeddings.") |
| parser.add_argument("--fig_dir", type=str, default="figures", |
| help="Directory to save figures.") |
| |
| |
| parser.add_argument("--model_name", type=str, default="ESM2-8", |
| help="Model name to use for embeddings.") |
| parser.add_argument("--matrix_embed", action="store_true", default=False, |
| help="Use matrix embedding (per-residue embeddings).") |
| parser.add_argument("--sql", action="store_true", default=False, |
| help="Use SQL storage for embeddings.") |
| |
| |
| parser.add_argument("--embedding_batch_size", type=int, default=16, |
| help="Batch size for embedding generation.") |
| parser.add_argument("--num_workers", type=int, default=0, |
| help="Number of worker processes for data loading.") |
| parser.add_argument("--download_embeddings", action="store_true", default=False, |
| help="Download embeddings from HuggingFace hub.") |
| parser.add_argument("--download_dir", type=str, default="Synthyra/vector_embeddings", |
| help="Directory to download embeddings from.") |
| parser.add_argument("--embedding_pooling_types", nargs="+", default=["mean", "var"], |
| help="Pooling types for embeddings.") |
| parser.add_argument("--save_embeddings", action="store_true", default=False, |
| help="Save computed embeddings (auto-enabled when embedding).") |
| parser.add_argument("--embed_dtype", type=str, default="float32", |
| choices=["float32", "float16", "bfloat16"], |
| help="Data type for embeddings.") |
| |
| |
| parser.add_argument("--data_names", nargs="+", default=["EC"], |
| help="List of dataset names to visualize.") |
| parser.add_argument("--max_length", type=int, default=1024, |
| help="Maximum sequence length.") |
| parser.add_argument("--trim", action="store_true", default=False, |
| help="Trim sequences to max_length instead of removing them.") |
| |
| |
| parser.add_argument("--n_components", type=int, default=2, |
| help="Number of components for dimensionality reduction.") |
| parser.add_argument("--perplexity", type=float, default=30.0, |
| help="Perplexity parameter for t-SNE.") |
| parser.add_argument("--n_neighbors", type=int, default=15, |
| help="Number of neighbors for UMAP.") |
| parser.add_argument("--min_dist", type=float, default=0.1, |
| help="Minimum distance for UMAP.") |
| |
| |
| parser.add_argument("--seed", type=int, default=None, |
| help="Seed for reproducibility (if omitted, current time is used).") |
| parser.add_argument("--deterministic", action="store_true", default=False, |
| help="Enable deterministic behavior (slower but reproducible).") |
| parser.add_argument("--fig_size", nargs=2, type=int, default=[10, 10], |
| help="Figure size (width height).") |
| parser.add_argument("--save_fig", action="store_true", default=True, |
| help="Save figures to disk.") |
| parser.add_argument("--task_type", type=str, default=None, |
| choices=["singlelabel", "multilabel", "regression"], |
| help="Task type (auto-detected from dataset if not specified).") |
| |
| |
| parser.add_argument("--methods", nargs="+", |
| choices=["PCA", "TSNE", "UMAP"], |
| default=["PCA", "TSNE", "UMAP"], |
| help="Dimensionality reduction methods to use.") |
| |
| return parser.parse_args() |
|
|
|
|
| if __name__ == "__main__": |
| |
| |
| args = parse_arguments() |
| |
| |
| if args.deterministic: |
| set_determinism() |
| |
| |
| chosen_seed = set_global_seed(args.seed) |
| args.seed = chosen_seed |
| print_message(f"Using seed: {chosen_seed}") |
| |
| |
| data_args = DataArguments( |
| data_names=args.data_names, |
| max_length=args.max_length, |
| trim=args.trim |
| ) |
| data_mixin = DataMixin(data_args=data_args) |
| datasets, all_seqs = data_mixin.get_data() |
| |
| |
| dataset_name = list(datasets.keys())[0] |
| train_set, valid_set, test_set, num_labels, label_type, ppi = datasets[dataset_name] |
| |
| |
| if args.task_type is None: |
| if label_type == "multilabel": |
| task_type = "multilabel" |
| elif label_type in ["regression", "sigmoid_regression"]: |
| task_type = "regression" |
| else: |
| task_type = "singlelabel" |
| else: |
| task_type = args.task_type |
| |
| sequences = list(train_set["seqs"]) |
| labels = list(train_set["labels"]) |
| |
| |
| vis_args = VisualizationArguments( |
| embedding_save_dir=args.embedding_save_dir, |
| fig_dir=args.fig_dir, |
| model_name=args.model_name, |
| matrix_embed=args.matrix_embed, |
| sql=args.sql, |
| embedding_batch_size=args.embedding_batch_size, |
| num_workers=args.num_workers, |
| download_embeddings=args.download_embeddings, |
| download_dir=args.download_dir, |
| embedding_pooling_types=args.embedding_pooling_types, |
| save_embeddings=args.save_embeddings, |
| embed_dtype=args.embed_dtype, |
| n_components=args.n_components, |
| perplexity=args.perplexity, |
| n_neighbors=args.n_neighbors, |
| min_dist=args.min_dist, |
| seed=args.seed, |
| deterministic=args.deterministic, |
| fig_size=tuple(args.fig_size), |
| save_fig=args.save_fig, |
| task_type=task_type |
| ) |
| |
| |
| method_map = { |
| "PCA": PCA, |
| "TSNE": TSNE, |
| "UMAP": UMAP |
| } |
| |
| |
| for method_name in args.methods: |
| if method_name not in method_map: |
| print_message(f"Unknown method: {method_name}, skipping") |
| continue |
| |
| Reducer = method_map[method_name] |
| print_message(f"Running {Reducer.__name__}") |
| reducer = Reducer(vis_args) |
| print_message("Loading embeddings") |
| reducer.load_embeddings(sequences, labels) |
| reducer.plot(f"{dataset_name}_{Reducer.__name__}.png") |
|
|