gabboud commited on
Commit
2898034
·
1 Parent(s): b3297e4

first try ESM app

Browse files
Files changed (2) hide show
  1. app.py +242 -4
  2. requirements.txt +5 -0
app.py CHANGED
@@ -1,7 +1,245 @@
1
  import gradio as gr
 
 
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
1
  import gradio as gr
2
+ import torch
3
+ import numpy as np
4
+ from bio import SeqIO
5
+ import tempfile
6
+ import os
7
+ import json
8
+ from pathlib import Path
9
+ import zipfile
10
 
11
+ # Load ESM2 model
12
+ print("Loading ESM2 model...")
13
+ import esm
14
+
15
+ # Load the model and alphabet
16
+ model_name = "esm2_t33_650M_UR50D"
17
+ try:
18
+ model, alphabet = esm.pretrained.load_model_and_alphabet_local(model_name)
19
+ except:
20
+ print(f"Loading {model_name} from HuggingFace...")
21
+ model, alphabet = esm.pretrained.load_model_and_alphabet_hub(model_name)
22
+
23
+ model = model.eval()
24
+ device = "cuda" if torch.cuda.is_available() else "cpu"
25
+ model = model.to(device)
26
+ batch_converter = alphabet.get_batch_converter()
27
+
28
+ print(f"Model loaded on {device}")
29
+
30
+
31
+ def parse_fasta_files(fasta_files):
32
+ """Parse one or multiple FASTA files and return sequences."""
33
+ sequences = []
34
+ file_info = {}
35
+
36
+ for fasta_file in fasta_files:
37
+ file_name = Path(fasta_file.name).stem
38
+ file_seqs = []
39
+
40
+ try:
41
+ for record in SeqIO.parse(fasta_file, "fasta"):
42
+ sequences.append((record.id, str(record.seq), file_name))
43
+ file_seqs.append(record.id)
44
+ file_info[file_name] = file_seqs
45
+ except Exception as e:
46
+ raise ValueError(f"Error parsing {fasta_file.name}: {str(e)}")
47
+
48
+ if not sequences:
49
+ raise ValueError("No sequences found in the provided FASTA files.")
50
+
51
+ return sequences, file_info
52
+
53
+
54
+ def generate_embeddings(sequences_batch):
55
+ """Generate embeddings for a batch of sequences."""
56
+ # Prepare batch for ESM2
57
+ batch_labels, batch_strs, batch_tokens = batch_converter(sequences_batch)
58
+
59
+ # Move to device
60
+ batch_tokens = batch_tokens.to(device)
61
+
62
+ # Generate embeddings
63
+ with torch.no_grad():
64
+ results = model(batch_tokens, repr_layers=[33], return_contacts=False)
65
+
66
+ # Extract embeddings (token representations from layer 33)
67
+ token_embeddings = results["representations"][33]
68
+
69
+ # Get sequence-level embeddings (mean pooling of token embeddings, excluding special tokens)
70
+ sequence_embeddings = []
71
+ for i, (label, seq) in enumerate(zip(batch_labels, batch_strs)):
72
+ # Remove special tokens (first and last)
73
+ seq_embedding = token_embeddings[i, 1:len(seq) + 1].mean(dim=0)
74
+ sequence_embeddings.append(seq_embedding.cpu().numpy())
75
+
76
+ return sequence_embeddings
77
+
78
+
79
+ def process_embeddings(fasta_files):
80
+ """Main function to process FASTA files and generate embeddings."""
81
+ try:
82
+ # Parse FASTA files
83
+ sequences, file_info = parse_fasta_files(fasta_files)
84
+
85
+ # Generate embeddings in batches
86
+ batch_size = 8
87
+ all_embeddings = {}
88
+ status_updates = [f"Processing {len(sequences)} sequences from {len(file_info)} file(s)..."]
89
+
90
+ for i in range(0, len(sequences), batch_size):
91
+ batch = sequences[i:i + batch_size]
92
+ batch_labels = [(seq_id, seq, file_name) for seq_id, seq, file_name in batch]
93
+
94
+ status_updates.append(f"Generating embeddings for sequences {i + 1}-{min(i + batch_size, len(sequences))}...")
95
+
96
+ # Generate embeddings
97
+ embeddings = generate_embeddings([(label, seq) for label, seq, _ in batch_labels])
98
+
99
+ # Store embeddings
100
+ for (seq_id, seq, file_name), embedding in zip(batch_labels, embeddings):
101
+ key = f"{file_name}_{seq_id}"
102
+ all_embeddings[key] = {
103
+ "sequence_id": seq_id,
104
+ "file": file_name,
105
+ "sequence_length": len(seq),
106
+ "embedding": embedding.tolist()
107
+ }
108
+
109
+ # Create output files
110
+ output_files = []
111
+ temp_dir = tempfile.mkdtemp()
112
+
113
+ # Save embeddings as NPZ (numpy compressed format)
114
+ npz_path = os.path.join(temp_dir, "embeddings.npz")
115
+ embeddings_array = {k: np.array(v["embedding"]) for k, v in all_embeddings.items()}
116
+ np.savez_compressed(npz_path, **embeddings_array)
117
+ output_files.append(npz_path)
118
+ status_updates.append(f"Saved compressed embeddings to embeddings.npz")
119
+
120
+ # Save metadata as JSON
121
+ metadata_path = os.path.join(temp_dir, "metadata.json")
122
+ metadata = {
123
+ "num_sequences": len(all_embeddings),
124
+ "embedding_dim": 1280, # ESM2-650M has 1280-dimensional embeddings
125
+ "model": model_name,
126
+ "sequences": {k: {
127
+ "sequence_id": v["sequence_id"],
128
+ "file": v["file"],
129
+ "sequence_length": v["sequence_length"]
130
+ } for k, v in all_embeddings.items()}
131
+ }
132
+ with open(metadata_path, 'w') as f:
133
+ json.dump(metadata, f, indent=2)
134
+ output_files.append(metadata_path)
135
+ status_updates.append(f"Saved metadata to metadata.json")
136
+
137
+ # Create per-file embedding files
138
+ for file_name in file_info.keys():
139
+ file_embeddings = {k: v for k, v in embeddings_array.items() if k.startswith(file_name)}
140
+ if file_embeddings:
141
+ file_npz_path = os.path.join(temp_dir, f"embeddings_{file_name}.npz")
142
+ np.savez_compressed(file_npz_path, **file_embeddings)
143
+ output_files.append(file_npz_path)
144
+ status_updates.append(f"Saved {len(file_embeddings)} embeddings for {file_name}")
145
+
146
+ # Create a summary report
147
+ summary_path = os.path.join(temp_dir, "summary.txt")
148
+ with open(summary_path, 'w') as f:
149
+ f.write("ESM2 Protein Sequence Embedding Summary\n")
150
+ f.write("=" * 50 + "\n\n")
151
+ f.write(f"Model: {model_name}\n")
152
+ f.write(f"Device: {device}\n")
153
+ f.write(f"Embedding Dimension: 1280\n\n")
154
+ f.write(f"Input Files: {len(file_info)}\n")
155
+ f.write(f"Total Sequences: {len(all_embeddings)}\n\n")
156
+ f.write("Sequences per file:\n")
157
+ for file_name, seq_ids in file_info.items():
158
+ f.write(f" - {file_name}: {len(seq_ids)} sequences\n")
159
+ output_files.append(summary_path)
160
+
161
+ status_message = "\n".join(status_updates)
162
+ status_message += f"\n\nSuccessfully generated embeddings for {len(all_embeddings)} sequences!"
163
+
164
+ return output_files, status_message
165
+
166
+ except Exception as e:
167
+ raise gr.Error(f"Error processing sequences: {str(e)}")
168
+
169
+
170
+ # Create Gradio interface
171
+ with gr.Blocks(title="ESM2 Protein Embeddings") as demo:
172
+ gr.Markdown("""
173
+ # ESM2 Protein Sequence Embeddings
174
+
175
+ Generate embeddings for protein sequences using Meta's ESM2 language model.
176
+
177
+ **Features:**
178
+ - Process one or multiple FASTA files
179
+ - Generate high-dimensional embeddings (1280-D) using ESM2-650M
180
+ - Download embeddings in NumPy format or as JSON metadata
181
+ - Supports batch processing for efficiency
182
+
183
+ **Instructions:**
184
+ 1. Upload one or more FASTA files containing protein sequences
185
+ 2. Click "Generate Embeddings"
186
+ 3. Download the output files (embeddings.npz, metadata.json, summary.txt)
187
+
188
+ **Output Files:**
189
+ - `embeddings.npz`: Compressed NumPy file with all embeddings
190
+ - `metadata.json`: JSON file with sequence IDs and metadata
191
+ - `summary.txt`: Human-readable summary
192
+ - `embeddings_[filename].npz`: Per-file embeddings
193
+ """)
194
+
195
+ with gr.Row():
196
+ with gr.Column():
197
+ input_files = gr.File(
198
+ label="Upload FASTA files",
199
+ file_count="multiple",
200
+ file_types=[".fasta", ".fa", ".faa"]
201
+ )
202
+ submit_btn = gr.Button("Generate Embeddings", variant="primary", size="lg")
203
+
204
+ with gr.Column():
205
+ status_output = gr.Textbox(
206
+ label="Processing Status",
207
+ interactive=False,
208
+ lines=6
209
+ )
210
+
211
+ with gr.Row():
212
+ download_output = gr.File(
213
+ label="Download Output Files",
214
+ file_count="multiple"
215
+ )
216
+
217
+ submit_btn.click(
218
+ fn=process_embeddings,
219
+ inputs=[input_files],
220
+ outputs=[download_output, status_output]
221
+ )
222
+
223
+ gr.Markdown("""
224
+ ### How to use the embeddings:
225
+
226
+ ```python
227
+ import numpy as np
228
+ import json
229
+
230
+ # Load embeddings
231
+ embeddings = np.load('embeddings.npz')
232
+
233
+ # Access a specific embedding
234
+ embedding = embeddings['file_name_sequence_id']
235
+
236
+ # Load metadata
237
+ with open('metadata.json', 'r') as f:
238
+ metadata = json.load(f)
239
+ ```
240
+ """)
241
+
242
+
243
+ if __name__ == "__main__":
244
+ demo.launch()
245
 
 
 
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ gradio==4.26.0
2
+ torch>=2.0.0
3
+ fair-esm>=2.0.0
4
+ biopython>=1.81
5
+ numpy>=1.21.0