gabboud commited on
Commit
ecb4e6c
·
1 Parent(s): 48ea20b

introduce approximate PPL through mask batching

Browse files
Files changed (2) hide show
  1. app.py +12 -2
  2. utils/pipelines.py +123 -7
app.py CHANGED
@@ -87,7 +87,9 @@ with gr.Blocks(title="ESM2 Protein Embeddings") as demo:
87
  file_count="multiple"
88
  )
89
  with gr.TabItem("Calculate Pseudo-Perplexity scores"):
90
- ppl_button = gr.Button("Calculate Pseudo-Perplexity", variant="primary", size="lg")
 
 
91
  ppl_status = gr.Textbox(
92
  label="Waiting for pseudo-perplexity calculation...",
93
  interactive=False,
@@ -108,7 +110,9 @@ with gr.Blocks(title="ESM2 Protein Embeddings") as demo:
108
  if task == "embedding":
109
  return full_embedding_pipeline(fasta_files, model, tokenizer, batch_size_value)
110
  elif task == "ppl":
111
- return full_ppl_pipeline(fasta_files, model, tokenizer, batch_size_value)
 
 
112
 
113
  submit_btn.click(
114
  fn=run_pipeline_with_selected_model,
@@ -122,6 +126,12 @@ with gr.Blocks(title="ESM2 Protein Embeddings") as demo:
122
  outputs=[ppl_download, ppl_status]
123
  )
124
 
 
 
 
 
 
 
125
 
126
 
127
  gr.Markdown("""
 
87
  file_count="multiple"
88
  )
89
  with gr.TabItem("Calculate Pseudo-Perplexity scores"):
90
+ with gr.Row():
91
+ ppl_button = gr.Button("Calculate Exact Pseudo-Perplexity", variant="primary", size="lg")
92
+ ppl_approx_button = gr.Button("Calculate Approximate Pseudo-Perplexity", variant="primary", size="lg")
93
  ppl_status = gr.Textbox(
94
  label="Waiting for pseudo-perplexity calculation...",
95
  interactive=False,
 
110
  if task == "embedding":
111
  return full_embedding_pipeline(fasta_files, model, tokenizer, batch_size_value)
112
  elif task == "ppl":
113
+ return full_ppl_pipeline(fasta_files, model, tokenizer, batch_size_value, mask_percentage=None)
114
+ elif task == "ppl-approx":
115
+ return full_ppl_pipeline(fasta_files, model, tokenizer, batch_size_value, mask_percentage=0.1)
116
 
117
  submit_btn.click(
118
  fn=run_pipeline_with_selected_model,
 
126
  outputs=[ppl_download, ppl_status]
127
  )
128
 
129
+ ppl_approx_button.click(
130
+ fn=run_pipeline_with_selected_model,
131
+ inputs=[input_files, model_dropdown, batch_size, gr.State("ppl-approx")],
132
+ outputs=[ppl_download, ppl_status]
133
+ )
134
+
135
 
136
 
137
  gr.Markdown("""
utils/pipelines.py CHANGED
@@ -146,6 +146,119 @@ def generate_ppl_scores(sequences_batch, model, tokenizer):
146
 
147
  return ppl_scores
148
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
  def full_embedding_pipeline(fasta_files, model, tokenizer, batch_size):
150
  """Full pipeline to process FASTA files and generate embeddings from desired model.
151
 
@@ -202,7 +315,7 @@ def full_embedding_pipeline(fasta_files, model, tokenizer, batch_size):
202
 
203
  return all_file_paths, status_string
204
 
205
- def full_ppl_pipeline(fasta_files, model, tokenizer, batch_size):
206
  """Full pipeline to process FASTA files and generate embeddings from desired model.
207
 
208
  Parameters:
@@ -215,6 +328,8 @@ def full_ppl_pipeline(fasta_files, model, tokenizer, batch_size):
215
  The pre-loaded tokenizer corresponding to the ESM model.
216
  batch_size : int
217
  The number of sequences to process in each batch when generating embeddings.
 
 
218
 
219
  Returns:
220
  --------
@@ -222,6 +337,7 @@ def full_ppl_pipeline(fasta_files, model, tokenizer, batch_size):
222
  List of file paths where the per-file embeddings were saved. To be passed to gradio for download.
223
  status_string : str
224
  A string summarizing the processing steps and output files generated, to be displayed in the gradio interface.
 
225
  """
226
  # Parse FASTA files
227
  sequences_info, file_info = parse_fasta_files(fasta_files)
@@ -234,9 +350,12 @@ def full_ppl_pipeline(fasta_files, model, tokenizer, batch_size):
234
  for i in range(0, len(sequences_info), batch_size):
235
  batch = sequences_info[i:i + batch_size]
236
  batch_sequences = [seq for _, seq, _ in batch]
237
-
238
- ppl_scores = generate_ppl_scores(batch_sequences, model, tokenizer)
239
- status_string += f"Generated {len(ppl_scores)} pseudo-perplexity scores for batch {i // batch_size + 1}/{n_batches}\n"
 
 
 
240
  all_ppl.extend(ppl_scores)
241
 
242
  status_string += f"Generated scores for all {len(sequences_info)} sequences.\n"
@@ -269,6 +388,3 @@ def full_ppl_pipeline(fasta_files, model, tokenizer, batch_size):
269
  status_string += f">{sequences_info[all_ppl.index(lowest_ppl)][0]}\n"
270
  status_string += f"{sequences_info[all_ppl.index(lowest_ppl)][1]}\n"
271
 
272
-
273
-
274
- return all_file_paths, status_string
 
146
 
147
  return ppl_scores
148
 
149
+
150
+ @spaces.GPU(duration=240)
151
+ def generate_ppl_scores_approx(sequences_batch, model, tokenizer, mask_percentage=0.15):
152
+ """Generate approximate pseudo-perplexity scores for ESM models using chunked masking.
153
+
154
+ Parameters:
155
+ -----------
156
+ sequences_batch : list of str
157
+ A batch of sequences for which to generate embeddings.
158
+ model : AutoModel
159
+ The pre-loaded ESM model. must already be on the correct device (CPU or GPU).
160
+ tokenizer : AutoTokenizer
161
+ The pre-loaded tokenizer corresponding to the ESM model.
162
+ mask_percentage : float, default=0.15
163
+ Percentage of positions to mask in each forward pass (0 < mask_percentage <= 1).
164
+
165
+ Returns:
166
+ --------
167
+ ppl_scores : list of float
168
+ A list of approximate perplexity scores for each input sequence.
169
+ """
170
+ device = model.device
171
+ mask_token_id = tokenizer.mask_token_id
172
+ if mask_token_id is None:
173
+ raise ValueError("Tokenizer does not define a mask token; cannot compute pseudo-perplexity.")
174
+
175
+ tokens = tokenizer(
176
+ sequences_batch,
177
+ return_tensors="pt",
178
+ padding=True,
179
+ truncation=True,
180
+ add_special_tokens=True
181
+ ).to(device)
182
+
183
+ input_ids = tokens["input_ids"]
184
+ attention_mask = tokens["attention_mask"]
185
+ batch_size = input_ids.size(0)
186
+ seq_len = input_ids.size(1)
187
+
188
+ # Initialize accumulators for each sequence
189
+ log_prob_sums = torch.zeros(batch_size, device=device)
190
+ token_counts = torch.zeros(batch_size, device=device)
191
+
192
+ # Precompute which positions to score for each sequence (exclude special tokens)
193
+ positions_to_score = []
194
+ for i in range(batch_size):
195
+ valid_positions = torch.nonzero(attention_mask[i], as_tuple=False).squeeze(-1)
196
+ if valid_positions.numel() < 3:
197
+ positions_to_score.append([])
198
+ else:
199
+ # Exclude first and last positions (special tokens)
200
+ positions_to_score.append(valid_positions[1:-1].tolist())
201
+
202
+ # Calculate chunk size based on mask percentage
203
+ max_positions = max(len(pos) for pos in positions_to_score) if positions_to_score else 0
204
+ if max_positions == 0:
205
+ return [float("inf")] * batch_size
206
+
207
+ chunk_size = max(1, int(max_positions * mask_percentage))
208
+
209
+ with torch.no_grad():
210
+ # Determine all unique positions across sequences
211
+ all_positions = set()
212
+ for pos_list in positions_to_score:
213
+ all_positions.update(pos_list)
214
+ all_positions = sorted(all_positions)
215
+
216
+ # Process positions in chunks
217
+ for chunk_start in range(0, len(all_positions), chunk_size):
218
+ chunk_positions = all_positions[chunk_start:chunk_start + chunk_size]
219
+
220
+ # Clone input_ids and mask all positions in this chunk
221
+ masked_batch = input_ids.clone()
222
+
223
+ # Track which sequences have tokens at positions in this chunk
224
+ seq_positions = {i: [] for i in range(batch_size)}
225
+ for pos in chunk_positions:
226
+ for seq_idx in range(batch_size):
227
+ if pos in positions_to_score[seq_idx]:
228
+ seq_positions[seq_idx].append(pos)
229
+ masked_batch[seq_idx, pos] = mask_token_id
230
+
231
+ # Skip if no sequences have tokens in this chunk
232
+ active_sequences = [i for i, pos_list in seq_positions.items() if pos_list]
233
+ if not active_sequences:
234
+ continue
235
+
236
+ # Single forward pass for the entire batch with chunk masked
237
+ outputs = model(masked_batch, attention_mask=attention_mask)
238
+ logits = outputs.logits # (batch_size, seq_len, vocab_size)
239
+
240
+ # Compute log-probs for each sequence and position in the chunk
241
+ for seq_idx in active_sequences:
242
+ for pos in seq_positions[seq_idx]:
243
+ true_token_id = input_ids[seq_idx, pos]
244
+ log_probs = torch.log_softmax(logits[seq_idx, pos], dim=-1)
245
+ true_log_prob = log_probs[true_token_id]
246
+
247
+ log_prob_sums[seq_idx] += true_log_prob
248
+ token_counts[seq_idx] += 1
249
+
250
+ # Compute final pseudo-perplexity scores
251
+ ppl_scores = []
252
+ for i in range(batch_size):
253
+ if token_counts[i] == 0:
254
+ ppl_scores.append(float("inf"))
255
+ else:
256
+ avg_neg_log_prob = -log_prob_sums[i] / token_counts[i]
257
+ ppl_scores.append(float(torch.exp(avg_neg_log_prob).item()))
258
+
259
+ return ppl_scores
260
+
261
+
262
  def full_embedding_pipeline(fasta_files, model, tokenizer, batch_size):
263
  """Full pipeline to process FASTA files and generate embeddings from desired model.
264
 
 
315
 
316
  return all_file_paths, status_string
317
 
318
+ def full_ppl_pipeline(fasta_files, model, tokenizer, batch_size, mask_percentage=None):
319
  """Full pipeline to process FASTA files and generate embeddings from desired model.
320
 
321
  Parameters:
 
328
  The pre-loaded tokenizer corresponding to the ESM model.
329
  batch_size : int
330
  The number of sequences to process in each batch when generating embeddings.
331
+ mask_percentage : float or None
332
+ If None, use the exact PPL calculation (masking one token at a time). If a float between 0 and 1, use the approximate chunked masking method with the specified percentage of tokens masked per forward pass.
333
 
334
  Returns:
335
  --------
 
337
  List of file paths where the per-file embeddings were saved. To be passed to gradio for download.
338
  status_string : str
339
  A string summarizing the processing steps and output files generated, to be displayed in the gradio interface.
340
+
341
  """
342
  # Parse FASTA files
343
  sequences_info, file_info = parse_fasta_files(fasta_files)
 
350
  for i in range(0, len(sequences_info), batch_size):
351
  batch = sequences_info[i:i + batch_size]
352
  batch_sequences = [seq for _, seq, _ in batch]
353
+ if mask_percentage is None:
354
+ ppl_scores = generate_ppl_scores(batch_sequences, model, tokenizer)
355
+ status_string += f"Generated {len(ppl_scores)} pseudo-perplexity scores for batch {i // batch_size + 1}/{n_batches}\n"
356
+ else:
357
+ ppl_scores = generate_ppl_scores_approx(batch_sequences, model, tokenizer, mask_percentage=mask_percentage)
358
+ status_string += f"Generated {len(ppl_scores)} approximate pseudo-perplexity scores for batch {i // batch_size + 1}/{n_batches} with mask percentage {mask_percentage*100:.1f}%\n"
359
  all_ppl.extend(ppl_scores)
360
 
361
  status_string += f"Generated scores for all {len(sequences_info)} sequences.\n"
 
388
  status_string += f">{sequences_info[all_ppl.index(lowest_ppl)][0]}\n"
389
  status_string += f"{sequences_info[all_ppl.index(lowest_ppl)][1]}\n"
390