Spaces:
Running
Running
| import spaces | |
| import torch | |
| import numpy as np | |
| from utils.handle_files import parse_fasta_files | |
| import gradio as gr | |
| import time | |
| import random | |
| import os | |
| import pandas as pd | |
| def get_duration_embeddings(sequences_batch, model, tokenizer, max_duration): | |
| return max_duration | |
| def generate_embeddings(sequences_batch, model, tokenizer, max_duration): | |
| """Generate embeddings for ESM models using the transformers library. | |
| Parameters: | |
| ----------- | |
| sequences_batch : list of str | |
| A batch of sequences for which to generate embeddings. | |
| model : AutoModel | |
| The pre-loaded ESM model. must already be on the correct device (CPU or GPU). | |
| tokenizer : AutoTokenizer | |
| The pre-loaded tokenizer corresponding to the ESM model. | |
| Returns: | |
| -------- | |
| sequence_embeddings : 2D np.array of shape (batch_size, embedding_dim) | |
| A list of sequence-level embeddings (mean-pooled) for each input sequence. | |
| """ | |
| # Tokenize sequences | |
| device = model.device | |
| tokens = tokenizer( | |
| sequences_batch, | |
| return_tensors="pt", | |
| padding=True, | |
| truncation=True, | |
| add_special_tokens=True | |
| ).to(device) | |
| # Generate embeddings | |
| with torch.no_grad(): | |
| results = model(**tokens) | |
| # Extract hidden states from last layer | |
| token_embeddings = results.hidden_states[-1] # Last layer embeddings | |
| # Get sequence-level embeddings (mean pooling, excluding special tokens) | |
| sequence_embeddings = [] | |
| for i, seq in enumerate(sequences_batch): | |
| # Remove special tokens (first and last) | |
| seq_embedding = token_embeddings[i, 1:len(seq) + 1].mean(dim=0) | |
| # this might seem inefficient compared to token_embeddings[:,1:seq_len+1,:].mean... | |
| # but it is necessary to account for variable sequence lengths and ensure we only average over the actual sequence tokens, not the padding or special tokens. | |
| sequence_embeddings.append(seq_embedding.cpu().numpy()) | |
| return np.array(sequence_embeddings) | |
| def get_duration_ppl(sequences_batch, model, tokenizer, max_duration): | |
| return max_duration | |
| def generate_ppl_scores(sequences_batch, model, tokenizer, max_duration): | |
| """Generate pseudo-perplexity scores for ESM models using batched masking across all sequences. | |
| Parameters: | |
| ----------- | |
| sequences_batch : list of str | |
| A batch of sequences for which to generate embeddings. | |
| model : AutoModel | |
| The pre-loaded ESM model. must already be on the correct device (CPU or GPU). | |
| tokenizer : AutoTokenizer | |
| The pre-loaded tokenizer corresponding to the ESM model. | |
| Returns: | |
| -------- | |
| ppl_scores : list of float | |
| A list of perplexity scores for each input sequence. | |
| """ | |
| device = model.device | |
| mask_token_id = tokenizer.mask_token_id | |
| if mask_token_id is None: | |
| raise ValueError("Tokenizer does not define a mask token; cannot compute pseudo-perplexity.") | |
| tokens = tokenizer( | |
| sequences_batch, | |
| return_tensors="pt", | |
| padding=True, | |
| truncation=True, | |
| add_special_tokens=True | |
| ).to(device) | |
| input_ids = tokens["input_ids"] | |
| attention_mask = tokens["attention_mask"] | |
| batch_size = input_ids.size(0) | |
| seq_len = input_ids.size(1) | |
| # Initialize accumulators for each sequence | |
| log_prob_sums = torch.zeros(batch_size, device=device) | |
| token_counts = torch.zeros(batch_size, device=device) | |
| # Precompute which positions to score for each sequence (exclude special tokens) | |
| positions_to_score = [] | |
| for i in range(batch_size): | |
| valid_positions = torch.nonzero(attention_mask[i], as_tuple=False).squeeze(-1) | |
| if valid_positions.numel() < 3: | |
| # Too short to score (less than 1 real token after excluding special tokens) | |
| positions_to_score.append(set()) | |
| else: | |
| # Exclude first and last positions (special tokens) | |
| positions_to_score.append(set(valid_positions[1:-1].tolist())) | |
| with torch.no_grad(): | |
| # Process one position at a time across all sequences | |
| for pos in range(1, seq_len - 1): | |
| # Find which sequences have a valid token at this position | |
| active_indices = [i for i in range(batch_size) if pos in positions_to_score[i]] | |
| if not active_indices: | |
| continue | |
| # Clone input_ids and mask the current position for all sequences | |
| masked_batch = input_ids.clone() | |
| true_token_ids = masked_batch[active_indices, pos].clone() | |
| masked_batch[active_indices, pos] = mask_token_id | |
| # Single forward pass for all sequences | |
| outputs = model(masked_batch, attention_mask=attention_mask) | |
| logits = outputs.logits # (batch_size, seq_len, vocab_size) | |
| # Extract log-probs for each active sequence at this position | |
| log_probs = torch.log_softmax(logits[active_indices, pos], dim=-1) | |
| # Gather log-probs of the true tokens | |
| true_log_probs = log_probs.gather(1, true_token_ids.unsqueeze(-1)).squeeze(-1) | |
| # Accumulate for each active sequence | |
| for idx, seq_idx in enumerate(active_indices): | |
| log_prob_sums[seq_idx] += true_log_probs[idx] | |
| token_counts[seq_idx] += 1 | |
| # Compute final pseudo-perplexity scores | |
| ppl_scores = [] | |
| for i in range(batch_size): | |
| if token_counts[i] == 0: | |
| ppl_scores.append(float("inf")) | |
| else: | |
| avg_neg_log_prob = -log_prob_sums[i] / token_counts[i] | |
| ppl_scores.append(float(torch.exp(avg_neg_log_prob).item())) | |
| return ppl_scores | |
| def get_duration_ppl_approx(sequences_batch, model, tokenizer, mask_percentage, max_duration): | |
| return max_duration | |
| def generate_ppl_scores_approx(sequences_batch, model, tokenizer, mask_percentage=0.15, max_duration=240): | |
| """Generate approximate pseudo-perplexity scores for ESM models using chunked masking. | |
| Parameters: | |
| ----------- | |
| sequences_batch : list of str | |
| A batch of sequences for which to generate embeddings. | |
| model : AutoModel | |
| The pre-loaded ESM model. must already be on the correct device (CPU or GPU). | |
| tokenizer : AutoTokenizer | |
| The pre-loaded tokenizer corresponding to the ESM model. | |
| mask_percentage : float, default=0.15 | |
| Percentage of positions to mask in each forward pass (0 < mask_percentage <= 1). | |
| Returns: | |
| -------- | |
| ppl_scores : list of float | |
| A list of approximate perplexity scores for each input sequence. | |
| """ | |
| print(sequences_batch) | |
| device = model.device | |
| mask_token_id = tokenizer.mask_token_id | |
| if mask_token_id is None: | |
| raise ValueError("Tokenizer does not define a mask token; cannot compute pseudo-perplexity.") | |
| tokens = tokenizer( | |
| sequences_batch, | |
| return_tensors="pt", | |
| padding=True, | |
| truncation=True, | |
| add_special_tokens=True | |
| ).to(device) | |
| input_ids = tokens["input_ids"] | |
| attention_mask = tokens["attention_mask"] | |
| batch_size = input_ids.size(0) | |
| seq_len = input_ids.size(1) | |
| # Initialize accumulators for each sequence | |
| log_prob_sums = torch.zeros(batch_size, device=device) | |
| token_counts = torch.zeros(batch_size, device=device) | |
| # Precompute which positions to score for each sequence (exclude special tokens) | |
| positions_to_score = [] | |
| for i in range(batch_size): | |
| valid_positions = torch.nonzero(attention_mask[i], as_tuple=False).squeeze(-1) | |
| if valid_positions.numel() < 3: | |
| positions_to_score.append([]) | |
| else: | |
| # Exclude first and last positions (special tokens) | |
| positions_to_score.append(valid_positions[1:-1].tolist()) | |
| # Calculate chunk size based on mask percentage | |
| max_positions = max(len(pos) for pos in positions_to_score) if positions_to_score else 0 | |
| if max_positions == 0: | |
| return [float("inf")] * batch_size | |
| chunk_size = max(1, int(max_positions * mask_percentage)) | |
| with torch.no_grad(): | |
| # Determine all unique positions across sequences | |
| all_positions = set() | |
| for pos_list in positions_to_score: | |
| all_positions.update(pos_list) | |
| all_positions = sorted(all_positions) | |
| # Process positions in chunks | |
| for chunk_start in range(0, len(all_positions), chunk_size): | |
| chunk_positions = all_positions[chunk_start:chunk_start + chunk_size] | |
| # Clone input_ids and mask all positions in this chunk | |
| masked_batch = input_ids.clone() | |
| # Track which sequences have tokens at positions in this chunk | |
| seq_positions = {i: [] for i in range(batch_size)} | |
| for pos in chunk_positions: | |
| for seq_idx in range(batch_size): | |
| if pos in positions_to_score[seq_idx]: | |
| seq_positions[seq_idx].append(pos) | |
| masked_batch[seq_idx, pos] = mask_token_id | |
| # Skip if no sequences have tokens in this chunk | |
| active_sequences = [i for i, pos_list in seq_positions.items() if pos_list] | |
| if not active_sequences: | |
| continue | |
| # Single forward pass for the entire batch with chunk masked | |
| outputs = model(masked_batch, attention_mask=attention_mask) | |
| logits = outputs.logits # (batch_size, seq_len, vocab_size) | |
| # Compute log-probs for each sequence and position in the chunk | |
| for seq_idx in active_sequences: | |
| for pos in seq_positions[seq_idx]: | |
| true_token_id = input_ids[seq_idx, pos] | |
| log_probs = torch.log_softmax(logits[seq_idx, pos], dim=-1) | |
| true_log_prob = log_probs[true_token_id] | |
| log_prob_sums[seq_idx] += true_log_prob | |
| token_counts[seq_idx] += 1 | |
| # Compute final pseudo-perplexity scores | |
| ppl_scores = [] | |
| for i in range(batch_size): | |
| if token_counts[i] == 0: | |
| ppl_scores.append(float("inf")) | |
| else: | |
| avg_neg_log_prob = -log_prob_sums[i] / token_counts[i] | |
| ppl_scores.append(float(torch.exp(avg_neg_log_prob).item())) | |
| print(ppl_scores) | |
| return ppl_scores | |
| def full_embedding_pipeline(fasta_files, model, tokenizer, batch_size, max_duration): | |
| """Full pipeline to process FASTA files and generate embeddings from desired model. | |
| Parameters: | |
| ----------- | |
| fasta_files : list of str, obtained from gradio file input | |
| List of paths to FASTA files to be parsed. | |
| model : AutoModel | |
| The pre-loaded ESM model. must already be on the correct device (CPU or GPU). | |
| tokenizer : AutoTokenizer | |
| The pre-loaded tokenizer corresponding to the ESM model. | |
| batch_size : int | |
| The number of sequences to process in each batch when generating embeddings. | |
| Returns: | |
| -------- | |
| all_file_paths : list of str | |
| List of file paths where the per-file embeddings were saved. To be passed to gradio for download. | |
| status_string : str | |
| A string summarizing the processing steps and output files generated, to be displayed in the gradio interface. | |
| """ | |
| # Parse FASTA files | |
| sequences_info, file_info = parse_fasta_files(fasta_files) | |
| # Generate embeddings in batches | |
| all_embeddings = [] | |
| n_batches = (len(sequences_info) + batch_size - 1) // batch_size | |
| status_string = f"Processing {len(sequences_info)} sequences from {len(file_info)} file(s) in {n_batches} batches of {batch_size} sequences...\n" | |
| for i in range(0, len(sequences_info), batch_size): | |
| batch = sequences_info[i:i + batch_size] | |
| batch_sequences = [seq for _, seq, _ in batch] | |
| embeddings = generate_embeddings(batch_sequences, model, tokenizer, max_duration) | |
| status_string += f"Generated {len(embeddings)} embeddings for batch {i // batch_size + 1}/{n_batches}\n" | |
| all_embeddings.extend(embeddings) | |
| status_string += f"Generated embeddings for all {len(sequences_info)} sequences.\n" | |
| unique_files = file_info.keys() | |
| session_hash = random.getrandbits(128) # Generate a random hash for this session | |
| time_stamp = time.strftime("%Y-%m-%d-%H-%M-%S") | |
| out_dir = f"./outputs/session_{session_hash}_{time_stamp}" | |
| os.makedirs(out_dir, exist_ok=True) | |
| all_file_paths = [] | |
| for file_name in unique_files: | |
| indices = [i for i, (_, _, f) in enumerate(sequences_info) if f == file_name] | |
| file_embeddings = np.array([all_embeddings[i] for i in indices]) | |
| sequence_ids = [sequences_info[i][0] for i in indices] # Extract sequence IDs for this file | |
| file_path = os.path.join(out_dir, f"{file_name}_embeddings.npz") | |
| np.savez_compressed(file_path, embeddings=file_embeddings, sequence_ids=sequence_ids) | |
| status_string += f"Saved compressed embeddings to {file_name}_embeddings.npz\n" | |
| all_file_paths.append(file_path) | |
| return all_file_paths, status_string | |
| def full_ppl_pipeline(fasta_files, model, tokenizer, batch_size, mask_percentage=None, max_duration=240): | |
| """Full pipeline to process FASTA files and generate embeddings from desired model. | |
| Parameters: | |
| ----------- | |
| fasta_files : list of str, obtained from gradio file input | |
| List of paths to FASTA files to be parsed. | |
| model : AutoModel | |
| The pre-loaded ESM model. must already be on the correct device (CPU or GPU). | |
| tokenizer : AutoTokenizer | |
| The pre-loaded tokenizer corresponding to the ESM model. | |
| batch_size : int | |
| The number of sequences to process in each batch when generating embeddings. | |
| mask_percentage : float or None | |
| If None, use the exact PPL calculation (masking one token at a time). If a float between 0 and 1, use the approximate chunked masking method with the specified percentage of tokens masked per forward pass. | |
| Returns: | |
| -------- | |
| all_file_paths : list of str | |
| List of file paths where the per-file embeddings were saved. To be passed to gradio for download. | |
| status_string : str | |
| A string summarizing the processing steps and output files generated, to be displayed in the gradio interface. | |
| """ | |
| # Parse FASTA files | |
| sequences_info, file_info = parse_fasta_files(fasta_files) | |
| # Generate embeddings in batches | |
| all_ppl = [] | |
| n_batches = (len(sequences_info) + batch_size - 1) // batch_size | |
| status_string = f"Processing {len(sequences_info)} sequences from {len(file_info)} file(s) in {n_batches} batches of {batch_size} sequences...\n" | |
| for i in range(0, len(sequences_info), batch_size): | |
| batch = sequences_info[i:i + batch_size] | |
| batch_sequences = [seq for _, seq, _ in batch] | |
| if mask_percentage is None: | |
| ppl_scores = generate_ppl_scores(batch_sequences, model, tokenizer, max_duration) | |
| status_string += f"Generated {len(ppl_scores)} pseudo-perplexity scores for batch {i // batch_size + 1}/{n_batches}\n" | |
| else: | |
| ppl_scores = generate_ppl_scores_approx(batch_sequences, model, tokenizer, mask_percentage=mask_percentage, max_duration=max_duration) | |
| status_string += f"Generated {len(ppl_scores)} approximate pseudo-perplexity scores for batch {i // batch_size + 1}/{n_batches} with mask percentage {mask_percentage*100:.1f}%\n" | |
| all_ppl.extend(ppl_scores) | |
| status_string += f"Generated scores for all {len(sequences_info)} sequences.\n" | |
| unique_files = file_info.keys() | |
| session_hash = random.getrandbits(128) # Generate a random hash for this session | |
| time_stamp = time.strftime("%Y-%m-%d-%H-%M-%S") | |
| out_dir = f"./outputs/session_{session_hash}_{time_stamp}" | |
| os.makedirs(out_dir, exist_ok=True) | |
| all_file_paths = [] | |
| for file_name in unique_files: | |
| indices = [i for i, (_, _, f) in enumerate(sequences_info) if f == file_name] | |
| file_path = os.path.join(out_dir, f"{file_name}_ppl.csv") | |
| rows = [] | |
| for idx in indices: | |
| description, sequence, _ = sequences_info[idx] | |
| rows.append({ | |
| "description": description, | |
| "sequence": sequence, | |
| "ppl_score": all_ppl[idx] | |
| }) | |
| df = pd.DataFrame(rows) | |
| df.to_csv(file_path, index=False) | |
| status_string += f"Saved PPL scores to {file_name}_ppl.csv\n" | |
| all_file_paths.append(file_path) | |
| lowest_ppl = min(all_ppl) | |
| status_string += f"Lowest PPL score across all sequences: {lowest_ppl:.4f}:\n for sequence in file {sequences_info[all_ppl.index(lowest_ppl)][2]}:\n" | |
| status_string += f">{sequences_info[all_ppl.index(lowest_ppl)][0]}\n" | |
| status_string += f"{sequences_info[all_ppl.index(lowest_ppl)][1]}\n" | |
| return all_file_paths, status_string | |