#!/usr/bin/env python """ Precompute species embeddings for CodonTranslator training. Protein embeddings are now computed on-the-fly using integrated ESM-C model. Steps: 1. Build taxonomy database from GBIF API 2. Generate species embeddings using Qwen3-Embedding-0.6B """ import os import json import logging import argparse from pathlib import Path from typing import Dict, List, Optional, Tuple import glob import requests import time from collections import defaultdict import numpy as np import pandas as pd import torch from tqdm import tqdm logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) def build_taxonomy_database(species_list: List[str]) -> Dict[str, str]: """Query GBIF API for comprehensive phylogenetic taxonomy of species. Creates detailed taxonomic descriptions for better species embeddings. """ taxonomy_db = {} base_url = "https://api.gbif.org/v1/species/match" logger.info(f"Building taxonomy database for {len(species_list)} species...") for species in tqdm(species_list, desc="Querying GBIF"): if not species or species in taxonomy_db: continue try: response = requests.get(base_url, params={"name": species}) if response.status_code == 200: data = response.json() if data.get("matchType") != "NONE": # Build comprehensive taxonomy description parts = [] # Add scientific classification taxonomy = [] for rank in ["kingdom", "phylum", "class", "order", "family", "genus", "species"]: if rank in data and data[rank]: taxonomy.append(data[rank]) if taxonomy: parts.append("Taxonomy: " + " > ".join(taxonomy)) # Add common name if available if "vernacularName" in data and data["vernacularName"]: parts.append(f"Common name: {data['vernacularName']}") # Add confidence score if "confidence" in data: parts.append(f"Match confidence: {data['confidence']}%") # Add status (accepted, synonym, etc.) if "status" in data: parts.append(f"Status: {data['status']}") # Combine all parts into comprehensive description taxonomy_db[species] = ". ".join(parts) if parts else species else: # No match found - use species name with indicator taxonomy_db[species] = f"Species: {species} (no GBIF match)" else: taxonomy_db[species] = f"Species: {species} (query failed)" # Rate limiting time.sleep(0.1) except Exception as e: logger.warning(f"Error querying GBIF for {species}: {e}") taxonomy_db[species] = f"Species: {species} (error)" logger.info(f"Taxonomy database built with {len(taxonomy_db)} entries") return taxonomy_db def generate_species_embeddings_qwen( species_list: List[str], taxonomy_db: Dict[str, str], device: str = "cuda", pooling: str = "last" # 'last' -> single vector; 'sequence'/'none' -> variable-length tokens ) -> Tuple[Dict[str, int], Dict[int, np.ndarray]]: """ Generate species embeddings using Qwen3-Embedding-0.6B. - pooling='last': returns one vector per species (fixed size) - pooling='none': returns variable-length token embeddings per species """ import torch.nn.functional as F from transformers import AutoTokenizer, AutoModel def last_token_pool(last_hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: """Pool by taking the last valid token's embedding.""" left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0]) if left_padding: return last_hidden_states[:, -1] else: sequence_lengths = attention_mask.sum(dim=1) - 1 batch_size = last_hidden_states.shape[0] return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths] def get_detailed_instruct(task_description: str, query: str) -> str: """Format the input with instruction for better embedding quality.""" return f'Instruct: {task_description}\nQuery: {query}' logger.info("Loading Qwen3-Embedding-0.6B model...") model_name = "Qwen/Qwen3-Embedding-0.6B" # Initialize with left padding for last token pooling tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, padding_side='left') model = AutoModel.from_pretrained(model_name, torch_dtype=torch.float16, trust_remote_code=True).to(device).eval() species_vocab = {} species_embeddings = {} # Task description for species embedding task = "Given a species taxonomy information, generate a biological embedding representing its taxonomic and evolutionary characteristics" for idx, species in enumerate(tqdm(species_list, desc="Generating embeddings")): # Get comprehensive taxonomy string from GBIF query results taxonomy_str = taxonomy_db.get(species, species) # Format with instruction for better semantic understanding input_text = get_detailed_instruct(task, taxonomy_str) # Generate embeddings with torch.no_grad(): inputs = tokenizer( input_text, return_tensors="pt", padding=True, truncation=True, max_length=512 ) inputs = {k: v.to(device) for k, v in inputs.items()} outputs = model(**inputs) hidden = outputs.last_hidden_state # [1, L, D] if pooling == 'last': pooled_embedding = last_token_pool(hidden, inputs['attention_mask']) normalized_embedding = F.normalize(pooled_embedding, p=2, dim=1) species_embedding = normalized_embedding.squeeze(0).cpu().numpy() # [D] else: # Variable-length token embeddings (normalize per token) tok = hidden.squeeze(0) # [L, D] tok = F.normalize(tok, p=2, dim=-1) species_embedding = tok.cpu().numpy() # [L, D] species_vocab[species] = idx species_embeddings[idx] = species_embedding logger.info(f"Generated {'fixed-size' if pooling=='last' else 'variable-length'} embeddings for {len(species_vocab)} species") return species_vocab, species_embeddings def save_species_embeddings_memmap( species_vocab: Dict[str, int], species_embeddings: Dict[int, np.ndarray], output_dir: str ) -> None: """Save fixed-size species embeddings as memory-mapped file.""" os.makedirs(output_dir, exist_ok=True) # Save vocabulary vocab_path = os.path.join(output_dir, "species_vocab.json") with open(vocab_path, 'w') as f: json.dump(species_vocab, f, indent=2) # All embeddings should have the same dimension now num_species = len(species_embeddings) embed_dim = next(iter(species_embeddings.values())).shape[0] # Should be 1024 # Create memmap for fixed-size embeddings emb_path = os.path.join(output_dir, "species_embeddings.bin") mmap = np.memmap(emb_path, dtype=np.float32, mode='w+', shape=(num_species, embed_dim)) # Store embeddings directly by ID for species_id, emb in species_embeddings.items(): mmap[species_id] = emb.astype(np.float32) # Flush to disk del mmap # Save metadata metadata = { "num_species": num_species, "embedding_dim": embed_dim, "embedding_type": "fixed_size", "pooling_method": "last_token", "normalization": "L2", "model": "Qwen/Qwen3-Embedding-0.6B" } metadata_path = os.path.join(output_dir, "species_metadata.json") with open(metadata_path, 'w') as f: json.dump(metadata, f, indent=2) logger.info(f"Saved {num_species} fixed-size species embeddings to {emb_path}") logger.info(f"Embedding dimension: {embed_dim}") logger.info(f"Saved metadata to {metadata_path}") def save_species_token_embeddings_memmap( species_vocab: Dict[str, int], species_tok_embeddings: Dict[int, np.ndarray], output_dir: str, dtype: str = 'float32' ) -> None: """Save variable-length token embeddings into a flat memmap with index.""" os.makedirs(output_dir, exist_ok=True) # Save vocabulary vocab_path = os.path.join(output_dir, "species_vocab.json") with open(vocab_path, 'w') as f: json.dump(species_vocab, f, indent=2) # Compute totals and dims embed_dim = next(iter(species_tok_embeddings.values())).shape[1] total_tokens = int(sum(v.shape[0] for v in species_tok_embeddings.values())) emb_path = os.path.join(output_dir, "species_tok_emb.bin") mmap = np.memmap(emb_path, dtype=np.float32 if dtype=='float32' else np.float16, mode='w+', shape=(total_tokens, embed_dim)) # Build index index = {} offset = 0 for sid, arr in species_tok_embeddings.items(): L = int(arr.shape[0]) mmap[offset: offset + L] = arr.astype(np.float32 if dtype=='float32' else np.float16) index[str(sid)] = {"offset": offset, "length": L} offset += L del mmap with open(os.path.join(output_dir, "species_index.json"), 'w') as f: json.dump(index, f, indent=2) meta = { "embedding_dim": embed_dim, "dtype": dtype, "total_tokens": total_tokens, "embedding_type": "variable_length", "pooling_method": "none", "model": "Qwen/Qwen3-Embedding-0.6B" } with open(os.path.join(output_dir, "metadata.json"), 'w') as f: json.dump(meta, f, indent=2) logger.info(f"Saved variable-length species token embeddings to {emb_path} with {total_tokens} tokens total") def filter_sequences_by_length(df: pd.DataFrame, max_protein_length: int = 2048) -> pd.DataFrame: """Filter sequences to prevent CUDA OOM during training.""" initial_count = len(df) # Filter by protein length if 'protein_seq' in df.columns: df = df[df['protein_seq'].str.len() <= max_protein_length] # Filter by CDS length (3x protein length) if 'cds_DNA' in df.columns: max_cds_length = max_protein_length * 3 df = df[df['cds_DNA'].str.len() <= max_cds_length] final_count = len(df) if final_count < initial_count: logger.info(f"Filtered from {initial_count} to {final_count} sequences (max_protein_length={max_protein_length})") return df def collect_unique_values_from_shards( shards_glob: str, column: str, max_items: Optional[int] = None ) -> List[str]: """Stream over Parquet shards to collect unique values from a column.""" unique_values = set() shard_files = sorted(glob.glob(shards_glob)) if not shard_files: raise ValueError(f"No parquet files found matching {shards_glob}") logger.info(f"Scanning {len(shard_files)} shards for unique {column} values...") for shard_file in tqdm(shard_files, desc=f"Collecting {column}"): # Some datasets use different casing (e.g., 'taxon' vs 'Taxon'). Resolve robustly. try: import pyarrow.parquet as pq # type: ignore pf = pq.ParquetFile(shard_file) names = set(pf.schema.names) resolved = column if resolved not in names: lower_map = {n.lower(): n for n in names} resolved = lower_map.get(column.lower(), column) except Exception: resolved = column df = pd.read_parquet(shard_file, columns=[resolved]) # Canonicalize to the requested column name for downstream logic. if resolved != column and resolved in df.columns and column not in df.columns: df = df.rename(columns={resolved: column}) unique_values.update(df[column].dropna().unique()) if max_items and len(unique_values) >= max_items: break result = sorted(list(unique_values))[:max_items] if max_items else sorted(list(unique_values)) logger.info(f"Collected {len(result)} unique {column} values") return result def collect_stage1_species(shards_glob: str) -> List[str]: """Extract unique species from Stage-1 shards.""" return collect_unique_values_from_shards(shards_glob, "Taxon") def prepare_species_from_stage1_shards( shards_glob: str, output_dir: str, device: str = "cuda", resume: bool = False, species_pooling: str = "last" ) -> None: """End-to-end species embedding generation from Stage-1 shards.""" os.makedirs(output_dir, exist_ok=True) # Check for existing files vocab_path = os.path.join(output_dir, "species_vocab.json") if resume and os.path.exists(vocab_path): logger.info("Species embeddings already exist. Skipping generation.") return # Collect unique species species_list = collect_stage1_species(shards_glob) logger.info(f"Found {len(species_list)} unique species in shards") # Build taxonomy database taxonomy_cache_path = os.path.join(output_dir, "taxonomy_database.json") if resume and os.path.exists(taxonomy_cache_path): logger.info("Loading cached taxonomy database...") with open(taxonomy_cache_path, 'r') as f: taxonomy_db = json.load(f) else: taxonomy_db = build_taxonomy_database(species_list) with open(taxonomy_cache_path, 'w') as f: json.dump(taxonomy_db, f, indent=2) # Generate embeddings species_vocab, species_embeddings = generate_species_embeddings_qwen( species_list, taxonomy_db, device, pooling=species_pooling ) # Save per requested pooling if species_pooling == 'last': save_species_embeddings_memmap(species_vocab, species_embeddings, output_dir) else: save_species_token_embeddings_memmap(species_vocab, species_embeddings, output_dir) logger.info("Species embedding preparation complete") def create_precomputed_dataset( input_csv: Optional[str], output_dir: str, device: str = "cuda", batch_size: int = 50, max_protein_length: int = 2048, resume: bool = False, species_pooling: str = "last" ): """ Create embedding dataset with species-only precomputation. Protein embeddings will be computed on-the-fly during training. """ os.makedirs(output_dir, exist_ok=True) # Skip if resuming and files exist if resume and os.path.exists(os.path.join(output_dir, "species_vocab.json")): logger.info("Precomputed dataset already exists. Use --resume=False to regenerate.") return # Load data logger.info(f"Loading data from {input_csv}...") if input_csv.endswith('.parquet'): df = pd.read_parquet(input_csv) else: df = pd.read_csv(input_csv) # Accept either 'Taxon' or 'taxon' as the species column. if "Taxon" not in df.columns and "taxon" in df.columns: df = df.rename(columns={"taxon": "Taxon"}) # Filter sequences by length df = filter_sequences_by_length(df, max_protein_length) # === Species Embeddings === logger.info("=== Generating Species Embeddings ===") unique_species = df["Taxon"].dropna().unique().tolist() logger.info(f"Found {len(unique_species)} unique species") # Build taxonomy database taxonomy_db = build_taxonomy_database(unique_species) # Save taxonomy database taxonomy_path = os.path.join(output_dir, "taxonomy_database.json") with open(taxonomy_path, 'w') as f: json.dump(taxonomy_db, f, indent=2) # Generate species embeddings species_vocab, species_embeddings = generate_species_embeddings_qwen( unique_species, taxonomy_db, device, pooling=species_pooling ) if species_pooling == 'last': save_species_embeddings_memmap(species_vocab, species_embeddings, output_dir) else: save_species_token_embeddings_memmap(species_vocab, species_embeddings, output_dir) # Save metadata metadata = { "num_sequences": len(df), "num_species": len(unique_species), "species_embedding_model": "Qwen/Qwen3-Embedding-0.6B", "species_embedding_dim": 1024, # Qwen3 dimension "max_protein_length": max_protein_length, } with open(os.path.join(output_dir, "metadata.json"), 'w') as f: json.dump(metadata, f, indent=2) logger.info(f"Dataset creation completed. Species embeddings are precomputed.") logger.info("Protein embeddings will be computed on-the-fly during training using integrated ESM-C.") def main(): parser = argparse.ArgumentParser(description="Precompute species embeddings for CodonTranslator") # Data source options parser.add_argument("--input_csv", type=str, help="Path to input CSV/Parquet file") parser.add_argument("--from_stage1_shards", action="store_true", help="Generate from Stage-1 Parquet shards instead of CSV") parser.add_argument("--stage1_shards_glob", type=str, default="./data/shards/*.parquet", help="Glob pattern for Stage-1 shards") # Output parser.add_argument("--output_dir", type=str, required=True, help="Output directory for precomputed embeddings") # Processing options parser.add_argument("--device", type=str, default="cuda", help="Device for model inference") parser.add_argument("--batch_size", type=int, default=50, help="Batch size for embedding generation") parser.add_argument("--max_protein_length", type=int, default=2048, help="Maximum protein sequence length") parser.add_argument("--resume", action="store_true", help="Resume from checkpoint if available") parser.add_argument("--species_pooling", type=str, choices=["last", "sequence", "none"], default="last", help="'last' for single-token; 'sequence' for variable-length token embeddings") args = parser.parse_args() # Route to appropriate function if args.from_stage1_shards: prepare_species_from_stage1_shards( args.stage1_shards_glob, args.output_dir, args.device, args.resume, args.species_pooling ) elif args.input_csv: create_precomputed_dataset( args.input_csv, args.output_dir, args.device, args.batch_size, args.max_protein_length, args.resume, args.species_pooling ) else: raise ValueError("Must specify either --input_csv or --from_stage1_shards") logger.info("Precomputation complete!") if __name__ == "__main__": main()