Spaces:
Sleeping
Sleeping
Upload generator.py
Browse files- scripts/generator.py +68 -81
scripts/generator.py
CHANGED
|
@@ -5,131 +5,118 @@ import warnings
|
|
| 5 |
import sys
|
| 6 |
import shutil
|
| 7 |
|
| 8 |
-
def
|
| 9 |
"""
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
Args:
|
| 13 |
-
pdb_path: Path to the target complex (e.g., 'data/3KAS.pdb').
|
| 14 |
-
fixed_chains: Chains to remain unchanged (e.g., 'A'). Empty for single-chain proteins.
|
| 15 |
-
variable_chains: Chains to be redesigned/repainted (e.g., 'B'). For single-chain, this is the only chain.
|
| 16 |
"""
|
| 17 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
pdb_name = os.path.basename(pdb_path).split('.')[0]
|
| 19 |
output_dir = f"./generated/{pdb_name}"
|
| 20 |
os.makedirs(output_dir, exist_ok=True)
|
| 21 |
|
| 22 |
-
# Get the project root directory (where ProteinMPNN should be)
|
| 23 |
script_dir = os.path.dirname(os.path.abspath(__file__))
|
| 24 |
project_root = os.path.dirname(script_dir)
|
| 25 |
proteinmpnn_dir = os.path.join(project_root, "ProteinMPNN")
|
| 26 |
|
| 27 |
-
# Clone ProteinMPNN if it doesn't exist (for HuggingFace Spaces deployment)
|
| 28 |
if not os.path.exists(proteinmpnn_dir):
|
| 29 |
print("ProteinMPNN not found, cloning repository...")
|
| 30 |
-
subprocess.run(
|
| 31 |
-
["git", "clone", "https://github.com/dauparas/ProteinMPNN.git"],
|
| 32 |
-
cwd=project_root,
|
| 33 |
-
check=True,
|
| 34 |
-
stdout=subprocess.DEVNULL,
|
| 35 |
-
stderr=subprocess.DEVNULL
|
| 36 |
-
)
|
| 37 |
|
| 38 |
mpnn_script = os.path.join(proteinmpnn_dir, "protein_mpnn_run.py")
|
| 39 |
|
| 40 |
-
# 2.
|
| 41 |
if not fixed_chains or len(fixed_chains) == 0:
|
| 42 |
-
# Single-chain protein: use direct PDB path command
|
| 43 |
-
# For single-chain, variable_chains should be the only chain (e.g., "A")
|
| 44 |
chain_to_design = variable_chains[0] if variable_chains else "A"
|
| 45 |
-
|
| 46 |
mpnn_cmd = (
|
| 47 |
-
f"python -W ignore {mpnn_script} "
|
| 48 |
-
f"--
|
| 49 |
-
f"--pdb_path_chains {chain_to_design} "
|
| 50 |
-
f"--out_folder {output_dir} "
|
| 51 |
-
f"--num_seq_per_target {num_seqs} "
|
| 52 |
-
f"--sampling_temp {temp} "
|
| 53 |
-
f"--seed 42 "
|
| 54 |
-
f"--batch_size 1"
|
| 55 |
)
|
| 56 |
-
|
| 57 |
-
print(f"🚀 Designing sequences for {pdb_name} (single-chain mode)...")
|
| 58 |
-
print(f"✏️ Redesigning chain: {chain_to_design}")
|
| 59 |
else:
|
| 60 |
-
# Multi-chain
|
| 61 |
-
|
| 62 |
-
pdb_dir = os.path.dirname(os.path.abspath(pdb_path))
|
| 63 |
-
if not pdb_dir:
|
| 64 |
-
pdb_dir = "."
|
| 65 |
jsonl_path = os.path.join(output_dir, "parsed_pdbs.jsonl")
|
| 66 |
-
|
| 67 |
parse_script = os.path.join(proteinmpnn_dir, "helper_scripts", "parse_multiple_chains.py")
|
| 68 |
|
| 69 |
-
|
| 70 |
-
subprocess.run(
|
| 71 |
|
| 72 |
-
#
|
| 73 |
pdb_name_clones = f"{pdb_name}_clones"
|
|
|
|
| 74 |
with open(jsonl_path, 'r') as f:
|
| 75 |
jsonl_data = json.loads(f.readline())
|
| 76 |
jsonl_data['name'] = pdb_name_clones
|
| 77 |
with open(jsonl_path, 'w') as f:
|
| 78 |
f.write(json.dumps(jsonl_data) + '\n')
|
| 79 |
|
| 80 |
-
# 3. Generate the Chain Configuration JSONs (The 'Engine' Logic)
|
| 81 |
-
# Format: {"name": [masked_chains_list, visible_chains_list]}
|
| 82 |
-
# masked_chains = chains to redesign, visible_chains = chains to keep fixed
|
| 83 |
-
masked_chains_list = [c for c in variable_chains]
|
| 84 |
-
visible_chains_list = [c for c in fixed_chains]
|
| 85 |
-
chain_id_dict = {pdb_name_clones: [masked_chains_list, visible_chains_list]}
|
| 86 |
-
|
| 87 |
chain_id_json = os.path.join(output_dir, "chain_id_dict.json")
|
|
|
|
| 88 |
with open(chain_id_json, 'w') as f:
|
| 89 |
json.dump(chain_id_dict, f)
|
| 90 |
|
| 91 |
-
#
|
|
|
|
|
|
|
|
|
|
| 92 |
mpnn_cmd = (
|
| 93 |
-
f"python -W ignore {mpnn_script} "
|
| 94 |
-
f"--
|
| 95 |
-
f"--chain_id_jsonl {chain_id_json} "
|
| 96 |
-
f"--out_folder {output_dir} "
|
| 97 |
-
f"--num_seq_per_target {num_seqs} "
|
| 98 |
-
f"--sampling_temp {temp} "
|
| 99 |
-
f"--seed 42"
|
| 100 |
)
|
| 101 |
-
|
| 102 |
-
print(f"🚀 Designing sequences for {pdb_name}...")
|
| 103 |
-
print(f"🔒 Fixed: {fixed_chains} | ✏️ Redesigning: {variable_chains}")
|
| 104 |
|
| 105 |
-
#
|
| 106 |
env = os.environ.copy()
|
| 107 |
env['PYTHONWARNINGS'] = 'ignore'
|
| 108 |
subprocess.run(mpnn_cmd, shell=True, check=True, env=env, stderr=subprocess.DEVNULL)
|
| 109 |
|
| 110 |
-
|
| 111 |
-
# Rename it to {pdb_name}_clones.fa for consistency
|
| 112 |
-
if not fixed_chains or len(fixed_chains) == 0:
|
| 113 |
-
seqs_dir = os.path.join(output_dir, "seqs")
|
| 114 |
-
old_file = os.path.join(seqs_dir, f"{pdb_name}.fa")
|
| 115 |
-
new_file = os.path.join(seqs_dir, f"{pdb_name}_clones.fa")
|
| 116 |
-
if os.path.exists(old_file) and not os.path.exists(new_file):
|
| 117 |
-
os.rename(old_file, new_file)
|
| 118 |
-
print(f"📝 Renamed {pdb_name}.fa → {pdb_name}_clones.fa")
|
| 119 |
-
|
| 120 |
-
print(f"✅ Success! Fold the top sequences at https://esmatlas.com/resources?action=fold")
|
| 121 |
|
| 122 |
if __name__ == "__main__":
|
| 123 |
-
import sys
|
| 124 |
if len(sys.argv) < 4:
|
| 125 |
print("Usage: python scripts/generator.py <pdb_path> <fixed_chains> <variable_chains> [num_seqs] [temp]")
|
| 126 |
-
print("Example: python scripts/generator.py data/3kas.pdb A B")
|
| 127 |
sys.exit(1)
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
temp = float(sys.argv[5]) if len(sys.argv) > 5 else 0.1
|
| 134 |
-
|
| 135 |
-
run_broteinshake_generator(pdb_path, fixed_chains, variable_chains, num_seqs, temp)
|
|
|
|
| 5 |
import sys
|
| 6 |
import shutil
|
| 7 |
|
| 8 |
+
def sync_protein_metadata(jsonl_path, dict_path):
|
| 9 |
"""
|
| 10 |
+
Automated metadata sanitization to prevent KeyError (e.g., 'seq_chain_C').
|
| 11 |
+
Prunes non-proteogenic chain IDs from the dictionary before the run.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
"""
|
| 13 |
+
if not os.path.exists(jsonl_path) or not os.path.exists(dict_path):
|
| 14 |
+
return
|
| 15 |
+
|
| 16 |
+
# 1. Identify chains that actually have proteogenic sequence data
|
| 17 |
+
valid_chains_map = {}
|
| 18 |
+
with open(jsonl_path, 'r') as f:
|
| 19 |
+
for line in f:
|
| 20 |
+
entry = json.loads(line)
|
| 21 |
+
name = entry['name']
|
| 22 |
+
# Only keep chains that have a 'seq_chain_X' entry in the JSONL
|
| 23 |
+
valid = {k.split('_')[-1] for k in entry.keys() if k.startswith('seq_chain_')}
|
| 24 |
+
valid_chains_map[name] = valid
|
| 25 |
+
|
| 26 |
+
# 2. Clean the chain ID dictionary
|
| 27 |
+
with open(dict_path, 'r') as f:
|
| 28 |
+
chain_id_dict = json.load(f)
|
| 29 |
+
|
| 30 |
+
for pdb_name, configs in chain_id_dict.items():
|
| 31 |
+
if pdb_name in valid_chains_map:
|
| 32 |
+
valid = valid_chains_map[pdb_name]
|
| 33 |
+
# configs[0] = redesign list, configs[1] = fixed list
|
| 34 |
+
original_chains = set(configs[0] + configs[1])
|
| 35 |
+
|
| 36 |
+
chain_id_dict[pdb_name] = [
|
| 37 |
+
[c for c in configs[0] if c in valid],
|
| 38 |
+
[c for c in configs[1] if c in valid]
|
| 39 |
+
]
|
| 40 |
+
|
| 41 |
+
# Diagnostic feedback
|
| 42 |
+
removed = original_chains - valid
|
| 43 |
+
if removed:
|
| 44 |
+
print(f"🧹 Sanitizer: Pruned non-protein chains from metadata: {removed}")
|
| 45 |
+
|
| 46 |
+
# 3. Overwrite with cleaned metadata for ProteinMPNN
|
| 47 |
+
with open(dict_path, 'w') as f:
|
| 48 |
+
json.dump(chain_id_dict, f)
|
| 49 |
+
|
| 50 |
+
def run_broteinshake_generator(pdb_path, fixed_chains, variable_chains, num_seqs=20, temp=0.1):
|
| 51 |
+
# 1. Setup identifiers and directories
|
| 52 |
pdb_name = os.path.basename(pdb_path).split('.')[0]
|
| 53 |
output_dir = f"./generated/{pdb_name}"
|
| 54 |
os.makedirs(output_dir, exist_ok=True)
|
| 55 |
|
|
|
|
| 56 |
script_dir = os.path.dirname(os.path.abspath(__file__))
|
| 57 |
project_root = os.path.dirname(script_dir)
|
| 58 |
proteinmpnn_dir = os.path.join(project_root, "ProteinMPNN")
|
| 59 |
|
|
|
|
| 60 |
if not os.path.exists(proteinmpnn_dir):
|
| 61 |
print("ProteinMPNN not found, cloning repository...")
|
| 62 |
+
subprocess.run(["git", "clone", "https://github.com/dauparas/ProteinMPNN.git"], cwd=project_root, check=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
|
| 64 |
mpnn_script = os.path.join(proteinmpnn_dir, "protein_mpnn_run.py")
|
| 65 |
|
| 66 |
+
# 2. Handle Single vs Multi-Chain Logic
|
| 67 |
if not fixed_chains or len(fixed_chains) == 0:
|
|
|
|
|
|
|
| 68 |
chain_to_design = variable_chains[0] if variable_chains else "A"
|
|
|
|
| 69 |
mpnn_cmd = (
|
| 70 |
+
f"python -W ignore {mpnn_script} --pdb_path {pdb_path} --pdb_path_chains {chain_to_design} "
|
| 71 |
+
f"--out_folder {output_dir} --num_seq_per_target {num_seqs} --sampling_temp {temp} --seed 42 --batch_size 1"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
)
|
| 73 |
+
print(f"🚀 Designing {pdb_name} (Single-chain: {chain_to_design})...")
|
|
|
|
|
|
|
| 74 |
else:
|
| 75 |
+
# Multi-chain setup
|
| 76 |
+
pdb_dir = os.path.dirname(os.path.abspath(pdb_path)) or "."
|
|
|
|
|
|
|
|
|
|
| 77 |
jsonl_path = os.path.join(output_dir, "parsed_pdbs.jsonl")
|
|
|
|
| 78 |
parse_script = os.path.join(proteinmpnn_dir, "helper_scripts", "parse_multiple_chains.py")
|
| 79 |
|
| 80 |
+
# Step A: Parse PDB to JSONL
|
| 81 |
+
subprocess.run(f"python -W ignore {parse_script} --input_path={pdb_dir}/ --output_path={jsonl_path}", shell=True, check=True)
|
| 82 |
|
| 83 |
+
# Step B: Create initial Chain Dictionary
|
| 84 |
pdb_name_clones = f"{pdb_name}_clones"
|
| 85 |
+
# Fix: ensure the name in JSONL matches the dict key
|
| 86 |
with open(jsonl_path, 'r') as f:
|
| 87 |
jsonl_data = json.loads(f.readline())
|
| 88 |
jsonl_data['name'] = pdb_name_clones
|
| 89 |
with open(jsonl_path, 'w') as f:
|
| 90 |
f.write(json.dumps(jsonl_data) + '\n')
|
| 91 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
chain_id_json = os.path.join(output_dir, "chain_id_dict.json")
|
| 93 |
+
chain_id_dict = {pdb_name_clones: [[c for c in variable_chains], [c for c in fixed_chains]]}
|
| 94 |
with open(chain_id_json, 'w') as f:
|
| 95 |
json.dump(chain_id_dict, f)
|
| 96 |
|
| 97 |
+
# Step C: AUTOMATED CLEANING - Prunes ghost chains like 'C'
|
| 98 |
+
sync_protein_metadata(jsonl_path, chain_id_json)
|
| 99 |
+
|
| 100 |
+
# Step D: Final Execution Command
|
| 101 |
mpnn_cmd = (
|
| 102 |
+
f"python -W ignore {mpnn_script} --jsonl_path {jsonl_path} --chain_id_jsonl {chain_id_json} "
|
| 103 |
+
f"--out_folder {output_dir} --num_seq_per_target {num_seqs} --sampling_temp {temp} --seed 42"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 104 |
)
|
| 105 |
+
print(f"🚀 Designing {pdb_name}... (Fixed: {fixed_chains} | Redesign: {variable_chains})")
|
|
|
|
|
|
|
| 106 |
|
| 107 |
+
# 3. Execute with suppressed warnings
|
| 108 |
env = os.environ.copy()
|
| 109 |
env['PYTHONWARNINGS'] = 'ignore'
|
| 110 |
subprocess.run(mpnn_cmd, shell=True, check=True, env=env, stderr=subprocess.DEVNULL)
|
| 111 |
|
| 112 |
+
print(f"✅ Success! Design complete for {pdb_name}.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 113 |
|
| 114 |
if __name__ == "__main__":
|
|
|
|
| 115 |
if len(sys.argv) < 4:
|
| 116 |
print("Usage: python scripts/generator.py <pdb_path> <fixed_chains> <variable_chains> [num_seqs] [temp]")
|
|
|
|
| 117 |
sys.exit(1)
|
| 118 |
+
run_broteinshake_generator(
|
| 119 |
+
sys.argv[1], sys.argv[2], sys.argv[3],
|
| 120 |
+
int(sys.argv[4]) if len(sys.argv) > 4 else 20,
|
| 121 |
+
float(sys.argv[5]) if len(sys.argv) > 5 else 0.1
|
| 122 |
+
)
|
|
|
|
|
|
|
|
|