42Cummer commited on
Commit
05055dd
·
verified ·
1 Parent(s): 178c45c

Upload generator.py

Browse files
Files changed (1) hide show
  1. scripts/generator.py +68 -81
scripts/generator.py CHANGED
@@ -5,131 +5,118 @@ import warnings
5
  import sys
6
  import shutil
7
 
8
- def run_broteinshake_generator(pdb_path, fixed_chains, variable_chains, num_seqs=20, temp=0.1):
9
  """
10
- Generalized first-half pipeline for protein redesign.
11
-
12
- Args:
13
- pdb_path: Path to the target complex (e.g., 'data/3KAS.pdb').
14
- fixed_chains: Chains to remain unchanged (e.g., 'A'). Empty for single-chain proteins.
15
- variable_chains: Chains to be redesigned/repainted (e.g., 'B'). For single-chain, this is the only chain.
16
  """
17
- # 1. Setup project identifiers and directories
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  pdb_name = os.path.basename(pdb_path).split('.')[0]
19
  output_dir = f"./generated/{pdb_name}"
20
  os.makedirs(output_dir, exist_ok=True)
21
 
22
- # Get the project root directory (where ProteinMPNN should be)
23
  script_dir = os.path.dirname(os.path.abspath(__file__))
24
  project_root = os.path.dirname(script_dir)
25
  proteinmpnn_dir = os.path.join(project_root, "ProteinMPNN")
26
 
27
- # Clone ProteinMPNN if it doesn't exist (for HuggingFace Spaces deployment)
28
  if not os.path.exists(proteinmpnn_dir):
29
  print("ProteinMPNN not found, cloning repository...")
30
- subprocess.run(
31
- ["git", "clone", "https://github.com/dauparas/ProteinMPNN.git"],
32
- cwd=project_root,
33
- check=True,
34
- stdout=subprocess.DEVNULL,
35
- stderr=subprocess.DEVNULL
36
- )
37
 
38
  mpnn_script = os.path.join(proteinmpnn_dir, "protein_mpnn_run.py")
39
 
40
- # 2. Check if single-chain protein (no fixed chains means single-chain)
41
  if not fixed_chains or len(fixed_chains) == 0:
42
- # Single-chain protein: use direct PDB path command
43
- # For single-chain, variable_chains should be the only chain (e.g., "A")
44
  chain_to_design = variable_chains[0] if variable_chains else "A"
45
-
46
  mpnn_cmd = (
47
- f"python -W ignore {mpnn_script} "
48
- f"--pdb_path {pdb_path} "
49
- f"--pdb_path_chains {chain_to_design} "
50
- f"--out_folder {output_dir} "
51
- f"--num_seq_per_target {num_seqs} "
52
- f"--sampling_temp {temp} "
53
- f"--seed 42 "
54
- f"--batch_size 1"
55
  )
56
-
57
- print(f"🚀 Designing sequences for {pdb_name} (single-chain mode)...")
58
- print(f"✏️ Redesigning chain: {chain_to_design}")
59
  else:
60
- # Multi-chain protein: use JSONL-based command
61
- # 2. Parse the PDB into JSONL format for the model
62
- pdb_dir = os.path.dirname(os.path.abspath(pdb_path))
63
- if not pdb_dir:
64
- pdb_dir = "."
65
  jsonl_path = os.path.join(output_dir, "parsed_pdbs.jsonl")
66
-
67
  parse_script = os.path.join(proteinmpnn_dir, "helper_scripts", "parse_multiple_chains.py")
68
 
69
- parse_cmd = f"python -W ignore {parse_script} --input_path={pdb_dir}/ --output_path={jsonl_path}"
70
- subprocess.run(parse_cmd, shell=True, check=True, stderr=subprocess.DEVNULL)
71
 
72
- # Update the name in parsed JSONL to include "_clones"
73
  pdb_name_clones = f"{pdb_name}_clones"
 
74
  with open(jsonl_path, 'r') as f:
75
  jsonl_data = json.loads(f.readline())
76
  jsonl_data['name'] = pdb_name_clones
77
  with open(jsonl_path, 'w') as f:
78
  f.write(json.dumps(jsonl_data) + '\n')
79
 
80
- # 3. Generate the Chain Configuration JSONs (The 'Engine' Logic)
81
- # Format: {"name": [masked_chains_list, visible_chains_list]}
82
- # masked_chains = chains to redesign, visible_chains = chains to keep fixed
83
- masked_chains_list = [c for c in variable_chains]
84
- visible_chains_list = [c for c in fixed_chains]
85
- chain_id_dict = {pdb_name_clones: [masked_chains_list, visible_chains_list]}
86
-
87
  chain_id_json = os.path.join(output_dir, "chain_id_dict.json")
 
88
  with open(chain_id_json, 'w') as f:
89
  json.dump(chain_id_dict, f)
90
 
91
- # 4. Execute optimized ProteinMPNN command for multi-chain
 
 
 
92
  mpnn_cmd = (
93
- f"python -W ignore {mpnn_script} "
94
- f"--jsonl_path {jsonl_path} "
95
- f"--chain_id_jsonl {chain_id_json} "
96
- f"--out_folder {output_dir} "
97
- f"--num_seq_per_target {num_seqs} "
98
- f"--sampling_temp {temp} "
99
- f"--seed 42"
100
  )
101
-
102
- print(f"🚀 Designing sequences for {pdb_name}...")
103
- print(f"🔒 Fixed: {fixed_chains} | ✏️ Redesigning: {variable_chains}")
104
 
105
- # Suppress warnings by redirecting stderr
106
  env = os.environ.copy()
107
  env['PYTHONWARNINGS'] = 'ignore'
108
  subprocess.run(mpnn_cmd, shell=True, check=True, env=env, stderr=subprocess.DEVNULL)
109
 
110
- # For single-chain proteins, ProteinMPNN saves as {pdb_name}.fa
111
- # Rename it to {pdb_name}_clones.fa for consistency
112
- if not fixed_chains or len(fixed_chains) == 0:
113
- seqs_dir = os.path.join(output_dir, "seqs")
114
- old_file = os.path.join(seqs_dir, f"{pdb_name}.fa")
115
- new_file = os.path.join(seqs_dir, f"{pdb_name}_clones.fa")
116
- if os.path.exists(old_file) and not os.path.exists(new_file):
117
- os.rename(old_file, new_file)
118
- print(f"📝 Renamed {pdb_name}.fa → {pdb_name}_clones.fa")
119
-
120
- print(f"✅ Success! Fold the top sequences at https://esmatlas.com/resources?action=fold")
121
 
122
  if __name__ == "__main__":
123
- import sys
124
  if len(sys.argv) < 4:
125
  print("Usage: python scripts/generator.py <pdb_path> <fixed_chains> <variable_chains> [num_seqs] [temp]")
126
- print("Example: python scripts/generator.py data/3kas.pdb A B")
127
  sys.exit(1)
128
-
129
- pdb_path = sys.argv[1]
130
- fixed_chains = sys.argv[2]
131
- variable_chains = sys.argv[3]
132
- num_seqs = int(sys.argv[4]) if len(sys.argv) > 4 else 20
133
- temp = float(sys.argv[5]) if len(sys.argv) > 5 else 0.1
134
-
135
- run_broteinshake_generator(pdb_path, fixed_chains, variable_chains, num_seqs, temp)
 
5
  import sys
6
  import shutil
7
 
8
+ def sync_protein_metadata(jsonl_path, dict_path):
9
  """
10
+ Automated metadata sanitization to prevent KeyError (e.g., 'seq_chain_C').
11
+ Prunes non-proteogenic chain IDs from the dictionary before the run.
 
 
 
 
12
  """
13
+ if not os.path.exists(jsonl_path) or not os.path.exists(dict_path):
14
+ return
15
+
16
+ # 1. Identify chains that actually have proteogenic sequence data
17
+ valid_chains_map = {}
18
+ with open(jsonl_path, 'r') as f:
19
+ for line in f:
20
+ entry = json.loads(line)
21
+ name = entry['name']
22
+ # Only keep chains that have a 'seq_chain_X' entry in the JSONL
23
+ valid = {k.split('_')[-1] for k in entry.keys() if k.startswith('seq_chain_')}
24
+ valid_chains_map[name] = valid
25
+
26
+ # 2. Clean the chain ID dictionary
27
+ with open(dict_path, 'r') as f:
28
+ chain_id_dict = json.load(f)
29
+
30
+ for pdb_name, configs in chain_id_dict.items():
31
+ if pdb_name in valid_chains_map:
32
+ valid = valid_chains_map[pdb_name]
33
+ # configs[0] = redesign list, configs[1] = fixed list
34
+ original_chains = set(configs[0] + configs[1])
35
+
36
+ chain_id_dict[pdb_name] = [
37
+ [c for c in configs[0] if c in valid],
38
+ [c for c in configs[1] if c in valid]
39
+ ]
40
+
41
+ # Diagnostic feedback
42
+ removed = original_chains - valid
43
+ if removed:
44
+ print(f"🧹 Sanitizer: Pruned non-protein chains from metadata: {removed}")
45
+
46
+ # 3. Overwrite with cleaned metadata for ProteinMPNN
47
+ with open(dict_path, 'w') as f:
48
+ json.dump(chain_id_dict, f)
49
+
50
+ def run_broteinshake_generator(pdb_path, fixed_chains, variable_chains, num_seqs=20, temp=0.1):
51
+ # 1. Setup identifiers and directories
52
  pdb_name = os.path.basename(pdb_path).split('.')[0]
53
  output_dir = f"./generated/{pdb_name}"
54
  os.makedirs(output_dir, exist_ok=True)
55
 
 
56
  script_dir = os.path.dirname(os.path.abspath(__file__))
57
  project_root = os.path.dirname(script_dir)
58
  proteinmpnn_dir = os.path.join(project_root, "ProteinMPNN")
59
 
 
60
  if not os.path.exists(proteinmpnn_dir):
61
  print("ProteinMPNN not found, cloning repository...")
62
+ subprocess.run(["git", "clone", "https://github.com/dauparas/ProteinMPNN.git"], cwd=project_root, check=True)
 
 
 
 
 
 
63
 
64
  mpnn_script = os.path.join(proteinmpnn_dir, "protein_mpnn_run.py")
65
 
66
+ # 2. Handle Single vs Multi-Chain Logic
67
  if not fixed_chains or len(fixed_chains) == 0:
 
 
68
  chain_to_design = variable_chains[0] if variable_chains else "A"
 
69
  mpnn_cmd = (
70
+ f"python -W ignore {mpnn_script} --pdb_path {pdb_path} --pdb_path_chains {chain_to_design} "
71
+ f"--out_folder {output_dir} --num_seq_per_target {num_seqs} --sampling_temp {temp} --seed 42 --batch_size 1"
 
 
 
 
 
 
72
  )
73
+ print(f"🚀 Designing {pdb_name} (Single-chain: {chain_to_design})...")
 
 
74
  else:
75
+ # Multi-chain setup
76
+ pdb_dir = os.path.dirname(os.path.abspath(pdb_path)) or "."
 
 
 
77
  jsonl_path = os.path.join(output_dir, "parsed_pdbs.jsonl")
 
78
  parse_script = os.path.join(proteinmpnn_dir, "helper_scripts", "parse_multiple_chains.py")
79
 
80
+ # Step A: Parse PDB to JSONL
81
+ subprocess.run(f"python -W ignore {parse_script} --input_path={pdb_dir}/ --output_path={jsonl_path}", shell=True, check=True)
82
 
83
+ # Step B: Create initial Chain Dictionary
84
  pdb_name_clones = f"{pdb_name}_clones"
85
+ # Fix: ensure the name in JSONL matches the dict key
86
  with open(jsonl_path, 'r') as f:
87
  jsonl_data = json.loads(f.readline())
88
  jsonl_data['name'] = pdb_name_clones
89
  with open(jsonl_path, 'w') as f:
90
  f.write(json.dumps(jsonl_data) + '\n')
91
 
 
 
 
 
 
 
 
92
  chain_id_json = os.path.join(output_dir, "chain_id_dict.json")
93
+ chain_id_dict = {pdb_name_clones: [[c for c in variable_chains], [c for c in fixed_chains]]}
94
  with open(chain_id_json, 'w') as f:
95
  json.dump(chain_id_dict, f)
96
 
97
+ # Step C: AUTOMATED CLEANING - Prunes ghost chains like 'C'
98
+ sync_protein_metadata(jsonl_path, chain_id_json)
99
+
100
+ # Step D: Final Execution Command
101
  mpnn_cmd = (
102
+ f"python -W ignore {mpnn_script} --jsonl_path {jsonl_path} --chain_id_jsonl {chain_id_json} "
103
+ f"--out_folder {output_dir} --num_seq_per_target {num_seqs} --sampling_temp {temp} --seed 42"
 
 
 
 
 
104
  )
105
+ print(f"🚀 Designing {pdb_name}... (Fixed: {fixed_chains} | Redesign: {variable_chains})")
 
 
106
 
107
+ # 3. Execute with suppressed warnings
108
  env = os.environ.copy()
109
  env['PYTHONWARNINGS'] = 'ignore'
110
  subprocess.run(mpnn_cmd, shell=True, check=True, env=env, stderr=subprocess.DEVNULL)
111
 
112
+ print(f"✅ Success! Design complete for {pdb_name}.")
 
 
 
 
 
 
 
 
 
 
113
 
114
  if __name__ == "__main__":
 
115
  if len(sys.argv) < 4:
116
  print("Usage: python scripts/generator.py <pdb_path> <fixed_chains> <variable_chains> [num_seqs] [temp]")
 
117
  sys.exit(1)
118
+ run_broteinshake_generator(
119
+ sys.argv[1], sys.argv[2], sys.argv[3],
120
+ int(sys.argv[4]) if len(sys.argv) > 4 else 20,
121
+ float(sys.argv[5]) if len(sys.argv) > 5 else 0.1
122
+ )