| |
| """ |
| Precompute species embeddings for CodonTranslator training. |
| Protein embeddings are now computed on-the-fly using integrated ESM-C model. |
| |
| Steps: |
| 1. Build taxonomy database from GBIF API |
| 2. Generate species embeddings using Qwen3-Embedding-0.6B |
| """ |
|
|
| import os |
| import json |
| import logging |
| import argparse |
| from pathlib import Path |
| from typing import Dict, List, Optional, Tuple |
| import glob |
| import requests |
| import time |
| from collections import defaultdict |
|
|
| import numpy as np |
| import pandas as pd |
| import torch |
| from tqdm import tqdm |
|
|
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') |
| logger = logging.getLogger(__name__) |
|
|
|
|
| def build_taxonomy_database(species_list: List[str]) -> Dict[str, str]: |
| """Query GBIF API for comprehensive phylogenetic taxonomy of species. |
| |
| Creates detailed taxonomic descriptions for better species embeddings. |
| """ |
| taxonomy_db = {} |
| base_url = "https://api.gbif.org/v1/species/match" |
| |
| logger.info(f"Building taxonomy database for {len(species_list)} species...") |
| for species in tqdm(species_list, desc="Querying GBIF"): |
| if not species or species in taxonomy_db: |
| continue |
| |
| try: |
| response = requests.get(base_url, params={"name": species}) |
| if response.status_code == 200: |
| data = response.json() |
| if data.get("matchType") != "NONE": |
| |
| parts = [] |
| |
| |
| taxonomy = [] |
| for rank in ["kingdom", "phylum", "class", "order", "family", "genus", "species"]: |
| if rank in data and data[rank]: |
| taxonomy.append(data[rank]) |
| |
| if taxonomy: |
| parts.append("Taxonomy: " + " > ".join(taxonomy)) |
| |
| |
| if "vernacularName" in data and data["vernacularName"]: |
| parts.append(f"Common name: {data['vernacularName']}") |
| |
| |
| if "confidence" in data: |
| parts.append(f"Match confidence: {data['confidence']}%") |
| |
| |
| if "status" in data: |
| parts.append(f"Status: {data['status']}") |
| |
| |
| taxonomy_db[species] = ". ".join(parts) if parts else species |
| else: |
| |
| taxonomy_db[species] = f"Species: {species} (no GBIF match)" |
| else: |
| taxonomy_db[species] = f"Species: {species} (query failed)" |
| |
| |
| time.sleep(0.1) |
| except Exception as e: |
| logger.warning(f"Error querying GBIF for {species}: {e}") |
| taxonomy_db[species] = f"Species: {species} (error)" |
| |
| logger.info(f"Taxonomy database built with {len(taxonomy_db)} entries") |
| return taxonomy_db |
|
|
|
|
| def generate_species_embeddings_qwen( |
| species_list: List[str], |
| taxonomy_db: Dict[str, str], |
| device: str = "cuda", |
| pooling: str = "last" |
| ) -> Tuple[Dict[str, int], Dict[int, np.ndarray]]: |
| """ |
| Generate species embeddings using Qwen3-Embedding-0.6B. |
| - pooling='last': returns one vector per species (fixed size) |
| - pooling='none': returns variable-length token embeddings per species |
| """ |
| import torch.nn.functional as F |
| from transformers import AutoTokenizer, AutoModel |
| |
| def last_token_pool(last_hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: |
| """Pool by taking the last valid token's embedding.""" |
| left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0]) |
| if left_padding: |
| return last_hidden_states[:, -1] |
| else: |
| sequence_lengths = attention_mask.sum(dim=1) - 1 |
| batch_size = last_hidden_states.shape[0] |
| return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths] |
| |
| def get_detailed_instruct(task_description: str, query: str) -> str: |
| """Format the input with instruction for better embedding quality.""" |
| return f'Instruct: {task_description}\nQuery: {query}' |
| |
| logger.info("Loading Qwen3-Embedding-0.6B model...") |
| model_name = "Qwen/Qwen3-Embedding-0.6B" |
| |
| |
| tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, padding_side='left') |
| model = AutoModel.from_pretrained(model_name, torch_dtype=torch.float16, trust_remote_code=True).to(device).eval() |
| |
| species_vocab = {} |
| species_embeddings = {} |
| |
| |
| task = "Given a species taxonomy information, generate a biological embedding representing its taxonomic and evolutionary characteristics" |
|
|
| for idx, species in enumerate(tqdm(species_list, desc="Generating embeddings")): |
| |
| taxonomy_str = taxonomy_db.get(species, species) |
| |
| |
| input_text = get_detailed_instruct(task, taxonomy_str) |
| |
| |
| with torch.no_grad(): |
| inputs = tokenizer( |
| input_text, |
| return_tensors="pt", |
| padding=True, |
| truncation=True, |
| max_length=512 |
| ) |
| inputs = {k: v.to(device) for k, v in inputs.items()} |
| |
| outputs = model(**inputs) |
| hidden = outputs.last_hidden_state |
| if pooling == 'last': |
| pooled_embedding = last_token_pool(hidden, inputs['attention_mask']) |
| normalized_embedding = F.normalize(pooled_embedding, p=2, dim=1) |
| species_embedding = normalized_embedding.squeeze(0).cpu().numpy() |
| else: |
| |
| tok = hidden.squeeze(0) |
| tok = F.normalize(tok, p=2, dim=-1) |
| species_embedding = tok.cpu().numpy() |
| |
| species_vocab[species] = idx |
| species_embeddings[idx] = species_embedding |
| |
| logger.info(f"Generated {'fixed-size' if pooling=='last' else 'variable-length'} embeddings for {len(species_vocab)} species") |
| return species_vocab, species_embeddings |
|
|
|
|
| def save_species_embeddings_memmap( |
| species_vocab: Dict[str, int], |
| species_embeddings: Dict[int, np.ndarray], |
| output_dir: str |
| ) -> None: |
| """Save fixed-size species embeddings as memory-mapped file.""" |
| os.makedirs(output_dir, exist_ok=True) |
| |
| |
| vocab_path = os.path.join(output_dir, "species_vocab.json") |
| with open(vocab_path, 'w') as f: |
| json.dump(species_vocab, f, indent=2) |
| |
| |
| num_species = len(species_embeddings) |
| embed_dim = next(iter(species_embeddings.values())).shape[0] |
| |
| |
| emb_path = os.path.join(output_dir, "species_embeddings.bin") |
| mmap = np.memmap(emb_path, dtype=np.float32, mode='w+', shape=(num_species, embed_dim)) |
| |
| |
| for species_id, emb in species_embeddings.items(): |
| mmap[species_id] = emb.astype(np.float32) |
| |
| |
| del mmap |
| |
| |
| metadata = { |
| "num_species": num_species, |
| "embedding_dim": embed_dim, |
| "embedding_type": "fixed_size", |
| "pooling_method": "last_token", |
| "normalization": "L2", |
| "model": "Qwen/Qwen3-Embedding-0.6B" |
| } |
| |
| metadata_path = os.path.join(output_dir, "species_metadata.json") |
| with open(metadata_path, 'w') as f: |
| json.dump(metadata, f, indent=2) |
| |
| logger.info(f"Saved {num_species} fixed-size species embeddings to {emb_path}") |
| logger.info(f"Embedding dimension: {embed_dim}") |
| logger.info(f"Saved metadata to {metadata_path}") |
|
|
|
|
| def save_species_token_embeddings_memmap( |
| species_vocab: Dict[str, int], |
| species_tok_embeddings: Dict[int, np.ndarray], |
| output_dir: str, |
| dtype: str = 'float32' |
| ) -> None: |
| """Save variable-length token embeddings into a flat memmap with index.""" |
| os.makedirs(output_dir, exist_ok=True) |
|
|
| |
| vocab_path = os.path.join(output_dir, "species_vocab.json") |
| with open(vocab_path, 'w') as f: |
| json.dump(species_vocab, f, indent=2) |
|
|
| |
| embed_dim = next(iter(species_tok_embeddings.values())).shape[1] |
| total_tokens = int(sum(v.shape[0] for v in species_tok_embeddings.values())) |
|
|
| emb_path = os.path.join(output_dir, "species_tok_emb.bin") |
| mmap = np.memmap(emb_path, dtype=np.float32 if dtype=='float32' else np.float16, mode='w+', shape=(total_tokens, embed_dim)) |
|
|
| |
| index = {} |
| offset = 0 |
| for sid, arr in species_tok_embeddings.items(): |
| L = int(arr.shape[0]) |
| mmap[offset: offset + L] = arr.astype(np.float32 if dtype=='float32' else np.float16) |
| index[str(sid)] = {"offset": offset, "length": L} |
| offset += L |
|
|
| del mmap |
|
|
| with open(os.path.join(output_dir, "species_index.json"), 'w') as f: |
| json.dump(index, f, indent=2) |
|
|
| meta = { |
| "embedding_dim": embed_dim, |
| "dtype": dtype, |
| "total_tokens": total_tokens, |
| "embedding_type": "variable_length", |
| "pooling_method": "none", |
| "model": "Qwen/Qwen3-Embedding-0.6B" |
| } |
| with open(os.path.join(output_dir, "metadata.json"), 'w') as f: |
| json.dump(meta, f, indent=2) |
| logger.info(f"Saved variable-length species token embeddings to {emb_path} with {total_tokens} tokens total") |
|
|
|
|
| def filter_sequences_by_length(df: pd.DataFrame, max_protein_length: int = 2048) -> pd.DataFrame: |
| """Filter sequences to prevent CUDA OOM during training.""" |
| initial_count = len(df) |
| |
| |
| if 'protein_seq' in df.columns: |
| df = df[df['protein_seq'].str.len() <= max_protein_length] |
| |
| |
| if 'cds_DNA' in df.columns: |
| max_cds_length = max_protein_length * 3 |
| df = df[df['cds_DNA'].str.len() <= max_cds_length] |
| |
| final_count = len(df) |
| if final_count < initial_count: |
| logger.info(f"Filtered from {initial_count} to {final_count} sequences (max_protein_length={max_protein_length})") |
| |
| return df |
|
|
|
|
| def collect_unique_values_from_shards( |
| shards_glob: str, |
| column: str, |
| max_items: Optional[int] = None |
| ) -> List[str]: |
| """Stream over Parquet shards to collect unique values from a column.""" |
| unique_values = set() |
| shard_files = sorted(glob.glob(shards_glob)) |
| |
| if not shard_files: |
| raise ValueError(f"No parquet files found matching {shards_glob}") |
| |
| logger.info(f"Scanning {len(shard_files)} shards for unique {column} values...") |
|
|
| for shard_file in tqdm(shard_files, desc=f"Collecting {column}"): |
| |
| try: |
| import pyarrow.parquet as pq |
| pf = pq.ParquetFile(shard_file) |
| names = set(pf.schema.names) |
| resolved = column |
| if resolved not in names: |
| lower_map = {n.lower(): n for n in names} |
| resolved = lower_map.get(column.lower(), column) |
| except Exception: |
| resolved = column |
|
|
| df = pd.read_parquet(shard_file, columns=[resolved]) |
| |
| if resolved != column and resolved in df.columns and column not in df.columns: |
| df = df.rename(columns={resolved: column}) |
| unique_values.update(df[column].dropna().unique()) |
| |
| if max_items and len(unique_values) >= max_items: |
| break |
| |
| result = sorted(list(unique_values))[:max_items] if max_items else sorted(list(unique_values)) |
| logger.info(f"Collected {len(result)} unique {column} values") |
| return result |
|
|
|
|
| def collect_stage1_species(shards_glob: str) -> List[str]: |
| """Extract unique species from Stage-1 shards.""" |
| return collect_unique_values_from_shards(shards_glob, "Taxon") |
|
|
|
|
| def prepare_species_from_stage1_shards( |
| shards_glob: str, |
| output_dir: str, |
| device: str = "cuda", |
| resume: bool = False, |
| species_pooling: str = "last" |
| ) -> None: |
| """End-to-end species embedding generation from Stage-1 shards.""" |
| os.makedirs(output_dir, exist_ok=True) |
| |
| |
| vocab_path = os.path.join(output_dir, "species_vocab.json") |
| if resume and os.path.exists(vocab_path): |
| logger.info("Species embeddings already exist. Skipping generation.") |
| return |
| |
| |
| species_list = collect_stage1_species(shards_glob) |
| logger.info(f"Found {len(species_list)} unique species in shards") |
| |
| |
| taxonomy_cache_path = os.path.join(output_dir, "taxonomy_database.json") |
| if resume and os.path.exists(taxonomy_cache_path): |
| logger.info("Loading cached taxonomy database...") |
| with open(taxonomy_cache_path, 'r') as f: |
| taxonomy_db = json.load(f) |
| else: |
| taxonomy_db = build_taxonomy_database(species_list) |
| with open(taxonomy_cache_path, 'w') as f: |
| json.dump(taxonomy_db, f, indent=2) |
| |
| |
| species_vocab, species_embeddings = generate_species_embeddings_qwen( |
| species_list, taxonomy_db, device, pooling=species_pooling |
| ) |
| |
| |
| if species_pooling == 'last': |
| save_species_embeddings_memmap(species_vocab, species_embeddings, output_dir) |
| else: |
| save_species_token_embeddings_memmap(species_vocab, species_embeddings, output_dir) |
| |
| logger.info("Species embedding preparation complete") |
|
|
|
|
| def create_precomputed_dataset( |
| input_csv: Optional[str], |
| output_dir: str, |
| device: str = "cuda", |
| batch_size: int = 50, |
| max_protein_length: int = 2048, |
| resume: bool = False, |
| species_pooling: str = "last" |
| ): |
| """ |
| Create embedding dataset with species-only precomputation. |
| Protein embeddings will be computed on-the-fly during training. |
| """ |
| os.makedirs(output_dir, exist_ok=True) |
| |
| |
| if resume and os.path.exists(os.path.join(output_dir, "species_vocab.json")): |
| logger.info("Precomputed dataset already exists. Use --resume=False to regenerate.") |
| return |
| |
| |
| logger.info(f"Loading data from {input_csv}...") |
| if input_csv.endswith('.parquet'): |
| df = pd.read_parquet(input_csv) |
| else: |
| df = pd.read_csv(input_csv) |
|
|
| |
| if "Taxon" not in df.columns and "taxon" in df.columns: |
| df = df.rename(columns={"taxon": "Taxon"}) |
| |
| |
| df = filter_sequences_by_length(df, max_protein_length) |
| |
| |
| logger.info("=== Generating Species Embeddings ===") |
| unique_species = df["Taxon"].dropna().unique().tolist() |
| logger.info(f"Found {len(unique_species)} unique species") |
| |
| |
| taxonomy_db = build_taxonomy_database(unique_species) |
| |
| |
| taxonomy_path = os.path.join(output_dir, "taxonomy_database.json") |
| with open(taxonomy_path, 'w') as f: |
| json.dump(taxonomy_db, f, indent=2) |
| |
| |
| species_vocab, species_embeddings = generate_species_embeddings_qwen( |
| unique_species, taxonomy_db, device, pooling=species_pooling |
| ) |
| |
| if species_pooling == 'last': |
| save_species_embeddings_memmap(species_vocab, species_embeddings, output_dir) |
| else: |
| save_species_token_embeddings_memmap(species_vocab, species_embeddings, output_dir) |
| |
| |
| metadata = { |
| "num_sequences": len(df), |
| "num_species": len(unique_species), |
| "species_embedding_model": "Qwen/Qwen3-Embedding-0.6B", |
| "species_embedding_dim": 1024, |
| "max_protein_length": max_protein_length, |
| } |
| |
| with open(os.path.join(output_dir, "metadata.json"), 'w') as f: |
| json.dump(metadata, f, indent=2) |
| |
| logger.info(f"Dataset creation completed. Species embeddings are precomputed.") |
| logger.info("Protein embeddings will be computed on-the-fly during training using integrated ESM-C.") |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="Precompute species embeddings for CodonTranslator") |
| |
| |
| parser.add_argument("--input_csv", type=str, |
| help="Path to input CSV/Parquet file") |
| parser.add_argument("--from_stage1_shards", action="store_true", |
| help="Generate from Stage-1 Parquet shards instead of CSV") |
| parser.add_argument("--stage1_shards_glob", type=str, default="./data/shards/*.parquet", |
| help="Glob pattern for Stage-1 shards") |
| |
| |
| parser.add_argument("--output_dir", type=str, required=True, |
| help="Output directory for precomputed embeddings") |
| |
| |
| parser.add_argument("--device", type=str, default="cuda", |
| help="Device for model inference") |
| parser.add_argument("--batch_size", type=int, default=50, |
| help="Batch size for embedding generation") |
| parser.add_argument("--max_protein_length", type=int, default=2048, |
| help="Maximum protein sequence length") |
| parser.add_argument("--resume", action="store_true", |
| help="Resume from checkpoint if available") |
| parser.add_argument("--species_pooling", type=str, choices=["last", "sequence", "none"], default="last", |
| help="'last' for single-token; 'sequence' for variable-length token embeddings") |
| |
| args = parser.parse_args() |
| |
| |
| if args.from_stage1_shards: |
| prepare_species_from_stage1_shards( |
| args.stage1_shards_glob, |
| args.output_dir, |
| args.device, |
| args.resume, |
| args.species_pooling |
| ) |
| elif args.input_csv: |
| create_precomputed_dataset( |
| args.input_csv, |
| args.output_dir, |
| args.device, |
| args.batch_size, |
| args.max_protein_length, |
| args.resume, |
| args.species_pooling |
| ) |
| else: |
| raise ValueError("Must specify either --input_csv or --from_stage1_shards") |
| |
| logger.info("Precomputation complete!") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|