import spaces import torch import numpy as np from utils.handle_files import parse_fasta_files import gradio as gr import time import random import os import pandas as pd def get_duration_embeddings(sequences_batch, model, tokenizer, max_duration): return max_duration @spaces.GPU(duration=get_duration_embeddings) def generate_embeddings(sequences_batch, model, tokenizer, max_duration): """Generate embeddings for ESM models using the transformers library. Parameters: ----------- sequences_batch : list of str A batch of sequences for which to generate embeddings. model : AutoModel The pre-loaded ESM model. must already be on the correct device (CPU or GPU). tokenizer : AutoTokenizer The pre-loaded tokenizer corresponding to the ESM model. Returns: -------- sequence_embeddings : 2D np.array of shape (batch_size, embedding_dim) A list of sequence-level embeddings (mean-pooled) for each input sequence. """ # Tokenize sequences device = model.device tokens = tokenizer( sequences_batch, return_tensors="pt", padding=True, truncation=True, add_special_tokens=True ).to(device) # Generate embeddings with torch.no_grad(): results = model(**tokens) # Extract hidden states from last layer token_embeddings = results.hidden_states[-1] # Last layer embeddings # Get sequence-level embeddings (mean pooling, excluding special tokens) sequence_embeddings = [] for i, seq in enumerate(sequences_batch): # Remove special tokens (first and last) seq_embedding = token_embeddings[i, 1:len(seq) + 1].mean(dim=0) # this might seem inefficient compared to token_embeddings[:,1:seq_len+1,:].mean... # but it is necessary to account for variable sequence lengths and ensure we only average over the actual sequence tokens, not the padding or special tokens. sequence_embeddings.append(seq_embedding.cpu().numpy()) return np.array(sequence_embeddings) def get_duration_ppl(sequences_batch, model, tokenizer, max_duration): return max_duration @spaces.GPU(duration=get_duration_ppl) def generate_ppl_scores(sequences_batch, model, tokenizer, max_duration): """Generate pseudo-perplexity scores for ESM models using batched masking across all sequences. Parameters: ----------- sequences_batch : list of str A batch of sequences for which to generate embeddings. model : AutoModel The pre-loaded ESM model. must already be on the correct device (CPU or GPU). tokenizer : AutoTokenizer The pre-loaded tokenizer corresponding to the ESM model. Returns: -------- ppl_scores : list of float A list of perplexity scores for each input sequence. """ device = model.device mask_token_id = tokenizer.mask_token_id if mask_token_id is None: raise ValueError("Tokenizer does not define a mask token; cannot compute pseudo-perplexity.") tokens = tokenizer( sequences_batch, return_tensors="pt", padding=True, truncation=True, add_special_tokens=True ).to(device) input_ids = tokens["input_ids"] attention_mask = tokens["attention_mask"] batch_size = input_ids.size(0) seq_len = input_ids.size(1) # Initialize accumulators for each sequence log_prob_sums = torch.zeros(batch_size, device=device) token_counts = torch.zeros(batch_size, device=device) # Precompute which positions to score for each sequence (exclude special tokens) positions_to_score = [] for i in range(batch_size): valid_positions = torch.nonzero(attention_mask[i], as_tuple=False).squeeze(-1) if valid_positions.numel() < 3: # Too short to score (less than 1 real token after excluding special tokens) positions_to_score.append(set()) else: # Exclude first and last positions (special tokens) positions_to_score.append(set(valid_positions[1:-1].tolist())) with torch.no_grad(): # Process one position at a time across all sequences for pos in range(1, seq_len - 1): # Find which sequences have a valid token at this position active_indices = [i for i in range(batch_size) if pos in positions_to_score[i]] if not active_indices: continue # Clone input_ids and mask the current position for all sequences masked_batch = input_ids.clone() true_token_ids = masked_batch[active_indices, pos].clone() masked_batch[active_indices, pos] = mask_token_id # Single forward pass for all sequences outputs = model(masked_batch, attention_mask=attention_mask) logits = outputs.logits # (batch_size, seq_len, vocab_size) # Extract log-probs for each active sequence at this position log_probs = torch.log_softmax(logits[active_indices, pos], dim=-1) # Gather log-probs of the true tokens true_log_probs = log_probs.gather(1, true_token_ids.unsqueeze(-1)).squeeze(-1) # Accumulate for each active sequence for idx, seq_idx in enumerate(active_indices): log_prob_sums[seq_idx] += true_log_probs[idx] token_counts[seq_idx] += 1 # Compute final pseudo-perplexity scores ppl_scores = [] for i in range(batch_size): if token_counts[i] == 0: ppl_scores.append(float("inf")) else: avg_neg_log_prob = -log_prob_sums[i] / token_counts[i] ppl_scores.append(float(torch.exp(avg_neg_log_prob).item())) return ppl_scores def get_duration_ppl_approx(sequences_batch, model, tokenizer, mask_percentage, max_duration): return max_duration @spaces.GPU(duration=get_duration_ppl_approx) def generate_ppl_scores_approx(sequences_batch, model, tokenizer, mask_percentage=0.15, max_duration=240): """Generate approximate pseudo-perplexity scores for ESM models using chunked masking. Parameters: ----------- sequences_batch : list of str A batch of sequences for which to generate embeddings. model : AutoModel The pre-loaded ESM model. must already be on the correct device (CPU or GPU). tokenizer : AutoTokenizer The pre-loaded tokenizer corresponding to the ESM model. mask_percentage : float, default=0.15 Percentage of positions to mask in each forward pass (0 < mask_percentage <= 1). Returns: -------- ppl_scores : list of float A list of approximate perplexity scores for each input sequence. """ print(sequences_batch) device = model.device mask_token_id = tokenizer.mask_token_id if mask_token_id is None: raise ValueError("Tokenizer does not define a mask token; cannot compute pseudo-perplexity.") tokens = tokenizer( sequences_batch, return_tensors="pt", padding=True, truncation=True, add_special_tokens=True ).to(device) input_ids = tokens["input_ids"] attention_mask = tokens["attention_mask"] batch_size = input_ids.size(0) seq_len = input_ids.size(1) # Initialize accumulators for each sequence log_prob_sums = torch.zeros(batch_size, device=device) token_counts = torch.zeros(batch_size, device=device) # Precompute which positions to score for each sequence (exclude special tokens) positions_to_score = [] for i in range(batch_size): valid_positions = torch.nonzero(attention_mask[i], as_tuple=False).squeeze(-1) if valid_positions.numel() < 3: positions_to_score.append([]) else: # Exclude first and last positions (special tokens) positions_to_score.append(valid_positions[1:-1].tolist()) # Calculate chunk size based on mask percentage max_positions = max(len(pos) for pos in positions_to_score) if positions_to_score else 0 if max_positions == 0: return [float("inf")] * batch_size chunk_size = max(1, int(max_positions * mask_percentage)) with torch.no_grad(): # Determine all unique positions across sequences all_positions = set() for pos_list in positions_to_score: all_positions.update(pos_list) all_positions = sorted(all_positions) # Process positions in chunks for chunk_start in range(0, len(all_positions), chunk_size): chunk_positions = all_positions[chunk_start:chunk_start + chunk_size] # Clone input_ids and mask all positions in this chunk masked_batch = input_ids.clone() # Track which sequences have tokens at positions in this chunk seq_positions = {i: [] for i in range(batch_size)} for pos in chunk_positions: for seq_idx in range(batch_size): if pos in positions_to_score[seq_idx]: seq_positions[seq_idx].append(pos) masked_batch[seq_idx, pos] = mask_token_id # Skip if no sequences have tokens in this chunk active_sequences = [i for i, pos_list in seq_positions.items() if pos_list] if not active_sequences: continue # Single forward pass for the entire batch with chunk masked outputs = model(masked_batch, attention_mask=attention_mask) logits = outputs.logits # (batch_size, seq_len, vocab_size) # Compute log-probs for each sequence and position in the chunk for seq_idx in active_sequences: for pos in seq_positions[seq_idx]: true_token_id = input_ids[seq_idx, pos] log_probs = torch.log_softmax(logits[seq_idx, pos], dim=-1) true_log_prob = log_probs[true_token_id] log_prob_sums[seq_idx] += true_log_prob token_counts[seq_idx] += 1 # Compute final pseudo-perplexity scores ppl_scores = [] for i in range(batch_size): if token_counts[i] == 0: ppl_scores.append(float("inf")) else: avg_neg_log_prob = -log_prob_sums[i] / token_counts[i] ppl_scores.append(float(torch.exp(avg_neg_log_prob).item())) print(ppl_scores) return ppl_scores def full_embedding_pipeline(fasta_files, model, tokenizer, batch_size, max_duration): """Full pipeline to process FASTA files and generate embeddings from desired model. Parameters: ----------- fasta_files : list of str, obtained from gradio file input List of paths to FASTA files to be parsed. model : AutoModel The pre-loaded ESM model. must already be on the correct device (CPU or GPU). tokenizer : AutoTokenizer The pre-loaded tokenizer corresponding to the ESM model. batch_size : int The number of sequences to process in each batch when generating embeddings. Returns: -------- all_file_paths : list of str List of file paths where the per-file embeddings were saved. To be passed to gradio for download. status_string : str A string summarizing the processing steps and output files generated, to be displayed in the gradio interface. """ # Parse FASTA files sequences_info, file_info = parse_fasta_files(fasta_files) # Generate embeddings in batches all_embeddings = [] n_batches = (len(sequences_info) + batch_size - 1) // batch_size status_string = f"Processing {len(sequences_info)} sequences from {len(file_info)} file(s) in {n_batches} batches of {batch_size} sequences...\n" for i in range(0, len(sequences_info), batch_size): batch = sequences_info[i:i + batch_size] batch_sequences = [seq for _, seq, _ in batch] embeddings = generate_embeddings(batch_sequences, model, tokenizer, max_duration) status_string += f"Generated {len(embeddings)} embeddings for batch {i // batch_size + 1}/{n_batches}\n" all_embeddings.extend(embeddings) status_string += f"Generated embeddings for all {len(sequences_info)} sequences.\n" unique_files = file_info.keys() session_hash = random.getrandbits(128) # Generate a random hash for this session time_stamp = time.strftime("%Y-%m-%d-%H-%M-%S") out_dir = f"./outputs/session_{session_hash}_{time_stamp}" os.makedirs(out_dir, exist_ok=True) all_file_paths = [] for file_name in unique_files: indices = [i for i, (_, _, f) in enumerate(sequences_info) if f == file_name] file_embeddings = np.array([all_embeddings[i] for i in indices]) sequence_ids = [sequences_info[i][0] for i in indices] # Extract sequence IDs for this file file_path = os.path.join(out_dir, f"{file_name}_embeddings.npz") np.savez_compressed(file_path, embeddings=file_embeddings, sequence_ids=sequence_ids) status_string += f"Saved compressed embeddings to {file_name}_embeddings.npz\n" all_file_paths.append(file_path) return all_file_paths, status_string def full_ppl_pipeline(fasta_files, model, tokenizer, batch_size, mask_percentage=None, max_duration=240): """Full pipeline to process FASTA files and generate embeddings from desired model. Parameters: ----------- fasta_files : list of str, obtained from gradio file input List of paths to FASTA files to be parsed. model : AutoModel The pre-loaded ESM model. must already be on the correct device (CPU or GPU). tokenizer : AutoTokenizer The pre-loaded tokenizer corresponding to the ESM model. batch_size : int The number of sequences to process in each batch when generating embeddings. mask_percentage : float or None If None, use the exact PPL calculation (masking one token at a time). If a float between 0 and 1, use the approximate chunked masking method with the specified percentage of tokens masked per forward pass. Returns: -------- all_file_paths : list of str List of file paths where the per-file embeddings were saved. To be passed to gradio for download. status_string : str A string summarizing the processing steps and output files generated, to be displayed in the gradio interface. """ # Parse FASTA files sequences_info, file_info = parse_fasta_files(fasta_files) # Generate embeddings in batches all_ppl = [] n_batches = (len(sequences_info) + batch_size - 1) // batch_size status_string = f"Processing {len(sequences_info)} sequences from {len(file_info)} file(s) in {n_batches} batches of {batch_size} sequences...\n" for i in range(0, len(sequences_info), batch_size): batch = sequences_info[i:i + batch_size] batch_sequences = [seq for _, seq, _ in batch] if mask_percentage is None: ppl_scores = generate_ppl_scores(batch_sequences, model, tokenizer, max_duration) status_string += f"Generated {len(ppl_scores)} pseudo-perplexity scores for batch {i // batch_size + 1}/{n_batches}\n" else: ppl_scores = generate_ppl_scores_approx(batch_sequences, model, tokenizer, mask_percentage=mask_percentage, max_duration=max_duration) status_string += f"Generated {len(ppl_scores)} approximate pseudo-perplexity scores for batch {i // batch_size + 1}/{n_batches} with mask percentage {mask_percentage*100:.1f}%\n" all_ppl.extend(ppl_scores) status_string += f"Generated scores for all {len(sequences_info)} sequences.\n" unique_files = file_info.keys() session_hash = random.getrandbits(128) # Generate a random hash for this session time_stamp = time.strftime("%Y-%m-%d-%H-%M-%S") out_dir = f"./outputs/session_{session_hash}_{time_stamp}" os.makedirs(out_dir, exist_ok=True) all_file_paths = [] for file_name in unique_files: indices = [i for i, (_, _, f) in enumerate(sequences_info) if f == file_name] file_path = os.path.join(out_dir, f"{file_name}_ppl.csv") rows = [] for idx in indices: description, sequence, _ = sequences_info[idx] rows.append({ "description": description, "sequence": sequence, "ppl_score": all_ppl[idx] }) df = pd.DataFrame(rows) df.to_csv(file_path, index=False) status_string += f"Saved PPL scores to {file_name}_ppl.csv\n" all_file_paths.append(file_path) lowest_ppl = min(all_ppl) status_string += f"Lowest PPL score across all sequences: {lowest_ppl:.4f}:\n for sequence in file {sequences_info[all_ppl.index(lowest_ppl)][2]}:\n" status_string += f">{sequences_info[all_ppl.index(lowest_ppl)][0]}\n" status_string += f"{sequences_info[all_ppl.index(lowest_ppl)][1]}\n" return all_file_paths, status_string