gabboud commited on
Commit
f0a8bfb
·
1 Parent(s): f2af1c6

dynamic max job duration

Browse files
Files changed (2) hide show
  1. app.py +14 -7
  2. utils/pipelines.py +19 -11
app.py CHANGED
@@ -63,6 +63,13 @@ with gr.Blocks(title="ESM2 Protein Embeddings") as demo:
63
  value=32,
64
  label="Batch Size"
65
  )
 
 
 
 
 
 
 
66
 
67
  with gr.Row():
68
  with gr.Column():
@@ -102,34 +109,34 @@ with gr.Blocks(title="ESM2 Protein Embeddings") as demo:
102
  )
103
 
104
 
105
- def run_pipeline_with_selected_model(fasta_files, model_key, batch_size_value, task="embedding"):
106
  """Wrapper to run pipeline with selected model from dropdown."""
107
 
108
  if not fasta_files:
109
  return gr.update(), "No FASTA files uploaded. Please upload at least one FASTA file for inference."
110
  model, tokenizer = models_and_tokenizers[model_key]
111
  if task == "embedding":
112
- return full_embedding_pipeline(fasta_files, model, tokenizer, batch_size_value)
113
  elif task == "ppl":
114
- return full_ppl_pipeline(fasta_files, model, tokenizer, batch_size_value, mask_percentage=None)
115
  elif task == "ppl-approx":
116
- return full_ppl_pipeline(fasta_files, model, tokenizer, batch_size_value, mask_percentage=0.1)
117
 
118
  submit_btn.click(
119
  fn=run_pipeline_with_selected_model,
120
- inputs=[input_files, model_dropdown, batch_size, gr.State("embedding")],
121
  outputs=[download_output, status_output]
122
  )
123
 
124
  ppl_button.click(
125
  fn=run_pipeline_with_selected_model,
126
- inputs=[input_files, model_dropdown, batch_size, gr.State("ppl")],
127
  outputs=[ppl_download, ppl_status]
128
  )
129
 
130
  ppl_approx_button.click(
131
  fn=run_pipeline_with_selected_model,
132
- inputs=[input_files, model_dropdown, batch_size, gr.State("ppl-approx")],
133
  outputs=[ppl_download, ppl_status]
134
  )
135
 
 
63
  value=32,
64
  label="Batch Size"
65
  )
66
+ max_duration = gr.Number(
67
+ value=3600,
68
+ label="Max Duration (seconds)",
69
+ precision=0,
70
+ minimum=1,
71
+ maximum=7199
72
+ )
73
 
74
  with gr.Row():
75
  with gr.Column():
 
109
  )
110
 
111
 
112
+ def run_pipeline_with_selected_model(fasta_files, model_key, batch_size_value, max_duration, task="embedding"):
113
  """Wrapper to run pipeline with selected model from dropdown."""
114
 
115
  if not fasta_files:
116
  return gr.update(), "No FASTA files uploaded. Please upload at least one FASTA file for inference."
117
  model, tokenizer = models_and_tokenizers[model_key]
118
  if task == "embedding":
119
+ return full_embedding_pipeline(fasta_files, model, tokenizer, batch_size_value, max_duration)
120
  elif task == "ppl":
121
+ return full_ppl_pipeline(fasta_files, model, tokenizer, batch_size_value, None, max_duration)
122
  elif task == "ppl-approx":
123
+ return full_ppl_pipeline(fasta_files, model, tokenizer, batch_size_value, 0.1, max_duration)
124
 
125
  submit_btn.click(
126
  fn=run_pipeline_with_selected_model,
127
+ inputs=[input_files, model_dropdown, batch_size, max_duration, gr.State("embedding")],
128
  outputs=[download_output, status_output]
129
  )
130
 
131
  ppl_button.click(
132
  fn=run_pipeline_with_selected_model,
133
+ inputs=[input_files, model_dropdown, batch_size, max_duration, gr.State("ppl")],
134
  outputs=[ppl_download, ppl_status]
135
  )
136
 
137
  ppl_approx_button.click(
138
  fn=run_pipeline_with_selected_model,
139
+ inputs=[input_files, model_dropdown, batch_size, max_duration, gr.State("ppl-approx")],
140
  outputs=[ppl_download, ppl_status]
141
  )
142
 
utils/pipelines.py CHANGED
@@ -8,8 +8,11 @@ import random
8
  import os
9
  import pandas as pd
10
 
11
- @spaces.GPU(duration=240)
12
- def generate_embeddings(sequences_batch, model, tokenizer):
 
 
 
13
  """Generate embeddings for ESM models using the transformers library.
14
 
15
  Parameters:
@@ -55,8 +58,11 @@ def generate_embeddings(sequences_batch, model, tokenizer):
55
 
56
  return np.array(sequence_embeddings)
57
 
58
- @spaces.GPU(duration=240)
59
- def generate_ppl_scores(sequences_batch, model, tokenizer):
 
 
 
60
  """Generate pseudo-perplexity scores for ESM models using batched masking across all sequences.
61
 
62
  Parameters:
@@ -146,9 +152,11 @@ def generate_ppl_scores(sequences_batch, model, tokenizer):
146
 
147
  return ppl_scores
148
 
 
 
149
 
150
- @spaces.GPU(duration=240)
151
- def generate_ppl_scores_approx(sequences_batch, model, tokenizer, mask_percentage=0.15):
152
  """Generate approximate pseudo-perplexity scores for ESM models using chunked masking.
153
 
154
  Parameters:
@@ -258,7 +266,7 @@ def generate_ppl_scores_approx(sequences_batch, model, tokenizer, mask_percentag
258
 
259
  return ppl_scores
260
 
261
- def full_embedding_pipeline(fasta_files, model, tokenizer, batch_size):
262
  """Full pipeline to process FASTA files and generate embeddings from desired model.
263
 
264
  Parameters:
@@ -291,7 +299,7 @@ def full_embedding_pipeline(fasta_files, model, tokenizer, batch_size):
291
  batch = sequences_info[i:i + batch_size]
292
  batch_sequences = [seq for _, seq, _ in batch]
293
 
294
- embeddings = generate_embeddings(batch_sequences, model, tokenizer)
295
  status_string += f"Generated {len(embeddings)} embeddings for batch {i // batch_size + 1}/{n_batches}\n"
296
  all_embeddings.extend(embeddings)
297
 
@@ -314,7 +322,7 @@ def full_embedding_pipeline(fasta_files, model, tokenizer, batch_size):
314
 
315
  return all_file_paths, status_string
316
 
317
- def full_ppl_pipeline(fasta_files, model, tokenizer, batch_size, mask_percentage=None):
318
  """Full pipeline to process FASTA files and generate embeddings from desired model.
319
 
320
  Parameters:
@@ -350,10 +358,10 @@ def full_ppl_pipeline(fasta_files, model, tokenizer, batch_size, mask_percentage
350
  batch = sequences_info[i:i + batch_size]
351
  batch_sequences = [seq for _, seq, _ in batch]
352
  if mask_percentage is None:
353
- ppl_scores = generate_ppl_scores(batch_sequences, model, tokenizer)
354
  status_string += f"Generated {len(ppl_scores)} pseudo-perplexity scores for batch {i // batch_size + 1}/{n_batches}\n"
355
  else:
356
- ppl_scores = generate_ppl_scores_approx(batch_sequences, model, tokenizer, mask_percentage=mask_percentage)
357
  status_string += f"Generated {len(ppl_scores)} approximate pseudo-perplexity scores for batch {i // batch_size + 1}/{n_batches} with mask percentage {mask_percentage*100:.1f}%\n"
358
  all_ppl.extend(ppl_scores)
359
 
 
8
  import os
9
  import pandas as pd
10
 
11
+ def get_duration_embeddings(sequences_batch, model, tokenizer, max_duration):
12
+ return max_duration
13
+
14
+ @spaces.GPU(duration=get_duration_embeddings)
15
+ def generate_embeddings(sequences_batch, model, tokenizer, max_duration):
16
  """Generate embeddings for ESM models using the transformers library.
17
 
18
  Parameters:
 
58
 
59
  return np.array(sequence_embeddings)
60
 
61
+ def get_duration_ppl(sequences_batch, model, tokenizer, max_duration):
62
+ return max_duration
63
+
64
+ @spaces.GPU(duration=get_duration_ppl)
65
+ def generate_ppl_scores(sequences_batch, model, tokenizer, max_duration):
66
  """Generate pseudo-perplexity scores for ESM models using batched masking across all sequences.
67
 
68
  Parameters:
 
152
 
153
  return ppl_scores
154
 
155
+ def get_duration_ppl_approx(sequences_batch, model, tokenizer, mask_percentage, max_duration):
156
+ return max_duration
157
 
158
+ @spaces.GPU(duration=get_duration_ppl_approx)
159
+ def generate_ppl_scores_approx(sequences_batch, model, tokenizer, mask_percentage=0.15, max_duration=240):
160
  """Generate approximate pseudo-perplexity scores for ESM models using chunked masking.
161
 
162
  Parameters:
 
266
 
267
  return ppl_scores
268
 
269
+ def full_embedding_pipeline(fasta_files, model, tokenizer, batch_size, max_duration):
270
  """Full pipeline to process FASTA files and generate embeddings from desired model.
271
 
272
  Parameters:
 
299
  batch = sequences_info[i:i + batch_size]
300
  batch_sequences = [seq for _, seq, _ in batch]
301
 
302
+ embeddings = generate_embeddings(batch_sequences, model, tokenizer, max_duration)
303
  status_string += f"Generated {len(embeddings)} embeddings for batch {i // batch_size + 1}/{n_batches}\n"
304
  all_embeddings.extend(embeddings)
305
 
 
322
 
323
  return all_file_paths, status_string
324
 
325
+ def full_ppl_pipeline(fasta_files, model, tokenizer, batch_size, mask_percentage=None, max_duration=240):
326
  """Full pipeline to process FASTA files and generate embeddings from desired model.
327
 
328
  Parameters:
 
358
  batch = sequences_info[i:i + batch_size]
359
  batch_sequences = [seq for _, seq, _ in batch]
360
  if mask_percentage is None:
361
+ ppl_scores = generate_ppl_scores(batch_sequences, model, tokenizer, max_duration)
362
  status_string += f"Generated {len(ppl_scores)} pseudo-perplexity scores for batch {i // batch_size + 1}/{n_batches}\n"
363
  else:
364
+ ppl_scores = generate_ppl_scores_approx(batch_sequences, model, tokenizer, mask_percentage=mask_percentage, max_duration=max_duration)
365
  status_string += f"Generated {len(ppl_scores)} approximate pseudo-perplexity scores for batch {i // batch_size + 1}/{n_batches} with mask percentage {mask_percentage*100:.1f}%\n"
366
  all_ppl.extend(ppl_scores)
367