gabboud commited on
Commit
4dcb469
·
1 Parent(s): c3ad370

replace fair-esm model access with huggingface hub, modularize and simplify post-processing

Browse files
Files changed (5) hide show
  1. app.py +37 -157
  2. requirements.txt +3 -1
  3. utils/download_models.py +157 -0
  4. utils/handle_files.py +90 -0
  5. utils/pipelines.py +113 -0
app.py CHANGED
@@ -8,164 +8,20 @@ import json
8
  from pathlib import Path
9
  import zipfile
10
  import spaces
 
 
 
11
 
12
- # Load ESM2 model
13
- print("Loading ESM2 model...")
14
- import esm
15
 
16
- # Load the model and alphabet
17
- model_name = "esm2_t33_650M_UR50D"
18
- try:
19
- model, alphabet = esm.pretrained.load_model_and_alphabet_local(model_name)
20
- except:
21
- print(f"Loading {model_name} from HuggingFace...")
22
- model, alphabet = esm.pretrained.load_model_and_alphabet_hub(model_name)
23
 
24
- model = model.eval()
25
- device = "cuda" if torch.cuda.is_available() else "cpu"
26
- model = model.to(device)
27
- batch_converter = alphabet.get_batch_converter()
28
-
29
- print(f"Model loaded on {device}")
30
-
31
-
32
- def parse_fasta_files(fasta_files):
33
- """Parse one or multiple FASTA files and return sequences."""
34
- sequences = []
35
- file_info = {}
36
-
37
- for fasta_file in fasta_files:
38
- file_name = Path(fasta_file.name).stem
39
- file_seqs = []
40
-
41
- try:
42
- for record in SeqIO.parse(fasta_file, "fasta"):
43
- sequences.append((record.id, str(record.seq), file_name))
44
- file_seqs.append(record.id)
45
- file_info[file_name] = file_seqs
46
- except Exception as e:
47
- raise ValueError(f"Error parsing {fasta_file.name}: {str(e)}")
48
-
49
- if not sequences:
50
- raise ValueError("No sequences found in the provided FASTA files.")
51
-
52
- return sequences, file_info
53
-
54
- @spaces.GPU(duration=240)
55
- def generate_embeddings(sequences_batch):
56
- """Generate embeddings for a batch of sequences."""
57
- # Prepare batch for ESM2
58
- batch_labels, batch_strs, batch_tokens = batch_converter(sequences_batch)
59
-
60
- # Move to device
61
- batch_tokens = batch_tokens.to(device)
62
-
63
- # Generate embeddings
64
- with torch.no_grad():
65
- results = model(batch_tokens, repr_layers=[33], return_contacts=False)
66
-
67
- # Extract embeddings (token representations from layer 33)
68
- token_embeddings = results["representations"][33]
69
-
70
- # Get sequence-level embeddings (mean pooling of token embeddings, excluding special tokens)
71
- sequence_embeddings = []
72
- for i, (label, seq) in enumerate(zip(batch_labels, batch_strs)):
73
- # Remove special tokens (first and last)
74
- seq_embedding = token_embeddings[i, 1:len(seq) + 1].mean(dim=0)
75
- sequence_embeddings.append(seq_embedding.cpu().numpy())
76
-
77
- return sequence_embeddings
78
-
79
-
80
- def process_embeddings(fasta_files):
81
- """Main function to process FASTA files and generate embeddings."""
82
- try:
83
- # Parse FASTA files
84
- sequences, file_info = parse_fasta_files(fasta_files)
85
-
86
- # Generate embeddings in batches
87
- batch_size = 8
88
- all_embeddings = {}
89
- status_updates = [f"Processing {len(sequences)} sequences from {len(file_info)} file(s)..."]
90
-
91
- for i in range(0, len(sequences), batch_size):
92
- batch = sequences[i:i + batch_size]
93
- batch_labels = [(seq_id, seq, file_name) for seq_id, seq, file_name in batch]
94
-
95
- status_updates.append(f"Generating embeddings for sequences {i + 1}-{min(i + batch_size, len(sequences))}...")
96
-
97
- # Generate embeddings
98
- embeddings = generate_embeddings([(label, seq) for label, seq, _ in batch_labels])
99
-
100
- # Store embeddings
101
- for (seq_id, seq, file_name), embedding in zip(batch_labels, embeddings):
102
- key = f"{file_name}_{seq_id}"
103
- all_embeddings[key] = {
104
- "sequence_id": seq_id,
105
- "file": file_name,
106
- "sequence_length": len(seq),
107
- "embedding": embedding.tolist()
108
- }
109
-
110
- # Create output files
111
- output_files = []
112
- temp_dir = tempfile.mkdtemp()
113
-
114
- # Save embeddings as NPZ (numpy compressed format)
115
- npz_path = os.path.join(temp_dir, "embeddings.npz")
116
- embeddings_array = {k: np.array(v["embedding"]) for k, v in all_embeddings.items()}
117
- np.savez_compressed(npz_path, **embeddings_array)
118
- output_files.append(npz_path)
119
- status_updates.append(f"Saved compressed embeddings to embeddings.npz")
120
-
121
- # Save metadata as JSON
122
- metadata_path = os.path.join(temp_dir, "metadata.json")
123
- metadata = {
124
- "num_sequences": len(all_embeddings),
125
- "embedding_dim": 1280, # ESM2-650M has 1280-dimensional embeddings
126
- "model": model_name,
127
- "sequences": {k: {
128
- "sequence_id": v["sequence_id"],
129
- "file": v["file"],
130
- "sequence_length": v["sequence_length"]
131
- } for k, v in all_embeddings.items()}
132
- }
133
- with open(metadata_path, 'w') as f:
134
- json.dump(metadata, f, indent=2)
135
- output_files.append(metadata_path)
136
- status_updates.append(f"Saved metadata to metadata.json")
137
-
138
- # Create per-file embedding files
139
- for file_name in file_info.keys():
140
- file_embeddings = {k: v for k, v in embeddings_array.items() if k.startswith(file_name)}
141
- if file_embeddings:
142
- file_npz_path = os.path.join(temp_dir, f"embeddings_{file_name}.npz")
143
- np.savez_compressed(file_npz_path, **file_embeddings)
144
- output_files.append(file_npz_path)
145
- status_updates.append(f"Saved {len(file_embeddings)} embeddings for {file_name}")
146
-
147
- # Create a summary report
148
- summary_path = os.path.join(temp_dir, "summary.txt")
149
- with open(summary_path, 'w') as f:
150
- f.write("ESM2 Protein Sequence Embedding Summary\n")
151
- f.write("=" * 50 + "\n\n")
152
- f.write(f"Model: {model_name}\n")
153
- f.write(f"Device: {device}\n")
154
- f.write(f"Embedding Dimension: 1280\n\n")
155
- f.write(f"Input Files: {len(file_info)}\n")
156
- f.write(f"Total Sequences: {len(all_embeddings)}\n\n")
157
- f.write("Sequences per file:\n")
158
- for file_name, seq_ids in file_info.items():
159
- f.write(f" - {file_name}: {len(seq_ids)} sequences\n")
160
- output_files.append(summary_path)
161
-
162
- status_message = "\n".join(status_updates)
163
- status_message += f"\n\nSuccessfully generated embeddings for {len(all_embeddings)} sequences!"
164
-
165
- return output_files, status_message
166
-
167
- except Exception as e:
168
- raise gr.Error(f"Error processing sequences: {str(e)}")
169
 
170
 
171
  # Create Gradio interface
@@ -214,12 +70,36 @@ with gr.Blocks(title="ESM2 Protein Embeddings") as demo:
214
  label="Download Output Files",
215
  file_count="multiple"
216
  )
217
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
218
  submit_btn.click(
219
- fn=process_embeddings,
220
  inputs=[input_files],
221
  outputs=[download_output, status_output]
222
  )
 
 
223
 
224
  gr.Markdown("""
225
  ### How to use the embeddings:
 
8
  from pathlib import Path
9
  import zipfile
10
  import spaces
11
+ from utils.download_models import *
12
+ from utils.handle_files import parse_fasta_files
13
+ from utils.pipelines import generate_embeddings, full_embedding_pipeline
14
 
15
+ print("Downloading ESM2 models...")
 
 
16
 
17
+ MODELS = {
18
+ "facebook/esm2_t6_8M_UR50D": "ESM2-8M",
19
+ "facebook/esm2_t12_35M_UR50D": "ESM2-35M",
20
+ #"esm2_t36_650M_UR50D": "ESM2-650M"
21
+ }
 
 
22
 
23
+ cache_dirs = cache_all_models(MODELS)
24
+ models_and_tokenizers = load_all_models(MODELS)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
 
27
  # Create Gradio interface
 
70
  label="Download Output Files",
71
  file_count="multiple"
72
  )
73
+
74
+ with gr.Row():
75
+ model_dropdown = gr.Dropdown(
76
+ choices=list(MODELS.values()),
77
+ value=list(MODELS.values())[0],
78
+ label="Select Model"
79
+ )
80
+
81
+
82
+ model_to_use = gr.State(value=models_and_tokenizers[model_dropdown.value][0])
83
+ tokenizer_to_use = gr.State(value=models_and_tokenizers[model_dropdown.value][1])
84
+
85
+ def pick_model(model_name):
86
+ model_key = [key for key, value in MODELS.items() if value == model_name][0]
87
+ print(f"Selected model: {model_name} ({model_key})")
88
+ return models_and_tokenizers[model_key]
89
+
90
+ model_dropdown.change(
91
+ fn=pick_model,
92
+ inputs=model_dropdown,
93
+ outputs=[model_to_use, tokenizer_to_use]
94
+ )
95
+
96
  submit_btn.click(
97
+ fn=full_embedding_pipeline,
98
  inputs=[input_files],
99
  outputs=[download_output, status_output]
100
  )
101
+
102
+
103
 
104
  gr.Markdown("""
105
  ### How to use the embeddings:
requirements.txt CHANGED
@@ -1,4 +1,6 @@
1
  torch>=2.0.0
2
- fair-esm>=2.0.0
3
  biopython>=1.81
4
  numpy>=1.21.0
 
 
 
 
1
  torch>=2.0.0
 
2
  biopython>=1.81
3
  numpy>=1.21.0
4
+ huggingface_hub
5
+ transformers
6
+
utils/download_models.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import esm
2
+ import torch
3
+ import huggingface_hub
4
+ from transformers import AutoTokenizer, AutoModel
5
+
6
+ def cache_model_weights(model_id):
7
+ """
8
+ Download ESM2 model weights to cache without loading into memory. Called upon restarting of spaces to have weights ready to load once inference is called.
9
+ Downloading weights without
10
+
11
+ Parameters:
12
+ -----------
13
+ model_id : str
14
+ Model identifier (e.g., "facebook/esm2_t6_8M_UR50D")
15
+
16
+ Returns:
17
+ --------
18
+ str : Path to cached model directory
19
+ """
20
+ cache_dir = huggingface_hub.snapshot_download(model_id)
21
+ print(f"Model {model_id} cached at: {cache_dir}")
22
+ return cache_dir
23
+
24
+ def cache_all_models(models):
25
+ """
26
+ Cache all models in the provided dictionary.
27
+
28
+ Parameters:
29
+ -----------
30
+ models : dict
31
+ A dictionary where keys are model identifiers (e.g., "facebook/esm2_t6_8M_UR50D") and values are human-readable model names (e.g., "ESM2-8M").
32
+
33
+ Returns:
34
+ --------
35
+ dict : A dictionary mapping model identifiers to their cache directories.
36
+ """
37
+ cache_dirs = {}
38
+ for model_id in models.keys():
39
+ cache_dirs[model_id] = cache_model_weights(model_id)
40
+ return cache_dirs
41
+
42
+ def load_model(model_id):
43
+ """
44
+ Load ESM model and tokenizer using from_pretrained. Initializes from default cache directory or downloads if missing.
45
+ To be used after cache_model_weights for control over when models are downloaded
46
+
47
+ Parameters:
48
+ -----------
49
+ model_id : str
50
+ Model identifier (e.g., "facebook/esm2_t6_8M_UR50D")
51
+
52
+ Returns:
53
+ --------
54
+ tuple : (model, tokenized) loaded from cache
55
+ """
56
+ try:
57
+ print(f"Loading {model_id} from local cache...")
58
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
59
+ model = AutoModel.from_pretrained(
60
+ model_id,
61
+ output_hidden_states=True,
62
+ )
63
+ except Exception as e:
64
+ raise RuntimeError(f"Failed to load model {model_id} from cache: {e}")
65
+
66
+ model = model.eval()
67
+ device = "cuda" if torch.cuda.is_available() else "cpu"
68
+ model = model.to(device)
69
+ print(f"{model_id} loaded on {device}")
70
+ return model, tokenizer
71
+
72
+ def load_all_models(models):
73
+ """
74
+ Load all models in the provided dictionary.
75
+
76
+ Parameters:
77
+ -----------
78
+ models : dict
79
+ A dictionary where keys are model identifiers (e.g., "facebook/esm2_t6_8M_UR50D") and values are human-readable model names (e.g., "ESM2-8M").
80
+
81
+ Returns:
82
+ --------
83
+ dict : A dictionary mapping model identifiers to their loaded (model, tokenizer) tuples.
84
+ """
85
+ loaded_models = {}
86
+ for model_id in models.keys():
87
+ loaded_models[model_id] = load_model(model_id)
88
+ return loaded_models
89
+
90
+
91
+ #def cache_models(models):
92
+ # """
93
+ # Download weights to ESM models in cache to be loaded later.
94
+ # We do not load the models into memory at this stage to avoid using GPU memory for models that are not used in the current session.
95
+ #
96
+ # Parameters:
97
+ # ----------
98
+ # models: dict
99
+ # A dictionary where keys are model identifiers (e.g., "esm2_t6_8M_UR50D") and values are human-readable model names (e.g., "ESM2-8M").
100
+ #
101
+ # Returns:
102
+ # -------
103
+ #
104
+ # """
105
+ # loaded_models = {}
106
+ # for model_id, model_name in models.items():
107
+ # print(f"Loading {model_name}...")
108
+ # try:
109
+ # #load from local cache if avilable, upon startup of space will fail and load from HF
110
+ # model, alphabet = esm.pretrained.load_model_and_alphabet_local(model_id)
111
+ # except:
112
+ # print(f"Loading {model_name} from HuggingFace...")
113
+ # model, alphabet = esm.pretrained.load_model_and_alphabet_hub(model_id)
114
+ #
115
+ # model = model.eval()
116
+ # device = "cuda" if torch.cuda.is_available() else "cpu"
117
+ # model = model.to(device)
118
+ # loaded_models[model_id] = {
119
+ # "model": model,
120
+ # "alphabet": alphabet,
121
+ # "batch_converter": alphabet.get_batch_converter()
122
+ # }
123
+ # print(f"{model_name} loaded on {device}")
124
+ #
125
+ #def download_models(models):
126
+ # """
127
+ # Download weights to ESM models in cache to be loaded later.
128
+ # We do not load the models into memory at this stage to avoid using GPU memory for models that are not used in the current session.
129
+ #
130
+ # Parameters:
131
+ # ----------
132
+ # models: dict
133
+ # A dictionary where keys are model identifiers (e.g., "esm2_t6_8M_UR50D") and values are human-readable model names (e.g., "ESM2-8M").
134
+ #
135
+ # Returns:
136
+ # -------
137
+ #
138
+ # """
139
+ # loaded_models = {}
140
+ # for model_id, model_name in models.items():
141
+ # print(f"Loading {model_name}...")
142
+ # try:
143
+ # #load from local cache if avilable, upon startup of space will fail and load from HF
144
+ # model, alphabet = esm.pretrained.load_model_and_alphabet_local(model_id)
145
+ # except:
146
+ # print(f"Loading {model_name} from HuggingFace...")
147
+ # model, alphabet = esm.pretrained.load_model_and_alphabet_hub(model_id)
148
+ #
149
+ # model = model.eval()
150
+ # device = "cuda" if torch.cuda.is_available() else "cpu"
151
+ # model = model.to(device)
152
+ # loaded_models[model_id] = {
153
+ # "model": model,
154
+ # "alphabet": alphabet,
155
+ # "batch_converter": alphabet.get_batch_converter()
156
+ # }
157
+ # print(f"{model_name} loaded on {device}")
utils/handle_files.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from Bio import SeqIO
3
+
4
+
5
+ def parse_fasta_files(fasta_files):
6
+ """Parse one or multiple FASTA files and return sequences.
7
+ This function uses the entire header line as sequence_id to deal with LigandMPNN's omittance of a unique sequence ID at the beginning of the header.
8
+
9
+ Parameters:
10
+ -----------
11
+ fasta_files : list of str
12
+ List of paths to FASTA files to be parsed.
13
+ Returns:
14
+ --------
15
+ sequences : list of tuples
16
+ A list of tuples containing (sequence_id, sequence, file_name) for each sequence found in the FASTA files.
17
+ file_info : dict
18
+ A dictionary mapping file names to lists of sequence IDs contained in each file.
19
+ """
20
+ sequences = []
21
+ file_info = {}
22
+
23
+ for fasta_file in fasta_files:
24
+ print(fasta_file)
25
+ if fasta_file.endswith('.fasta') or fasta_file.endswith('.fa'):
26
+ file_name = Path(fasta_file).stem
27
+ file_seqs = []
28
+
29
+ try:
30
+ for record in SeqIO.parse(fasta_file, "fasta"):
31
+ # Use the entire header as the sequence ID
32
+ full_header = record.description # Full header line without '>'
33
+ sequences.append((full_header, str(record.seq), file_name))
34
+ file_seqs.append(full_header)
35
+ file_info[file_name] = file_seqs
36
+ except Exception as e:
37
+ raise ValueError(f"Error parsing {fasta_file.name}: {str(e)}")
38
+
39
+ if not sequences:
40
+ raise ValueError("No sequences found in the provided FASTA files.")
41
+
42
+ return sequences, file_info
43
+
44
+ def parse_fasta_files_from_ligandmpnn(fasta_files):
45
+ """Parse one or multiple FASTA files and return sequences. These files are expected to be in the format generated by LigandMPNN.
46
+ In these fasta files, there is no sequence_id in the header, It's the name of the file + some info on generated sequence quality + the number of the designs "id=0"
47
+ Hence special parsing is needed to extract the sequence_id from the header.
48
+
49
+ Parameters:
50
+ -----------
51
+ fasta_files : list of str
52
+ List of paths to FASTA files to be parsed.
53
+ Returns:
54
+ --------
55
+ sequences : list of tuples
56
+ A list of tuples containing (sequence_id, sequence, file_name) for each sequence found in the FASTA files.
57
+ file_info : dict
58
+ A dictionary mapping file names to lists of sequence IDs contained in each file.
59
+ """
60
+ sequences = []
61
+ file_info = {}
62
+
63
+ for fasta_file in fasta_files:
64
+ print(fasta_file)
65
+ if fasta_file.endswith('.fasta') or fasta_file.endswith('.fa'):
66
+ file_name = Path(fasta_file).stem
67
+ file_seqs = []
68
+
69
+ try:
70
+ for record in SeqIO.parse(fasta_file, "fasta"):
71
+ # Extract id from description if it contains id=
72
+ seq_id = record.id
73
+ if "id=" in record.description:
74
+ # Parse the description to find id=...
75
+ parts = record.description.split()
76
+ for part in parts:
77
+ if part.startswith("id="):
78
+ seq_id = part[3:] # Remove "id=" prefix
79
+ break
80
+
81
+ sequences.append((seq_id, str(record.seq), file_name))
82
+ file_seqs.append(seq_id)
83
+ file_info[file_name] = file_seqs
84
+ except Exception as e:
85
+ raise ValueError(f"Error parsing {fasta_file.name}: {str(e)}")
86
+
87
+ if not sequences:
88
+ raise ValueError("No sequences found in the provided FASTA files.")
89
+
90
+ return sequences, file_info
utils/pipelines.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import torch
3
+ import spaces
4
+ import numpy as np
5
+ from utils.handle_files import parse_fasta_files
6
+ import gradio as gr
7
+ import time
8
+ import random
9
+ import os
10
+
11
+ @spaces(duration=240)
12
+ def generate_embeddings(sequences_batch, model, tokenizer):
13
+ """Generate embeddings for ESM models using the transformers library.
14
+
15
+ Parameters:
16
+ -----------
17
+ sequences_batch : list of str
18
+ A batch of sequences for which to generate embeddings.
19
+ model : AutoModel
20
+ The pre-loaded ESM model. must already be on the correct device (CPU or GPU).
21
+ tokenizer : AutoTokenizer
22
+ The pre-loaded tokenizer corresponding to the ESM model.
23
+
24
+ Returns:
25
+ --------
26
+ sequence_embeddings : 2D np.array of shape (batch_size, embedding_dim)
27
+ A list of sequence-level embeddings (mean-pooled) for each input sequence.
28
+ """
29
+ # Tokenize sequences
30
+ device = model.device
31
+ tokens = tokenizer(
32
+ sequences_batch,
33
+ return_tensors="pt",
34
+ padding=True,
35
+ truncation=True,
36
+ add_special_tokens=True
37
+ ).to(device)
38
+
39
+ # Generate embeddings
40
+ with torch.no_grad():
41
+ results = model(**tokens)
42
+
43
+ # Extract hidden states from last layer
44
+ token_embeddings = results.hidden_states[-1] # Last layer embeddings
45
+
46
+
47
+ # Get sequence-level embeddings (mean pooling, excluding special tokens)
48
+ sequence_embeddings = []
49
+ for i, seq in enumerate(sequences_batch):
50
+ # Remove special tokens (first and last)
51
+ seq_embedding = token_embeddings[i, 1:len(seq) + 1].mean(dim=0)
52
+ # this might seem inefficient compared to token_embeddings[:,1:seq_len+1,:].mean...
53
+ # but it is necessary to account for variable sequence lengths and ensure we only average over the actual sequence tokens, not the padding or special tokens.
54
+ sequence_embeddings.append(seq_embedding.cpu().numpy())
55
+
56
+ return np.array(sequence_embeddings)
57
+
58
+ def full_embedding_pipeline(fasta_files, model, tokenizer, batch_size):
59
+ """Full pipeline to process FASTA files and generate embeddings from desired model.
60
+
61
+ Parameters:
62
+ -----------
63
+ fasta_files : list of str, obtained from gradio file input
64
+ List of paths to FASTA files to be parsed.
65
+ model : AutoModel
66
+ The pre-loaded ESM model. must already be on the correct device (CPU or GPU).
67
+ tokenizer : AutoTokenizer
68
+ The pre-loaded tokenizer corresponding to the ESM model.
69
+ batch_size : int
70
+ The number of sequences to process in each batch when generating embeddings.
71
+
72
+ Returns:
73
+ --------
74
+ all_file_paths : list of str
75
+ List of file paths where the per-file embeddings were saved. To be passed to gradio for download.
76
+ status_string : str
77
+ A string summarizing the processing steps and output files generated, to be displayed in the gradio interface.
78
+ """
79
+ # Parse FASTA files
80
+ sequences_info, file_info = parse_fasta_files(fasta_files)
81
+
82
+ # Generate embeddings in batches
83
+ all_embeddings = []
84
+ n_batches = (len(sequences_info) + batch_size - 1) // batch_size
85
+ status_string = f"Processing {len(sequences_info)} sequences from {len(file_info)} file(s) in {n_batches} batches of {batch_size} sequences...\n"
86
+
87
+ for i in range(0, len(sequences_info), batch_size):
88
+ batch = sequences_info[i:i + batch_size]
89
+ batch_sequences = [seq for _, seq, _ in batch]
90
+
91
+ embeddings = generate_embeddings(batch_sequences, model, tokenizer)
92
+ status_string += f"Generated {len(embeddings)} embeddings for batch {i // batch_size + 1}/{n_batches}\n"
93
+ all_embeddings.extend(embeddings)
94
+
95
+ status_string += f"Generated embeddings for all {len(sequences_info)} sequences.\n"
96
+ unique_files = file_info.keys()
97
+ session_hash = random.getrandbits(128) # Generate a random hash for this session
98
+ time_stamp = time.strftime("%Y-%m-%d-%H-%M-%S")
99
+ out_dir = f"./outputs/unconditional_generation/session_{session_hash}_{time_stamp}"
100
+ os.makedirs(out_dir, exist_ok=True)
101
+ all_file_paths = []
102
+ for file_name in unique_files:
103
+ indices = [i for i, (_, _, f) in enumerate(sequences_info) if f == file_name]
104
+ file_embeddings = np.array([all_embeddings[i] for i in indices])
105
+ sequence_ids = [sequences_info[i][0] for i in indices] # Extract sequence IDs for this file
106
+ file_path = os.path.join(out_dir, f"{file_name}_embeddings.npz")
107
+ np.savez_compressed(file_path, embeddings=file_embeddings, sequence_ids=sequence_ids)
108
+ status_string += f"Saved compressed embeddings to {file_name}_embeddings.npz\n"
109
+ all_file_paths.append(file_path)
110
+
111
+
112
+ return all_file_paths, all_embeddings
113
+