gabboud commited on
Commit
ae38197
·
1 Parent(s): a749e8f

fix locally and implement PPL

Browse files
Files changed (4) hide show
  1. app.py +57 -42
  2. requirements.txt +1 -0
  3. utils/download_models.py +4 -3
  4. utils/pipelines.py +164 -3
app.py CHANGED
@@ -10,7 +10,7 @@ import zipfile
10
  import spaces
11
  from utils.download_models import *
12
  from utils.handle_files import parse_fasta_files
13
- from utils.pipelines import generate_embeddings, full_embedding_pipeline
14
 
15
  print("Downloading ESM2 models...")
16
 
@@ -49,32 +49,10 @@ with gr.Blocks(title="ESM2 Protein Embeddings") as demo:
49
  - `embeddings_[filename].npz`: Per-file embeddings
50
  """)
51
 
52
- with gr.Row():
53
- with gr.Column():
54
- input_files = gr.File(
55
- label="Upload FASTA files",
56
- file_count="multiple",
57
- file_types=[".fasta", ".fa", ".faa"]
58
- )
59
- submit_btn = gr.Button("Generate Embeddings", variant="primary", size="lg")
60
-
61
- with gr.Column():
62
- status_output = gr.Textbox(
63
- label="Processing Status",
64
- interactive=False,
65
- lines=6
66
- )
67
-
68
- with gr.Row():
69
- download_output = gr.File(
70
- label="Download Output Files",
71
- file_count="multiple"
72
- )
73
-
74
  with gr.Row():
75
  model_dropdown = gr.Dropdown(
76
- choices=list(MODELS.values()),
77
- value=list(MODELS.values())[0],
78
  label="Select Model"
79
  )
80
  batch_size = gr.Slider(
@@ -84,29 +62,66 @@ with gr.Blocks(title="ESM2 Protein Embeddings") as demo:
84
  value=32,
85
  label="Batch Size"
86
  )
 
 
 
 
 
 
 
 
87
 
88
-
89
- current_key = [key for key, value in MODELS.items() if value == model_dropdown.value][0]
90
- model_to_use = gr.State(value=models_and_tokenizers[current_key][0])
91
- tokenizer_to_use = gr.State(value=models_and_tokenizers[current_key][1])
92
-
93
- def pick_model(model_name):
94
- model_key = [key for key, value in MODELS.items() if value == model_name][0]
95
- print(f"Selected model: {model_name} ({model_key})")
96
- return models_and_tokenizers[model_key]
97
-
98
- model_dropdown.change(
99
- fn=pick_model,
100
- inputs=model_dropdown,
101
- outputs=[model_to_use, tokenizer_to_use]
102
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
 
104
  submit_btn.click(
105
- fn=full_embedding_pipeline,
106
- inputs=[input_files, model_to_use, tokenizer_to_use, batch_size],
107
  outputs=[download_output, status_output]
108
  )
109
 
 
 
 
 
 
 
110
 
111
 
112
  gr.Markdown("""
 
10
  import spaces
11
  from utils.download_models import *
12
  from utils.handle_files import parse_fasta_files
13
+ from utils.pipelines import generate_embeddings, full_embedding_pipeline, full_ppl_pipeline
14
 
15
  print("Downloading ESM2 models...")
16
 
 
49
  - `embeddings_[filename].npz`: Per-file embeddings
50
  """)
51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  with gr.Row():
53
  model_dropdown = gr.Dropdown(
54
+ choices=list(MODELS.keys()),
55
+ value=list(MODELS.keys())[0],
56
  label="Select Model"
57
  )
58
  batch_size = gr.Slider(
 
62
  value=32,
63
  label="Batch Size"
64
  )
65
+
66
+ with gr.Row():
67
+ with gr.Column():
68
+ input_files = gr.File(
69
+ label="Upload FASTA files",
70
+ file_count="multiple",
71
+ file_types=[".fasta", ".fa", ".faa"]
72
+ )
73
 
74
+
75
+ with gr.Column():
76
+ with gr.Tabs():
77
+ with gr.TabItem("Generate Embeddings"):
78
+ submit_btn = gr.Button("Generate Embeddings", variant="primary", size="lg")
79
+ status_output = gr.Textbox(
80
+ label="Waiting for embeddings generation...",
81
+ interactive=False,
82
+ lines=6
83
+ )
84
+
85
+ download_output = gr.File(
86
+ label="Download Output Files",
87
+ file_count="multiple"
88
+ )
89
+ with gr.TabItem("Calculate Pseudo-Perplexity scores"):
90
+ ppl_button = gr.Button("Calculate Pseudo-Perplexity", variant="primary", size="lg")
91
+ ppl_status = gr.Textbox(
92
+ label="Waiting for pseudo-perplexity calculation...",
93
+ interactive=False,
94
+ lines=6
95
+ )
96
+ ppl_download = gr.File(
97
+ label="Download Pseudo-Perplexity Output",
98
+ file_count="multiple"
99
+ )
100
+
101
+
102
+ def run_pipeline_with_selected_model(fasta_files, model_key, batch_size_value, task="embedding"):
103
+ """Wrapper to run pipeline with selected model from dropdown."""
104
+
105
+ if not fasta_files:
106
+ return gr.update(), "No FASTA files uploaded. Please upload at least one FASTA file for inference."
107
+ model, tokenizer = models_and_tokenizers[model_key]
108
+ if task == "embedding":
109
+ return full_embedding_pipeline(fasta_files, model, tokenizer, batch_size_value)
110
+ elif task == "ppl":
111
+ return full_ppl_pipeline(fasta_files, model, tokenizer, batch_size_value)
112
 
113
  submit_btn.click(
114
+ fn=run_pipeline_with_selected_model,
115
+ inputs=[input_files, model_dropdown, batch_size, gr.State("embedding")],
116
  outputs=[download_output, status_output]
117
  )
118
 
119
+ ppl_button.click(
120
+ fn=run_pipeline_with_selected_model,
121
+ inputs=[input_files, model_dropdown, batch_size, gr.State("ppl")],
122
+ outputs=[ppl_download, ppl_status]
123
+ )
124
+
125
 
126
 
127
  gr.Markdown("""
requirements.txt CHANGED
@@ -3,4 +3,5 @@ biopython>=1.81
3
  numpy>=1.21.0
4
  huggingface_hub
5
  transformers
 
6
 
 
3
  numpy>=1.21.0
4
  huggingface_hub
5
  transformers
6
+ pandas
7
 
utils/download_models.py CHANGED
@@ -1,6 +1,6 @@
1
  import torch
2
  import huggingface_hub
3
- from transformers import AutoTokenizer, AutoModel
4
 
5
  def cache_model_weights(model_id):
6
  """
@@ -54,8 +54,8 @@ def load_model(model_id):
54
  """
55
  try:
56
  print(f"Loading {model_id} from local cache...")
57
- tokenizer = AutoTokenizer.from_pretrained(model_id)
58
- model = AutoModel.from_pretrained(
59
  model_id,
60
  output_hidden_states=True,
61
  )
@@ -87,6 +87,7 @@ def load_all_models(models):
87
  return loaded_models
88
 
89
 
 
90
  #def cache_models(models):
91
  # """
92
  # Download weights to ESM models in cache to be loaded later.
 
1
  import torch
2
  import huggingface_hub
3
+ from transformers import AutoTokenizer, AutoModel, EsmForMaskedLM, EsmTokenizer
4
 
5
  def cache_model_weights(model_id):
6
  """
 
54
  """
55
  try:
56
  print(f"Loading {model_id} from local cache...")
57
+ tokenizer = EsmTokenizer.from_pretrained(model_id)
58
+ model = EsmForMaskedLM.from_pretrained(
59
  model_id,
60
  output_hidden_states=True,
61
  )
 
87
  return loaded_models
88
 
89
 
90
+
91
  #def cache_models(models):
92
  # """
93
  # Download weights to ESM models in cache to be loaded later.
utils/pipelines.py CHANGED
@@ -1,12 +1,12 @@
1
  import spaces
2
  import torch
3
- import spaces
4
  import numpy as np
5
  from utils.handle_files import parse_fasta_files
6
  import gradio as gr
7
  import time
8
  import random
9
  import os
 
10
 
11
  @spaces.GPU(duration=240)
12
  def generate_embeddings(sequences_batch, model, tokenizer):
@@ -55,6 +55,97 @@ def generate_embeddings(sequences_batch, model, tokenizer):
55
 
56
  return np.array(sequence_embeddings)
57
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  def full_embedding_pipeline(fasta_files, model, tokenizer, batch_size):
59
  """Full pipeline to process FASTA files and generate embeddings from desired model.
60
 
@@ -96,7 +187,7 @@ def full_embedding_pipeline(fasta_files, model, tokenizer, batch_size):
96
  unique_files = file_info.keys()
97
  session_hash = random.getrandbits(128) # Generate a random hash for this session
98
  time_stamp = time.strftime("%Y-%m-%d-%H-%M-%S")
99
- out_dir = f"./outputs/unconditional_generation/session_{session_hash}_{time_stamp}"
100
  os.makedirs(out_dir, exist_ok=True)
101
  all_file_paths = []
102
  for file_name in unique_files:
@@ -109,5 +200,75 @@ def full_embedding_pipeline(fasta_files, model, tokenizer, batch_size):
109
  all_file_paths.append(file_path)
110
 
111
 
112
- return all_file_paths, all_embeddings
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
 
 
 
1
  import spaces
2
  import torch
 
3
  import numpy as np
4
  from utils.handle_files import parse_fasta_files
5
  import gradio as gr
6
  import time
7
  import random
8
  import os
9
+ import pandas as pd
10
 
11
  @spaces.GPU(duration=240)
12
  def generate_embeddings(sequences_batch, model, tokenizer):
 
55
 
56
  return np.array(sequence_embeddings)
57
 
58
+ @spaces.GPU(duration=240)
59
+ def generate_ppl_scores(sequences_batch, model, tokenizer):
60
+ """Generate pseudo-perplexity scores for ESM models using batched masking across all sequences.
61
+
62
+ Parameters:
63
+ -----------
64
+ sequences_batch : list of str
65
+ A batch of sequences for which to generate embeddings.
66
+ model : AutoModel
67
+ The pre-loaded ESM model. must already be on the correct device (CPU or GPU).
68
+ tokenizer : AutoTokenizer
69
+ The pre-loaded tokenizer corresponding to the ESM model.
70
+
71
+ Returns:
72
+ --------
73
+ ppl_scores : list of float
74
+ A list of perplexity scores for each input sequence.
75
+ """
76
+ device = model.device
77
+ mask_token_id = tokenizer.mask_token_id
78
+ if mask_token_id is None:
79
+ raise ValueError("Tokenizer does not define a mask token; cannot compute pseudo-perplexity.")
80
+
81
+ tokens = tokenizer(
82
+ sequences_batch,
83
+ return_tensors="pt",
84
+ padding=True,
85
+ truncation=True,
86
+ add_special_tokens=True
87
+ ).to(device)
88
+
89
+ input_ids = tokens["input_ids"]
90
+ attention_mask = tokens["attention_mask"]
91
+ batch_size = input_ids.size(0)
92
+ seq_len = input_ids.size(1)
93
+
94
+ # Initialize accumulators for each sequence
95
+ log_prob_sums = torch.zeros(batch_size, device=device)
96
+ token_counts = torch.zeros(batch_size, device=device)
97
+
98
+ # Precompute which positions to score for each sequence (exclude special tokens)
99
+ positions_to_score = []
100
+ for i in range(batch_size):
101
+ valid_positions = torch.nonzero(attention_mask[i], as_tuple=False).squeeze(-1)
102
+ if valid_positions.numel() < 3:
103
+ # Too short to score (less than 1 real token after excluding special tokens)
104
+ positions_to_score.append(set())
105
+ else:
106
+ # Exclude first and last positions (special tokens)
107
+ positions_to_score.append(set(valid_positions[1:-1].tolist()))
108
+
109
+ with torch.no_grad():
110
+ # Process one position at a time across all sequences
111
+ for pos in range(1, seq_len - 1):
112
+ # Find which sequences have a valid token at this position
113
+ active_indices = [i for i in range(batch_size) if pos in positions_to_score[i]]
114
+
115
+ if not active_indices:
116
+ continue
117
+
118
+ # Clone input_ids and mask the current position for all sequences
119
+ masked_batch = input_ids.clone()
120
+ true_token_ids = masked_batch[active_indices, pos].clone()
121
+ masked_batch[active_indices, pos] = mask_token_id
122
+
123
+ # Single forward pass for all sequences
124
+ outputs = model(masked_batch, attention_mask=attention_mask)
125
+ logits = outputs.logits # (batch_size, seq_len, vocab_size)
126
+
127
+ # Extract log-probs for each active sequence at this position
128
+ log_probs = torch.log_softmax(logits[active_indices, pos], dim=-1)
129
+
130
+ # Gather log-probs of the true tokens
131
+ true_log_probs = log_probs.gather(1, true_token_ids.unsqueeze(-1)).squeeze(-1)
132
+
133
+ # Accumulate for each active sequence
134
+ for idx, seq_idx in enumerate(active_indices):
135
+ log_prob_sums[seq_idx] += true_log_probs[idx]
136
+ token_counts[seq_idx] += 1
137
+
138
+ # Compute final pseudo-perplexity scores
139
+ ppl_scores = []
140
+ for i in range(batch_size):
141
+ if token_counts[i] == 0:
142
+ ppl_scores.append(float("inf"))
143
+ else:
144
+ avg_neg_log_prob = -log_prob_sums[i] / token_counts[i]
145
+ ppl_scores.append(float(torch.exp(avg_neg_log_prob).item()))
146
+
147
+ return ppl_scores
148
+
149
  def full_embedding_pipeline(fasta_files, model, tokenizer, batch_size):
150
  """Full pipeline to process FASTA files and generate embeddings from desired model.
151
 
 
187
  unique_files = file_info.keys()
188
  session_hash = random.getrandbits(128) # Generate a random hash for this session
189
  time_stamp = time.strftime("%Y-%m-%d-%H-%M-%S")
190
+ out_dir = f"./outputs/session_{session_hash}_{time_stamp}"
191
  os.makedirs(out_dir, exist_ok=True)
192
  all_file_paths = []
193
  for file_name in unique_files:
 
200
  all_file_paths.append(file_path)
201
 
202
 
203
+ return all_file_paths, status_string
204
+
205
+ def full_ppl_pipeline(fasta_files, model, tokenizer, batch_size):
206
+ """Full pipeline to process FASTA files and generate embeddings from desired model.
207
+
208
+ Parameters:
209
+ -----------
210
+ fasta_files : list of str, obtained from gradio file input
211
+ List of paths to FASTA files to be parsed.
212
+ model : AutoModel
213
+ The pre-loaded ESM model. must already be on the correct device (CPU or GPU).
214
+ tokenizer : AutoTokenizer
215
+ The pre-loaded tokenizer corresponding to the ESM model.
216
+ batch_size : int
217
+ The number of sequences to process in each batch when generating embeddings.
218
+
219
+ Returns:
220
+ --------
221
+ all_file_paths : list of str
222
+ List of file paths where the per-file embeddings were saved. To be passed to gradio for download.
223
+ status_string : str
224
+ A string summarizing the processing steps and output files generated, to be displayed in the gradio interface.
225
+ """
226
+ # Parse FASTA files
227
+ sequences_info, file_info = parse_fasta_files(fasta_files)
228
+
229
+ # Generate embeddings in batches
230
+ all_ppl = []
231
+ n_batches = (len(sequences_info) + batch_size - 1) // batch_size
232
+ status_string = f"Processing {len(sequences_info)} sequences from {len(file_info)} file(s) in {n_batches} batches of {batch_size} sequences...\n"
233
+
234
+ for i in range(0, len(sequences_info), batch_size):
235
+ batch = sequences_info[i:i + batch_size]
236
+ batch_sequences = [seq for _, seq, _ in batch]
237
+
238
+ ppl_scores = generate_ppl_scores(batch_sequences, model, tokenizer)
239
+ status_string += f"Generated {len(ppl_scores)} pseudo-perplexity scores for batch {i // batch_size + 1}/{n_batches}\n"
240
+ all_ppl.extend(ppl_scores)
241
+
242
+ status_string += f"Generated scores for all {len(sequences_info)} sequences.\n"
243
+ unique_files = file_info.keys()
244
+ session_hash = random.getrandbits(128) # Generate a random hash for this session
245
+ time_stamp = time.strftime("%Y-%m-%d-%H-%M-%S")
246
+ out_dir = f"./outputs/session_{session_hash}_{time_stamp}"
247
+ os.makedirs(out_dir, exist_ok=True)
248
+ all_file_paths = []
249
+ for file_name in unique_files:
250
+ indices = [i for i, (_, _, f) in enumerate(sequences_info) if f == file_name]
251
+ file_path = os.path.join(out_dir, f"{file_name}_ppl.csv")
252
+ rows = []
253
+ for idx in indices:
254
+ description, sequence, _ = sequences_info[idx]
255
+ rows.append({
256
+ "description": description,
257
+ "sequence": sequence,
258
+ "ppl_score": all_ppl[idx]
259
+ })
260
+
261
+ df = pd.DataFrame(rows)
262
+ df.to_csv(file_path, index=False)
263
+
264
+ status_string += f"Saved PPL scores to {file_name}_ppl.csv\n"
265
+ all_file_paths.append(file_path)
266
+
267
+ lowest_ppl = min(all_ppl)
268
+ status_string += f"Lowest PPL score across all sequences: {lowest_ppl:.4f}:\n for sequence in file {sequences_info[all_ppl.index(lowest_ppl)][2]}:\n"
269
+ status_string += f">{sequences_info[all_ppl.index(lowest_ppl)][0]}\n"
270
+ status_string += f"{sequences_info[all_ppl.index(lowest_ppl)][1]}\n"
271
+
272
+
273
 
274
+ return all_file_paths, status_string