Spaces:
Running
Running
replace fair-esm model access with huggingface hub, modularize and simplify post-processing
Browse files- app.py +37 -157
- requirements.txt +3 -1
- utils/download_models.py +157 -0
- utils/handle_files.py +90 -0
- utils/pipelines.py +113 -0
app.py
CHANGED
|
@@ -8,164 +8,20 @@ import json
|
|
| 8 |
from pathlib import Path
|
| 9 |
import zipfile
|
| 10 |
import spaces
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
-
|
| 13 |
-
print("Loading ESM2 model...")
|
| 14 |
-
import esm
|
| 15 |
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
print(f"Loading {model_name} from HuggingFace...")
|
| 22 |
-
model, alphabet = esm.pretrained.load_model_and_alphabet_hub(model_name)
|
| 23 |
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
model = model.to(device)
|
| 27 |
-
batch_converter = alphabet.get_batch_converter()
|
| 28 |
-
|
| 29 |
-
print(f"Model loaded on {device}")
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
def parse_fasta_files(fasta_files):
|
| 33 |
-
"""Parse one or multiple FASTA files and return sequences."""
|
| 34 |
-
sequences = []
|
| 35 |
-
file_info = {}
|
| 36 |
-
|
| 37 |
-
for fasta_file in fasta_files:
|
| 38 |
-
file_name = Path(fasta_file.name).stem
|
| 39 |
-
file_seqs = []
|
| 40 |
-
|
| 41 |
-
try:
|
| 42 |
-
for record in SeqIO.parse(fasta_file, "fasta"):
|
| 43 |
-
sequences.append((record.id, str(record.seq), file_name))
|
| 44 |
-
file_seqs.append(record.id)
|
| 45 |
-
file_info[file_name] = file_seqs
|
| 46 |
-
except Exception as e:
|
| 47 |
-
raise ValueError(f"Error parsing {fasta_file.name}: {str(e)}")
|
| 48 |
-
|
| 49 |
-
if not sequences:
|
| 50 |
-
raise ValueError("No sequences found in the provided FASTA files.")
|
| 51 |
-
|
| 52 |
-
return sequences, file_info
|
| 53 |
-
|
| 54 |
-
@spaces.GPU(duration=240)
|
| 55 |
-
def generate_embeddings(sequences_batch):
|
| 56 |
-
"""Generate embeddings for a batch of sequences."""
|
| 57 |
-
# Prepare batch for ESM2
|
| 58 |
-
batch_labels, batch_strs, batch_tokens = batch_converter(sequences_batch)
|
| 59 |
-
|
| 60 |
-
# Move to device
|
| 61 |
-
batch_tokens = batch_tokens.to(device)
|
| 62 |
-
|
| 63 |
-
# Generate embeddings
|
| 64 |
-
with torch.no_grad():
|
| 65 |
-
results = model(batch_tokens, repr_layers=[33], return_contacts=False)
|
| 66 |
-
|
| 67 |
-
# Extract embeddings (token representations from layer 33)
|
| 68 |
-
token_embeddings = results["representations"][33]
|
| 69 |
-
|
| 70 |
-
# Get sequence-level embeddings (mean pooling of token embeddings, excluding special tokens)
|
| 71 |
-
sequence_embeddings = []
|
| 72 |
-
for i, (label, seq) in enumerate(zip(batch_labels, batch_strs)):
|
| 73 |
-
# Remove special tokens (first and last)
|
| 74 |
-
seq_embedding = token_embeddings[i, 1:len(seq) + 1].mean(dim=0)
|
| 75 |
-
sequence_embeddings.append(seq_embedding.cpu().numpy())
|
| 76 |
-
|
| 77 |
-
return sequence_embeddings
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
def process_embeddings(fasta_files):
|
| 81 |
-
"""Main function to process FASTA files and generate embeddings."""
|
| 82 |
-
try:
|
| 83 |
-
# Parse FASTA files
|
| 84 |
-
sequences, file_info = parse_fasta_files(fasta_files)
|
| 85 |
-
|
| 86 |
-
# Generate embeddings in batches
|
| 87 |
-
batch_size = 8
|
| 88 |
-
all_embeddings = {}
|
| 89 |
-
status_updates = [f"Processing {len(sequences)} sequences from {len(file_info)} file(s)..."]
|
| 90 |
-
|
| 91 |
-
for i in range(0, len(sequences), batch_size):
|
| 92 |
-
batch = sequences[i:i + batch_size]
|
| 93 |
-
batch_labels = [(seq_id, seq, file_name) for seq_id, seq, file_name in batch]
|
| 94 |
-
|
| 95 |
-
status_updates.append(f"Generating embeddings for sequences {i + 1}-{min(i + batch_size, len(sequences))}...")
|
| 96 |
-
|
| 97 |
-
# Generate embeddings
|
| 98 |
-
embeddings = generate_embeddings([(label, seq) for label, seq, _ in batch_labels])
|
| 99 |
-
|
| 100 |
-
# Store embeddings
|
| 101 |
-
for (seq_id, seq, file_name), embedding in zip(batch_labels, embeddings):
|
| 102 |
-
key = f"{file_name}_{seq_id}"
|
| 103 |
-
all_embeddings[key] = {
|
| 104 |
-
"sequence_id": seq_id,
|
| 105 |
-
"file": file_name,
|
| 106 |
-
"sequence_length": len(seq),
|
| 107 |
-
"embedding": embedding.tolist()
|
| 108 |
-
}
|
| 109 |
-
|
| 110 |
-
# Create output files
|
| 111 |
-
output_files = []
|
| 112 |
-
temp_dir = tempfile.mkdtemp()
|
| 113 |
-
|
| 114 |
-
# Save embeddings as NPZ (numpy compressed format)
|
| 115 |
-
npz_path = os.path.join(temp_dir, "embeddings.npz")
|
| 116 |
-
embeddings_array = {k: np.array(v["embedding"]) for k, v in all_embeddings.items()}
|
| 117 |
-
np.savez_compressed(npz_path, **embeddings_array)
|
| 118 |
-
output_files.append(npz_path)
|
| 119 |
-
status_updates.append(f"Saved compressed embeddings to embeddings.npz")
|
| 120 |
-
|
| 121 |
-
# Save metadata as JSON
|
| 122 |
-
metadata_path = os.path.join(temp_dir, "metadata.json")
|
| 123 |
-
metadata = {
|
| 124 |
-
"num_sequences": len(all_embeddings),
|
| 125 |
-
"embedding_dim": 1280, # ESM2-650M has 1280-dimensional embeddings
|
| 126 |
-
"model": model_name,
|
| 127 |
-
"sequences": {k: {
|
| 128 |
-
"sequence_id": v["sequence_id"],
|
| 129 |
-
"file": v["file"],
|
| 130 |
-
"sequence_length": v["sequence_length"]
|
| 131 |
-
} for k, v in all_embeddings.items()}
|
| 132 |
-
}
|
| 133 |
-
with open(metadata_path, 'w') as f:
|
| 134 |
-
json.dump(metadata, f, indent=2)
|
| 135 |
-
output_files.append(metadata_path)
|
| 136 |
-
status_updates.append(f"Saved metadata to metadata.json")
|
| 137 |
-
|
| 138 |
-
# Create per-file embedding files
|
| 139 |
-
for file_name in file_info.keys():
|
| 140 |
-
file_embeddings = {k: v for k, v in embeddings_array.items() if k.startswith(file_name)}
|
| 141 |
-
if file_embeddings:
|
| 142 |
-
file_npz_path = os.path.join(temp_dir, f"embeddings_{file_name}.npz")
|
| 143 |
-
np.savez_compressed(file_npz_path, **file_embeddings)
|
| 144 |
-
output_files.append(file_npz_path)
|
| 145 |
-
status_updates.append(f"Saved {len(file_embeddings)} embeddings for {file_name}")
|
| 146 |
-
|
| 147 |
-
# Create a summary report
|
| 148 |
-
summary_path = os.path.join(temp_dir, "summary.txt")
|
| 149 |
-
with open(summary_path, 'w') as f:
|
| 150 |
-
f.write("ESM2 Protein Sequence Embedding Summary\n")
|
| 151 |
-
f.write("=" * 50 + "\n\n")
|
| 152 |
-
f.write(f"Model: {model_name}\n")
|
| 153 |
-
f.write(f"Device: {device}\n")
|
| 154 |
-
f.write(f"Embedding Dimension: 1280\n\n")
|
| 155 |
-
f.write(f"Input Files: {len(file_info)}\n")
|
| 156 |
-
f.write(f"Total Sequences: {len(all_embeddings)}\n\n")
|
| 157 |
-
f.write("Sequences per file:\n")
|
| 158 |
-
for file_name, seq_ids in file_info.items():
|
| 159 |
-
f.write(f" - {file_name}: {len(seq_ids)} sequences\n")
|
| 160 |
-
output_files.append(summary_path)
|
| 161 |
-
|
| 162 |
-
status_message = "\n".join(status_updates)
|
| 163 |
-
status_message += f"\n\nSuccessfully generated embeddings for {len(all_embeddings)} sequences!"
|
| 164 |
-
|
| 165 |
-
return output_files, status_message
|
| 166 |
-
|
| 167 |
-
except Exception as e:
|
| 168 |
-
raise gr.Error(f"Error processing sequences: {str(e)}")
|
| 169 |
|
| 170 |
|
| 171 |
# Create Gradio interface
|
|
@@ -214,12 +70,36 @@ with gr.Blocks(title="ESM2 Protein Embeddings") as demo:
|
|
| 214 |
label="Download Output Files",
|
| 215 |
file_count="multiple"
|
| 216 |
)
|
| 217 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 218 |
submit_btn.click(
|
| 219 |
-
fn=
|
| 220 |
inputs=[input_files],
|
| 221 |
outputs=[download_output, status_output]
|
| 222 |
)
|
|
|
|
|
|
|
| 223 |
|
| 224 |
gr.Markdown("""
|
| 225 |
### How to use the embeddings:
|
|
|
|
| 8 |
from pathlib import Path
|
| 9 |
import zipfile
|
| 10 |
import spaces
|
| 11 |
+
from utils.download_models import *
|
| 12 |
+
from utils.handle_files import parse_fasta_files
|
| 13 |
+
from utils.pipelines import generate_embeddings, full_embedding_pipeline
|
| 14 |
|
| 15 |
+
print("Downloading ESM2 models...")
|
|
|
|
|
|
|
| 16 |
|
| 17 |
+
MODELS = {
|
| 18 |
+
"facebook/esm2_t6_8M_UR50D": "ESM2-8M",
|
| 19 |
+
"facebook/esm2_t12_35M_UR50D": "ESM2-35M",
|
| 20 |
+
#"esm2_t36_650M_UR50D": "ESM2-650M"
|
| 21 |
+
}
|
|
|
|
|
|
|
| 22 |
|
| 23 |
+
cache_dirs = cache_all_models(MODELS)
|
| 24 |
+
models_and_tokenizers = load_all_models(MODELS)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
|
| 26 |
|
| 27 |
# Create Gradio interface
|
|
|
|
| 70 |
label="Download Output Files",
|
| 71 |
file_count="multiple"
|
| 72 |
)
|
| 73 |
+
|
| 74 |
+
with gr.Row():
|
| 75 |
+
model_dropdown = gr.Dropdown(
|
| 76 |
+
choices=list(MODELS.values()),
|
| 77 |
+
value=list(MODELS.values())[0],
|
| 78 |
+
label="Select Model"
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
model_to_use = gr.State(value=models_and_tokenizers[model_dropdown.value][0])
|
| 83 |
+
tokenizer_to_use = gr.State(value=models_and_tokenizers[model_dropdown.value][1])
|
| 84 |
+
|
| 85 |
+
def pick_model(model_name):
|
| 86 |
+
model_key = [key for key, value in MODELS.items() if value == model_name][0]
|
| 87 |
+
print(f"Selected model: {model_name} ({model_key})")
|
| 88 |
+
return models_and_tokenizers[model_key]
|
| 89 |
+
|
| 90 |
+
model_dropdown.change(
|
| 91 |
+
fn=pick_model,
|
| 92 |
+
inputs=model_dropdown,
|
| 93 |
+
outputs=[model_to_use, tokenizer_to_use]
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
submit_btn.click(
|
| 97 |
+
fn=full_embedding_pipeline,
|
| 98 |
inputs=[input_files],
|
| 99 |
outputs=[download_output, status_output]
|
| 100 |
)
|
| 101 |
+
|
| 102 |
+
|
| 103 |
|
| 104 |
gr.Markdown("""
|
| 105 |
### How to use the embeddings:
|
requirements.txt
CHANGED
|
@@ -1,4 +1,6 @@
|
|
| 1 |
torch>=2.0.0
|
| 2 |
-
fair-esm>=2.0.0
|
| 3 |
biopython>=1.81
|
| 4 |
numpy>=1.21.0
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
torch>=2.0.0
|
|
|
|
| 2 |
biopython>=1.81
|
| 3 |
numpy>=1.21.0
|
| 4 |
+
huggingface_hub
|
| 5 |
+
transformers
|
| 6 |
+
|
utils/download_models.py
ADDED
|
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import esm
|
| 2 |
+
import torch
|
| 3 |
+
import huggingface_hub
|
| 4 |
+
from transformers import AutoTokenizer, AutoModel
|
| 5 |
+
|
| 6 |
+
def cache_model_weights(model_id):
|
| 7 |
+
"""
|
| 8 |
+
Download ESM2 model weights to cache without loading into memory. Called upon restarting of spaces to have weights ready to load once inference is called.
|
| 9 |
+
Downloading weights without
|
| 10 |
+
|
| 11 |
+
Parameters:
|
| 12 |
+
-----------
|
| 13 |
+
model_id : str
|
| 14 |
+
Model identifier (e.g., "facebook/esm2_t6_8M_UR50D")
|
| 15 |
+
|
| 16 |
+
Returns:
|
| 17 |
+
--------
|
| 18 |
+
str : Path to cached model directory
|
| 19 |
+
"""
|
| 20 |
+
cache_dir = huggingface_hub.snapshot_download(model_id)
|
| 21 |
+
print(f"Model {model_id} cached at: {cache_dir}")
|
| 22 |
+
return cache_dir
|
| 23 |
+
|
| 24 |
+
def cache_all_models(models):
|
| 25 |
+
"""
|
| 26 |
+
Cache all models in the provided dictionary.
|
| 27 |
+
|
| 28 |
+
Parameters:
|
| 29 |
+
-----------
|
| 30 |
+
models : dict
|
| 31 |
+
A dictionary where keys are model identifiers (e.g., "facebook/esm2_t6_8M_UR50D") and values are human-readable model names (e.g., "ESM2-8M").
|
| 32 |
+
|
| 33 |
+
Returns:
|
| 34 |
+
--------
|
| 35 |
+
dict : A dictionary mapping model identifiers to their cache directories.
|
| 36 |
+
"""
|
| 37 |
+
cache_dirs = {}
|
| 38 |
+
for model_id in models.keys():
|
| 39 |
+
cache_dirs[model_id] = cache_model_weights(model_id)
|
| 40 |
+
return cache_dirs
|
| 41 |
+
|
| 42 |
+
def load_model(model_id):
|
| 43 |
+
"""
|
| 44 |
+
Load ESM model and tokenizer using from_pretrained. Initializes from default cache directory or downloads if missing.
|
| 45 |
+
To be used after cache_model_weights for control over when models are downloaded
|
| 46 |
+
|
| 47 |
+
Parameters:
|
| 48 |
+
-----------
|
| 49 |
+
model_id : str
|
| 50 |
+
Model identifier (e.g., "facebook/esm2_t6_8M_UR50D")
|
| 51 |
+
|
| 52 |
+
Returns:
|
| 53 |
+
--------
|
| 54 |
+
tuple : (model, tokenized) loaded from cache
|
| 55 |
+
"""
|
| 56 |
+
try:
|
| 57 |
+
print(f"Loading {model_id} from local cache...")
|
| 58 |
+
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
| 59 |
+
model = AutoModel.from_pretrained(
|
| 60 |
+
model_id,
|
| 61 |
+
output_hidden_states=True,
|
| 62 |
+
)
|
| 63 |
+
except Exception as e:
|
| 64 |
+
raise RuntimeError(f"Failed to load model {model_id} from cache: {e}")
|
| 65 |
+
|
| 66 |
+
model = model.eval()
|
| 67 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 68 |
+
model = model.to(device)
|
| 69 |
+
print(f"{model_id} loaded on {device}")
|
| 70 |
+
return model, tokenizer
|
| 71 |
+
|
| 72 |
+
def load_all_models(models):
|
| 73 |
+
"""
|
| 74 |
+
Load all models in the provided dictionary.
|
| 75 |
+
|
| 76 |
+
Parameters:
|
| 77 |
+
-----------
|
| 78 |
+
models : dict
|
| 79 |
+
A dictionary where keys are model identifiers (e.g., "facebook/esm2_t6_8M_UR50D") and values are human-readable model names (e.g., "ESM2-8M").
|
| 80 |
+
|
| 81 |
+
Returns:
|
| 82 |
+
--------
|
| 83 |
+
dict : A dictionary mapping model identifiers to their loaded (model, tokenizer) tuples.
|
| 84 |
+
"""
|
| 85 |
+
loaded_models = {}
|
| 86 |
+
for model_id in models.keys():
|
| 87 |
+
loaded_models[model_id] = load_model(model_id)
|
| 88 |
+
return loaded_models
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
#def cache_models(models):
|
| 92 |
+
# """
|
| 93 |
+
# Download weights to ESM models in cache to be loaded later.
|
| 94 |
+
# We do not load the models into memory at this stage to avoid using GPU memory for models that are not used in the current session.
|
| 95 |
+
#
|
| 96 |
+
# Parameters:
|
| 97 |
+
# ----------
|
| 98 |
+
# models: dict
|
| 99 |
+
# A dictionary where keys are model identifiers (e.g., "esm2_t6_8M_UR50D") and values are human-readable model names (e.g., "ESM2-8M").
|
| 100 |
+
#
|
| 101 |
+
# Returns:
|
| 102 |
+
# -------
|
| 103 |
+
#
|
| 104 |
+
# """
|
| 105 |
+
# loaded_models = {}
|
| 106 |
+
# for model_id, model_name in models.items():
|
| 107 |
+
# print(f"Loading {model_name}...")
|
| 108 |
+
# try:
|
| 109 |
+
# #load from local cache if avilable, upon startup of space will fail and load from HF
|
| 110 |
+
# model, alphabet = esm.pretrained.load_model_and_alphabet_local(model_id)
|
| 111 |
+
# except:
|
| 112 |
+
# print(f"Loading {model_name} from HuggingFace...")
|
| 113 |
+
# model, alphabet = esm.pretrained.load_model_and_alphabet_hub(model_id)
|
| 114 |
+
#
|
| 115 |
+
# model = model.eval()
|
| 116 |
+
# device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 117 |
+
# model = model.to(device)
|
| 118 |
+
# loaded_models[model_id] = {
|
| 119 |
+
# "model": model,
|
| 120 |
+
# "alphabet": alphabet,
|
| 121 |
+
# "batch_converter": alphabet.get_batch_converter()
|
| 122 |
+
# }
|
| 123 |
+
# print(f"{model_name} loaded on {device}")
|
| 124 |
+
#
|
| 125 |
+
#def download_models(models):
|
| 126 |
+
# """
|
| 127 |
+
# Download weights to ESM models in cache to be loaded later.
|
| 128 |
+
# We do not load the models into memory at this stage to avoid using GPU memory for models that are not used in the current session.
|
| 129 |
+
#
|
| 130 |
+
# Parameters:
|
| 131 |
+
# ----------
|
| 132 |
+
# models: dict
|
| 133 |
+
# A dictionary where keys are model identifiers (e.g., "esm2_t6_8M_UR50D") and values are human-readable model names (e.g., "ESM2-8M").
|
| 134 |
+
#
|
| 135 |
+
# Returns:
|
| 136 |
+
# -------
|
| 137 |
+
#
|
| 138 |
+
# """
|
| 139 |
+
# loaded_models = {}
|
| 140 |
+
# for model_id, model_name in models.items():
|
| 141 |
+
# print(f"Loading {model_name}...")
|
| 142 |
+
# try:
|
| 143 |
+
# #load from local cache if avilable, upon startup of space will fail and load from HF
|
| 144 |
+
# model, alphabet = esm.pretrained.load_model_and_alphabet_local(model_id)
|
| 145 |
+
# except:
|
| 146 |
+
# print(f"Loading {model_name} from HuggingFace...")
|
| 147 |
+
# model, alphabet = esm.pretrained.load_model_and_alphabet_hub(model_id)
|
| 148 |
+
#
|
| 149 |
+
# model = model.eval()
|
| 150 |
+
# device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 151 |
+
# model = model.to(device)
|
| 152 |
+
# loaded_models[model_id] = {
|
| 153 |
+
# "model": model,
|
| 154 |
+
# "alphabet": alphabet,
|
| 155 |
+
# "batch_converter": alphabet.get_batch_converter()
|
| 156 |
+
# }
|
| 157 |
+
# print(f"{model_name} loaded on {device}")
|
utils/handle_files.py
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
from Bio import SeqIO
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def parse_fasta_files(fasta_files):
|
| 6 |
+
"""Parse one or multiple FASTA files and return sequences.
|
| 7 |
+
This function uses the entire header line as sequence_id to deal with LigandMPNN's omittance of a unique sequence ID at the beginning of the header.
|
| 8 |
+
|
| 9 |
+
Parameters:
|
| 10 |
+
-----------
|
| 11 |
+
fasta_files : list of str
|
| 12 |
+
List of paths to FASTA files to be parsed.
|
| 13 |
+
Returns:
|
| 14 |
+
--------
|
| 15 |
+
sequences : list of tuples
|
| 16 |
+
A list of tuples containing (sequence_id, sequence, file_name) for each sequence found in the FASTA files.
|
| 17 |
+
file_info : dict
|
| 18 |
+
A dictionary mapping file names to lists of sequence IDs contained in each file.
|
| 19 |
+
"""
|
| 20 |
+
sequences = []
|
| 21 |
+
file_info = {}
|
| 22 |
+
|
| 23 |
+
for fasta_file in fasta_files:
|
| 24 |
+
print(fasta_file)
|
| 25 |
+
if fasta_file.endswith('.fasta') or fasta_file.endswith('.fa'):
|
| 26 |
+
file_name = Path(fasta_file).stem
|
| 27 |
+
file_seqs = []
|
| 28 |
+
|
| 29 |
+
try:
|
| 30 |
+
for record in SeqIO.parse(fasta_file, "fasta"):
|
| 31 |
+
# Use the entire header as the sequence ID
|
| 32 |
+
full_header = record.description # Full header line without '>'
|
| 33 |
+
sequences.append((full_header, str(record.seq), file_name))
|
| 34 |
+
file_seqs.append(full_header)
|
| 35 |
+
file_info[file_name] = file_seqs
|
| 36 |
+
except Exception as e:
|
| 37 |
+
raise ValueError(f"Error parsing {fasta_file.name}: {str(e)}")
|
| 38 |
+
|
| 39 |
+
if not sequences:
|
| 40 |
+
raise ValueError("No sequences found in the provided FASTA files.")
|
| 41 |
+
|
| 42 |
+
return sequences, file_info
|
| 43 |
+
|
| 44 |
+
def parse_fasta_files_from_ligandmpnn(fasta_files):
|
| 45 |
+
"""Parse one or multiple FASTA files and return sequences. These files are expected to be in the format generated by LigandMPNN.
|
| 46 |
+
In these fasta files, there is no sequence_id in the header, It's the name of the file + some info on generated sequence quality + the number of the designs "id=0"
|
| 47 |
+
Hence special parsing is needed to extract the sequence_id from the header.
|
| 48 |
+
|
| 49 |
+
Parameters:
|
| 50 |
+
-----------
|
| 51 |
+
fasta_files : list of str
|
| 52 |
+
List of paths to FASTA files to be parsed.
|
| 53 |
+
Returns:
|
| 54 |
+
--------
|
| 55 |
+
sequences : list of tuples
|
| 56 |
+
A list of tuples containing (sequence_id, sequence, file_name) for each sequence found in the FASTA files.
|
| 57 |
+
file_info : dict
|
| 58 |
+
A dictionary mapping file names to lists of sequence IDs contained in each file.
|
| 59 |
+
"""
|
| 60 |
+
sequences = []
|
| 61 |
+
file_info = {}
|
| 62 |
+
|
| 63 |
+
for fasta_file in fasta_files:
|
| 64 |
+
print(fasta_file)
|
| 65 |
+
if fasta_file.endswith('.fasta') or fasta_file.endswith('.fa'):
|
| 66 |
+
file_name = Path(fasta_file).stem
|
| 67 |
+
file_seqs = []
|
| 68 |
+
|
| 69 |
+
try:
|
| 70 |
+
for record in SeqIO.parse(fasta_file, "fasta"):
|
| 71 |
+
# Extract id from description if it contains id=
|
| 72 |
+
seq_id = record.id
|
| 73 |
+
if "id=" in record.description:
|
| 74 |
+
# Parse the description to find id=...
|
| 75 |
+
parts = record.description.split()
|
| 76 |
+
for part in parts:
|
| 77 |
+
if part.startswith("id="):
|
| 78 |
+
seq_id = part[3:] # Remove "id=" prefix
|
| 79 |
+
break
|
| 80 |
+
|
| 81 |
+
sequences.append((seq_id, str(record.seq), file_name))
|
| 82 |
+
file_seqs.append(seq_id)
|
| 83 |
+
file_info[file_name] = file_seqs
|
| 84 |
+
except Exception as e:
|
| 85 |
+
raise ValueError(f"Error parsing {fasta_file.name}: {str(e)}")
|
| 86 |
+
|
| 87 |
+
if not sequences:
|
| 88 |
+
raise ValueError("No sequences found in the provided FASTA files.")
|
| 89 |
+
|
| 90 |
+
return sequences, file_info
|
utils/pipelines.py
ADDED
|
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import spaces
|
| 2 |
+
import torch
|
| 3 |
+
import spaces
|
| 4 |
+
import numpy as np
|
| 5 |
+
from utils.handle_files import parse_fasta_files
|
| 6 |
+
import gradio as gr
|
| 7 |
+
import time
|
| 8 |
+
import random
|
| 9 |
+
import os
|
| 10 |
+
|
| 11 |
+
@spaces(duration=240)
|
| 12 |
+
def generate_embeddings(sequences_batch, model, tokenizer):
|
| 13 |
+
"""Generate embeddings for ESM models using the transformers library.
|
| 14 |
+
|
| 15 |
+
Parameters:
|
| 16 |
+
-----------
|
| 17 |
+
sequences_batch : list of str
|
| 18 |
+
A batch of sequences for which to generate embeddings.
|
| 19 |
+
model : AutoModel
|
| 20 |
+
The pre-loaded ESM model. must already be on the correct device (CPU or GPU).
|
| 21 |
+
tokenizer : AutoTokenizer
|
| 22 |
+
The pre-loaded tokenizer corresponding to the ESM model.
|
| 23 |
+
|
| 24 |
+
Returns:
|
| 25 |
+
--------
|
| 26 |
+
sequence_embeddings : 2D np.array of shape (batch_size, embedding_dim)
|
| 27 |
+
A list of sequence-level embeddings (mean-pooled) for each input sequence.
|
| 28 |
+
"""
|
| 29 |
+
# Tokenize sequences
|
| 30 |
+
device = model.device
|
| 31 |
+
tokens = tokenizer(
|
| 32 |
+
sequences_batch,
|
| 33 |
+
return_tensors="pt",
|
| 34 |
+
padding=True,
|
| 35 |
+
truncation=True,
|
| 36 |
+
add_special_tokens=True
|
| 37 |
+
).to(device)
|
| 38 |
+
|
| 39 |
+
# Generate embeddings
|
| 40 |
+
with torch.no_grad():
|
| 41 |
+
results = model(**tokens)
|
| 42 |
+
|
| 43 |
+
# Extract hidden states from last layer
|
| 44 |
+
token_embeddings = results.hidden_states[-1] # Last layer embeddings
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
# Get sequence-level embeddings (mean pooling, excluding special tokens)
|
| 48 |
+
sequence_embeddings = []
|
| 49 |
+
for i, seq in enumerate(sequences_batch):
|
| 50 |
+
# Remove special tokens (first and last)
|
| 51 |
+
seq_embedding = token_embeddings[i, 1:len(seq) + 1].mean(dim=0)
|
| 52 |
+
# this might seem inefficient compared to token_embeddings[:,1:seq_len+1,:].mean...
|
| 53 |
+
# but it is necessary to account for variable sequence lengths and ensure we only average over the actual sequence tokens, not the padding or special tokens.
|
| 54 |
+
sequence_embeddings.append(seq_embedding.cpu().numpy())
|
| 55 |
+
|
| 56 |
+
return np.array(sequence_embeddings)
|
| 57 |
+
|
| 58 |
+
def full_embedding_pipeline(fasta_files, model, tokenizer, batch_size):
|
| 59 |
+
"""Full pipeline to process FASTA files and generate embeddings from desired model.
|
| 60 |
+
|
| 61 |
+
Parameters:
|
| 62 |
+
-----------
|
| 63 |
+
fasta_files : list of str, obtained from gradio file input
|
| 64 |
+
List of paths to FASTA files to be parsed.
|
| 65 |
+
model : AutoModel
|
| 66 |
+
The pre-loaded ESM model. must already be on the correct device (CPU or GPU).
|
| 67 |
+
tokenizer : AutoTokenizer
|
| 68 |
+
The pre-loaded tokenizer corresponding to the ESM model.
|
| 69 |
+
batch_size : int
|
| 70 |
+
The number of sequences to process in each batch when generating embeddings.
|
| 71 |
+
|
| 72 |
+
Returns:
|
| 73 |
+
--------
|
| 74 |
+
all_file_paths : list of str
|
| 75 |
+
List of file paths where the per-file embeddings were saved. To be passed to gradio for download.
|
| 76 |
+
status_string : str
|
| 77 |
+
A string summarizing the processing steps and output files generated, to be displayed in the gradio interface.
|
| 78 |
+
"""
|
| 79 |
+
# Parse FASTA files
|
| 80 |
+
sequences_info, file_info = parse_fasta_files(fasta_files)
|
| 81 |
+
|
| 82 |
+
# Generate embeddings in batches
|
| 83 |
+
all_embeddings = []
|
| 84 |
+
n_batches = (len(sequences_info) + batch_size - 1) // batch_size
|
| 85 |
+
status_string = f"Processing {len(sequences_info)} sequences from {len(file_info)} file(s) in {n_batches} batches of {batch_size} sequences...\n"
|
| 86 |
+
|
| 87 |
+
for i in range(0, len(sequences_info), batch_size):
|
| 88 |
+
batch = sequences_info[i:i + batch_size]
|
| 89 |
+
batch_sequences = [seq for _, seq, _ in batch]
|
| 90 |
+
|
| 91 |
+
embeddings = generate_embeddings(batch_sequences, model, tokenizer)
|
| 92 |
+
status_string += f"Generated {len(embeddings)} embeddings for batch {i // batch_size + 1}/{n_batches}\n"
|
| 93 |
+
all_embeddings.extend(embeddings)
|
| 94 |
+
|
| 95 |
+
status_string += f"Generated embeddings for all {len(sequences_info)} sequences.\n"
|
| 96 |
+
unique_files = file_info.keys()
|
| 97 |
+
session_hash = random.getrandbits(128) # Generate a random hash for this session
|
| 98 |
+
time_stamp = time.strftime("%Y-%m-%d-%H-%M-%S")
|
| 99 |
+
out_dir = f"./outputs/unconditional_generation/session_{session_hash}_{time_stamp}"
|
| 100 |
+
os.makedirs(out_dir, exist_ok=True)
|
| 101 |
+
all_file_paths = []
|
| 102 |
+
for file_name in unique_files:
|
| 103 |
+
indices = [i for i, (_, _, f) in enumerate(sequences_info) if f == file_name]
|
| 104 |
+
file_embeddings = np.array([all_embeddings[i] for i in indices])
|
| 105 |
+
sequence_ids = [sequences_info[i][0] for i in indices] # Extract sequence IDs for this file
|
| 106 |
+
file_path = os.path.join(out_dir, f"{file_name}_embeddings.npz")
|
| 107 |
+
np.savez_compressed(file_path, embeddings=file_embeddings, sequence_ids=sequence_ids)
|
| 108 |
+
status_string += f"Saved compressed embeddings to {file_name}_embeddings.npz\n"
|
| 109 |
+
all_file_paths.append(file_path)
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
return all_file_paths, all_embeddings
|
| 113 |
+
|