ESM2 / utils /pipelines.py
gabboud's picture
print for debugging
100843b
import spaces
import torch
import numpy as np
from utils.handle_files import parse_fasta_files
import gradio as gr
import time
import random
import os
import pandas as pd
def get_duration_embeddings(sequences_batch, model, tokenizer, max_duration):
return max_duration
@spaces.GPU(duration=get_duration_embeddings)
def generate_embeddings(sequences_batch, model, tokenizer, max_duration):
"""Generate embeddings for ESM models using the transformers library.
Parameters:
-----------
sequences_batch : list of str
A batch of sequences for which to generate embeddings.
model : AutoModel
The pre-loaded ESM model. must already be on the correct device (CPU or GPU).
tokenizer : AutoTokenizer
The pre-loaded tokenizer corresponding to the ESM model.
Returns:
--------
sequence_embeddings : 2D np.array of shape (batch_size, embedding_dim)
A list of sequence-level embeddings (mean-pooled) for each input sequence.
"""
# Tokenize sequences
device = model.device
tokens = tokenizer(
sequences_batch,
return_tensors="pt",
padding=True,
truncation=True,
add_special_tokens=True
).to(device)
# Generate embeddings
with torch.no_grad():
results = model(**tokens)
# Extract hidden states from last layer
token_embeddings = results.hidden_states[-1] # Last layer embeddings
# Get sequence-level embeddings (mean pooling, excluding special tokens)
sequence_embeddings = []
for i, seq in enumerate(sequences_batch):
# Remove special tokens (first and last)
seq_embedding = token_embeddings[i, 1:len(seq) + 1].mean(dim=0)
# this might seem inefficient compared to token_embeddings[:,1:seq_len+1,:].mean...
# but it is necessary to account for variable sequence lengths and ensure we only average over the actual sequence tokens, not the padding or special tokens.
sequence_embeddings.append(seq_embedding.cpu().numpy())
return np.array(sequence_embeddings)
def get_duration_ppl(sequences_batch, model, tokenizer, max_duration):
return max_duration
@spaces.GPU(duration=get_duration_ppl)
def generate_ppl_scores(sequences_batch, model, tokenizer, max_duration):
"""Generate pseudo-perplexity scores for ESM models using batched masking across all sequences.
Parameters:
-----------
sequences_batch : list of str
A batch of sequences for which to generate embeddings.
model : AutoModel
The pre-loaded ESM model. must already be on the correct device (CPU or GPU).
tokenizer : AutoTokenizer
The pre-loaded tokenizer corresponding to the ESM model.
Returns:
--------
ppl_scores : list of float
A list of perplexity scores for each input sequence.
"""
device = model.device
mask_token_id = tokenizer.mask_token_id
if mask_token_id is None:
raise ValueError("Tokenizer does not define a mask token; cannot compute pseudo-perplexity.")
tokens = tokenizer(
sequences_batch,
return_tensors="pt",
padding=True,
truncation=True,
add_special_tokens=True
).to(device)
input_ids = tokens["input_ids"]
attention_mask = tokens["attention_mask"]
batch_size = input_ids.size(0)
seq_len = input_ids.size(1)
# Initialize accumulators for each sequence
log_prob_sums = torch.zeros(batch_size, device=device)
token_counts = torch.zeros(batch_size, device=device)
# Precompute which positions to score for each sequence (exclude special tokens)
positions_to_score = []
for i in range(batch_size):
valid_positions = torch.nonzero(attention_mask[i], as_tuple=False).squeeze(-1)
if valid_positions.numel() < 3:
# Too short to score (less than 1 real token after excluding special tokens)
positions_to_score.append(set())
else:
# Exclude first and last positions (special tokens)
positions_to_score.append(set(valid_positions[1:-1].tolist()))
with torch.no_grad():
# Process one position at a time across all sequences
for pos in range(1, seq_len - 1):
# Find which sequences have a valid token at this position
active_indices = [i for i in range(batch_size) if pos in positions_to_score[i]]
if not active_indices:
continue
# Clone input_ids and mask the current position for all sequences
masked_batch = input_ids.clone()
true_token_ids = masked_batch[active_indices, pos].clone()
masked_batch[active_indices, pos] = mask_token_id
# Single forward pass for all sequences
outputs = model(masked_batch, attention_mask=attention_mask)
logits = outputs.logits # (batch_size, seq_len, vocab_size)
# Extract log-probs for each active sequence at this position
log_probs = torch.log_softmax(logits[active_indices, pos], dim=-1)
# Gather log-probs of the true tokens
true_log_probs = log_probs.gather(1, true_token_ids.unsqueeze(-1)).squeeze(-1)
# Accumulate for each active sequence
for idx, seq_idx in enumerate(active_indices):
log_prob_sums[seq_idx] += true_log_probs[idx]
token_counts[seq_idx] += 1
# Compute final pseudo-perplexity scores
ppl_scores = []
for i in range(batch_size):
if token_counts[i] == 0:
ppl_scores.append(float("inf"))
else:
avg_neg_log_prob = -log_prob_sums[i] / token_counts[i]
ppl_scores.append(float(torch.exp(avg_neg_log_prob).item()))
return ppl_scores
def get_duration_ppl_approx(sequences_batch, model, tokenizer, mask_percentage, max_duration):
return max_duration
@spaces.GPU(duration=get_duration_ppl_approx)
def generate_ppl_scores_approx(sequences_batch, model, tokenizer, mask_percentage=0.15, max_duration=240):
"""Generate approximate pseudo-perplexity scores for ESM models using chunked masking.
Parameters:
-----------
sequences_batch : list of str
A batch of sequences for which to generate embeddings.
model : AutoModel
The pre-loaded ESM model. must already be on the correct device (CPU or GPU).
tokenizer : AutoTokenizer
The pre-loaded tokenizer corresponding to the ESM model.
mask_percentage : float, default=0.15
Percentage of positions to mask in each forward pass (0 < mask_percentage <= 1).
Returns:
--------
ppl_scores : list of float
A list of approximate perplexity scores for each input sequence.
"""
print(sequences_batch)
device = model.device
mask_token_id = tokenizer.mask_token_id
if mask_token_id is None:
raise ValueError("Tokenizer does not define a mask token; cannot compute pseudo-perplexity.")
tokens = tokenizer(
sequences_batch,
return_tensors="pt",
padding=True,
truncation=True,
add_special_tokens=True
).to(device)
input_ids = tokens["input_ids"]
attention_mask = tokens["attention_mask"]
batch_size = input_ids.size(0)
seq_len = input_ids.size(1)
# Initialize accumulators for each sequence
log_prob_sums = torch.zeros(batch_size, device=device)
token_counts = torch.zeros(batch_size, device=device)
# Precompute which positions to score for each sequence (exclude special tokens)
positions_to_score = []
for i in range(batch_size):
valid_positions = torch.nonzero(attention_mask[i], as_tuple=False).squeeze(-1)
if valid_positions.numel() < 3:
positions_to_score.append([])
else:
# Exclude first and last positions (special tokens)
positions_to_score.append(valid_positions[1:-1].tolist())
# Calculate chunk size based on mask percentage
max_positions = max(len(pos) for pos in positions_to_score) if positions_to_score else 0
if max_positions == 0:
return [float("inf")] * batch_size
chunk_size = max(1, int(max_positions * mask_percentage))
with torch.no_grad():
# Determine all unique positions across sequences
all_positions = set()
for pos_list in positions_to_score:
all_positions.update(pos_list)
all_positions = sorted(all_positions)
# Process positions in chunks
for chunk_start in range(0, len(all_positions), chunk_size):
chunk_positions = all_positions[chunk_start:chunk_start + chunk_size]
# Clone input_ids and mask all positions in this chunk
masked_batch = input_ids.clone()
# Track which sequences have tokens at positions in this chunk
seq_positions = {i: [] for i in range(batch_size)}
for pos in chunk_positions:
for seq_idx in range(batch_size):
if pos in positions_to_score[seq_idx]:
seq_positions[seq_idx].append(pos)
masked_batch[seq_idx, pos] = mask_token_id
# Skip if no sequences have tokens in this chunk
active_sequences = [i for i, pos_list in seq_positions.items() if pos_list]
if not active_sequences:
continue
# Single forward pass for the entire batch with chunk masked
outputs = model(masked_batch, attention_mask=attention_mask)
logits = outputs.logits # (batch_size, seq_len, vocab_size)
# Compute log-probs for each sequence and position in the chunk
for seq_idx in active_sequences:
for pos in seq_positions[seq_idx]:
true_token_id = input_ids[seq_idx, pos]
log_probs = torch.log_softmax(logits[seq_idx, pos], dim=-1)
true_log_prob = log_probs[true_token_id]
log_prob_sums[seq_idx] += true_log_prob
token_counts[seq_idx] += 1
# Compute final pseudo-perplexity scores
ppl_scores = []
for i in range(batch_size):
if token_counts[i] == 0:
ppl_scores.append(float("inf"))
else:
avg_neg_log_prob = -log_prob_sums[i] / token_counts[i]
ppl_scores.append(float(torch.exp(avg_neg_log_prob).item()))
print(ppl_scores)
return ppl_scores
def full_embedding_pipeline(fasta_files, model, tokenizer, batch_size, max_duration):
"""Full pipeline to process FASTA files and generate embeddings from desired model.
Parameters:
-----------
fasta_files : list of str, obtained from gradio file input
List of paths to FASTA files to be parsed.
model : AutoModel
The pre-loaded ESM model. must already be on the correct device (CPU or GPU).
tokenizer : AutoTokenizer
The pre-loaded tokenizer corresponding to the ESM model.
batch_size : int
The number of sequences to process in each batch when generating embeddings.
Returns:
--------
all_file_paths : list of str
List of file paths where the per-file embeddings were saved. To be passed to gradio for download.
status_string : str
A string summarizing the processing steps and output files generated, to be displayed in the gradio interface.
"""
# Parse FASTA files
sequences_info, file_info = parse_fasta_files(fasta_files)
# Generate embeddings in batches
all_embeddings = []
n_batches = (len(sequences_info) + batch_size - 1) // batch_size
status_string = f"Processing {len(sequences_info)} sequences from {len(file_info)} file(s) in {n_batches} batches of {batch_size} sequences...\n"
for i in range(0, len(sequences_info), batch_size):
batch = sequences_info[i:i + batch_size]
batch_sequences = [seq for _, seq, _ in batch]
embeddings = generate_embeddings(batch_sequences, model, tokenizer, max_duration)
status_string += f"Generated {len(embeddings)} embeddings for batch {i // batch_size + 1}/{n_batches}\n"
all_embeddings.extend(embeddings)
status_string += f"Generated embeddings for all {len(sequences_info)} sequences.\n"
unique_files = file_info.keys()
session_hash = random.getrandbits(128) # Generate a random hash for this session
time_stamp = time.strftime("%Y-%m-%d-%H-%M-%S")
out_dir = f"./outputs/session_{session_hash}_{time_stamp}"
os.makedirs(out_dir, exist_ok=True)
all_file_paths = []
for file_name in unique_files:
indices = [i for i, (_, _, f) in enumerate(sequences_info) if f == file_name]
file_embeddings = np.array([all_embeddings[i] for i in indices])
sequence_ids = [sequences_info[i][0] for i in indices] # Extract sequence IDs for this file
file_path = os.path.join(out_dir, f"{file_name}_embeddings.npz")
np.savez_compressed(file_path, embeddings=file_embeddings, sequence_ids=sequence_ids)
status_string += f"Saved compressed embeddings to {file_name}_embeddings.npz\n"
all_file_paths.append(file_path)
return all_file_paths, status_string
def full_ppl_pipeline(fasta_files, model, tokenizer, batch_size, mask_percentage=None, max_duration=240):
"""Full pipeline to process FASTA files and generate embeddings from desired model.
Parameters:
-----------
fasta_files : list of str, obtained from gradio file input
List of paths to FASTA files to be parsed.
model : AutoModel
The pre-loaded ESM model. must already be on the correct device (CPU or GPU).
tokenizer : AutoTokenizer
The pre-loaded tokenizer corresponding to the ESM model.
batch_size : int
The number of sequences to process in each batch when generating embeddings.
mask_percentage : float or None
If None, use the exact PPL calculation (masking one token at a time). If a float between 0 and 1, use the approximate chunked masking method with the specified percentage of tokens masked per forward pass.
Returns:
--------
all_file_paths : list of str
List of file paths where the per-file embeddings were saved. To be passed to gradio for download.
status_string : str
A string summarizing the processing steps and output files generated, to be displayed in the gradio interface.
"""
# Parse FASTA files
sequences_info, file_info = parse_fasta_files(fasta_files)
# Generate embeddings in batches
all_ppl = []
n_batches = (len(sequences_info) + batch_size - 1) // batch_size
status_string = f"Processing {len(sequences_info)} sequences from {len(file_info)} file(s) in {n_batches} batches of {batch_size} sequences...\n"
for i in range(0, len(sequences_info), batch_size):
batch = sequences_info[i:i + batch_size]
batch_sequences = [seq for _, seq, _ in batch]
if mask_percentage is None:
ppl_scores = generate_ppl_scores(batch_sequences, model, tokenizer, max_duration)
status_string += f"Generated {len(ppl_scores)} pseudo-perplexity scores for batch {i // batch_size + 1}/{n_batches}\n"
else:
ppl_scores = generate_ppl_scores_approx(batch_sequences, model, tokenizer, mask_percentage=mask_percentage, max_duration=max_duration)
status_string += f"Generated {len(ppl_scores)} approximate pseudo-perplexity scores for batch {i // batch_size + 1}/{n_batches} with mask percentage {mask_percentage*100:.1f}%\n"
all_ppl.extend(ppl_scores)
status_string += f"Generated scores for all {len(sequences_info)} sequences.\n"
unique_files = file_info.keys()
session_hash = random.getrandbits(128) # Generate a random hash for this session
time_stamp = time.strftime("%Y-%m-%d-%H-%M-%S")
out_dir = f"./outputs/session_{session_hash}_{time_stamp}"
os.makedirs(out_dir, exist_ok=True)
all_file_paths = []
for file_name in unique_files:
indices = [i for i, (_, _, f) in enumerate(sequences_info) if f == file_name]
file_path = os.path.join(out_dir, f"{file_name}_ppl.csv")
rows = []
for idx in indices:
description, sequence, _ = sequences_info[idx]
rows.append({
"description": description,
"sequence": sequence,
"ppl_score": all_ppl[idx]
})
df = pd.DataFrame(rows)
df.to_csv(file_path, index=False)
status_string += f"Saved PPL scores to {file_name}_ppl.csv\n"
all_file_paths.append(file_path)
lowest_ppl = min(all_ppl)
status_string += f"Lowest PPL score across all sequences: {lowest_ppl:.4f}:\n for sequence in file {sequences_info[all_ppl.index(lowest_ppl)][2]}:\n"
status_string += f">{sequences_info[all_ppl.index(lowest_ppl)][0]}\n"
status_string += f"{sequences_info[all_ppl.index(lowest_ppl)][1]}\n"
return all_file_paths, status_string