kabudadada commited on
Commit
c6562b0
·
1 Parent(s): b2105f3

feat(esm-mcp): enable variant effect & fixed-backbone; align adapter returns

Browse files
esm/mcp_output/mcp_plugin/adapter.py CHANGED
@@ -5,16 +5,12 @@ import sys
5
  source_path = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "source")
6
  sys.path.insert(0, source_path)
7
 
8
- # Import modules
9
  try:
10
  from esm.pretrained import load_model_and_alphabet, load_model_and_alphabet_local
 
11
  from esm.data import Alphabet, BatchConverter
12
- from esm.inverse_folding import load_inverse_folding_model
13
  from esm.model import ESM1, ESM2, MSATransformer
14
- from examples.lm_design.lm_design import generate_fixed_backbone, generate_free_backbone
15
- from examples.variant_prediction.predict import predict_variant_effect
16
- from scripts.extract import extract_features
17
- from scripts.fold import predict_structure
18
  except ImportError as e:
19
  print(f"Module import failed: {e}, some functions will be unavailable.")
20
 
@@ -49,11 +45,11 @@ class Adapter:
49
  else:
50
  model, alphabet = load_model_and_alphabet(model_name)
51
  self.models[model_name] = model
52
- return {"status": "success", "model": model, "alphabet": alphabet}
53
  except Exception as e:
54
- return {"status": "error", "message": f"Failed to load model: {e}"}
55
 
56
- def load_inverse_folding_model(self, model_name):
57
  """
58
  Load inverse folding model.
59
 
@@ -64,11 +60,12 @@ class Adapter:
64
  - dict: Information containing status and model instance.
65
  """
66
  try:
67
- model = load_inverse_folding_model(model_name)
 
68
  self.models[model_name] = model
69
- return {"status": "success", "model": model}
70
  except Exception as e:
71
- return {"status": "error", "message": f"Failed to load inverse folding model: {e}"}
72
 
73
  # ------------------------- Data Processing Module -------------------------
74
 
@@ -81,9 +78,9 @@ class Adapter:
81
  """
82
  try:
83
  alphabet = Alphabet()
84
- return {"status": "success", "alphabet": alphabet}
85
  except Exception as e:
86
- return {"status": "error", "message": f"Failed to create alphabet: {e}"}
87
 
88
  def create_batch_converter(self, alphabet):
89
  """
@@ -97,9 +94,9 @@ class Adapter:
97
  """
98
  try:
99
  batch_converter = BatchConverter(alphabet)
100
- return {"status": "success", "batch_converter": batch_converter}
101
  except Exception as e:
102
- return {"status": "error", "message": f"Failed to create batch converter: {e}"}
103
 
104
  # ------------------------- Model Instantiation Module -------------------------
105
 
@@ -123,9 +120,9 @@ class Adapter:
123
  attention_heads=attention_heads,
124
  alphabet_size=alphabet_size
125
  )
126
- return {"status": "success", "model": model}
127
  except Exception as e:
128
- return {"status": "error", "message": f"Failed to instantiate ESM1 model: {e}"}
129
 
130
  def create_esm2_model(self, num_layers=33, embed_dim=1280, attention_heads=20, alphabet_size=33):
131
  """
@@ -147,9 +144,9 @@ class Adapter:
147
  attention_heads=attention_heads,
148
  alphabet_size=alphabet_size
149
  )
150
- return {"status": "success", "model": model}
151
  except Exception as e:
152
- return {"status": "error", "message": f"Failed to instantiate ESM2 model: {e}"}
153
 
154
  def create_msa_transformer(self, num_layers=12, embed_dim=768, attention_heads=12, max_tokens_per_msa=2**14):
155
  """
@@ -171,41 +168,69 @@ class Adapter:
171
  attention_heads=attention_heads,
172
  max_tokens_per_msa=max_tokens_per_msa
173
  )
174
- return {"status": "success", "model": model}
175
  except Exception as e:
176
- return {"status": "error", "message": f"Failed to instantiate MSA Transformer model: {e}"}
177
 
178
  # ------------------------- Function Call Module -------------------------
179
 
180
- def generate_fixed_backbone(self, model, alphabet, pdb_file, chain_id, temperature=1.0, num_samples=1):
181
  """
182
  Call fixed backbone generation function.
183
 
184
  Parameters:
185
- - model: ESM model instance
186
- - alphabet: Alphabet instance
187
- - pdb_file: str, path to PDB file
188
- - chain_id: str, chain identifier
189
  - temperature: float, sampling temperature (default: 1.0)
190
  - num_samples: int, number of samples to generate (default: 1)
 
 
191
 
192
  Returns:
193
  - dict: Information containing status and generation result.
194
  """
195
  try:
196
- result = generate_fixed_backbone(
197
- model=model,
198
- alphabet=alphabet,
199
- pdb_file=pdb_file,
200
- chain_id=chain_id,
201
- temperature=temperature,
202
- num_samples=num_samples
203
- )
204
- return {"status": "success", "result": result}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205
  except Exception as e:
206
- return {"status": "error", "message": f"Failed to generate fixed backbone: {e}"}
207
 
208
- def generate_free_backbone(self, model, alphabet, length, temperature=1.0, num_samples=1, device="cpu"):
209
  """
210
  Call free backbone generation function.
211
 
@@ -221,47 +246,64 @@ class Adapter:
221
  - dict: Information containing status and generation result.
222
  """
223
  try:
224
- result = generate_free_backbone(
225
- model=model,
226
- alphabet=alphabet,
227
- length=length,
228
- temperature=temperature,
229
- num_samples=num_samples,
230
- device=device
231
- )
232
- return {"status": "success", "result": result}
233
  except Exception as e:
234
- return {"status": "error", "message": f"Failed to generate free backbone: {e}"}
235
 
236
- def predict_variant_effect(self, model, alphabet, sequence, mutations, batch_size=1, device="cpu"):
237
  """
238
  Call variant effect prediction function.
239
 
240
  Parameters:
241
- - model: ESM model instance
242
- - alphabet: Alphabet instance
243
  - sequence: str, wild-type protein sequence
244
- - mutations: list, list of mutations in format ["A123V", "G456D"]
245
- - batch_size: int, batch size for processing (default: 1)
246
- - device: str, device to use for computation (default: "cpu")
 
 
247
 
248
  Returns:
249
  - dict: Information containing status and prediction result.
250
  """
251
  try:
252
- result = predict_variant_effect(
253
- model=model,
254
- alphabet=alphabet,
255
- sequence=sequence,
256
- mutations=mutations,
257
- batch_size=batch_size,
258
- device=device
259
- )
260
- return {"status": "success", "result": result}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
261
  except Exception as e:
262
- return {"status": "error", "message": f"Failed to predict variant effect: {e}"}
263
 
264
- def extract_features(self, model, alphabet, sequences, repr_layers=[-1], include_contacts=False, device="cpu"):
265
  """
266
  Call feature extraction function.
267
 
@@ -277,19 +319,11 @@ class Adapter:
277
  - dict: Information containing status and extraction result.
278
  """
279
  try:
280
- result = extract_features(
281
- model=model,
282
- alphabet=alphabet,
283
- sequences=sequences,
284
- repr_layers=repr_layers,
285
- include_contacts=include_contacts,
286
- device=device
287
- )
288
- return {"status": "success", "result": result}
289
  except Exception as e:
290
- return {"status": "error", "message": f"Failed to extract features: {e}"}
291
 
292
- def predict_structure_local(self, model, alphabet, sequence, device="cpu"):
293
  """
294
  Call local structure prediction function.
295
 
@@ -303,15 +337,9 @@ class Adapter:
303
  - dict: Information containing status and prediction result.
304
  """
305
  try:
306
- result = predict_structure(
307
- model=model,
308
- alphabet=alphabet,
309
- sequence=sequence,
310
- device=device
311
- )
312
- return {"status": "success", "result": result}
313
  except Exception as e:
314
- return {"status": "error", "message": f"Failed to predict structure: {e}"}
315
 
316
  def predict_structure(self, sequence):
317
  """
@@ -346,15 +374,14 @@ class Adapter:
346
  "num_atoms": len(list(structure.get_atoms())),
347
  "pdb_content": response.text
348
  }
349
-
350
- return {"status": "success", "result": structure_info}
351
  else:
352
- return {"status": "error", "message": f"API returned error: {response.status_code}"}
353
 
354
  except requests.exceptions.Timeout:
355
- return {"status": "error", "message": "ESMFold API request timed out"}
356
  except Exception as e:
357
- return {"status": "error", "message": f"Error predicting structure: {e}"}
358
 
359
  def analyze_protein_sequence(self, sequence):
360
  """
@@ -380,10 +407,9 @@ class Adapter:
380
  "composition": composition,
381
  "sequence": sequence
382
  }
383
-
384
- return {"status": "success", "result": result}
385
  except Exception as e:
386
- return {"status": "error", "message": f"Failed to analyze sequence: {e}"}
387
 
388
  def validate_protein_sequence(self, sequence):
389
  """
@@ -409,10 +435,9 @@ class Adapter:
409
  "length": len(sequence),
410
  "uppercase_sequence": sequence_upper
411
  }
412
-
413
- return {"status": "success", "result": result}
414
  except Exception as e:
415
- return {"status": "error", "message": f"Failed to validate sequence: {e}"}
416
 
417
  # ------------------------- Fallback Mode Handling -------------------------
418
 
@@ -420,4 +445,4 @@ class Adapter:
420
  """
421
  Enable fallback mode, prompting the user that some functions are unavailable.
422
  """
423
- return {"status": "warning", "message": "Some functions are unavailable, please check module import status."}
 
5
  source_path = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "source")
6
  sys.path.insert(0, source_path)
7
 
8
+ # Minimal, stable imports only; avoid examples/scripts at import time
9
  try:
10
  from esm.pretrained import load_model_and_alphabet, load_model_and_alphabet_local
11
+ from esm import pretrained, inverse_folding
12
  from esm.data import Alphabet, BatchConverter
 
13
  from esm.model import ESM1, ESM2, MSATransformer
 
 
 
 
14
  except ImportError as e:
15
  print(f"Module import failed: {e}, some functions will be unavailable.")
16
 
 
45
  else:
46
  model, alphabet = load_model_and_alphabet(model_name)
47
  self.models[model_name] = model
48
+ return {"success": True, "result": {"model": model, "alphabet": alphabet}, "error": None}
49
  except Exception as e:
50
+ return {"success": False, "result": None, "error": f"Failed to load model: {e}"}
51
 
52
+ def load_inverse_folding_model(self, model_name="esm_if1_gvp4_t16_142M_UR50"):
53
  """
54
  Load inverse folding model.
55
 
 
60
  - dict: Information containing status and model instance.
61
  """
62
  try:
63
+ # Use pretrained helper consistent with service
64
+ model, _alphabet = getattr(pretrained, model_name)() if hasattr(pretrained, model_name) else pretrained.esm_if1_gvp4_t16_142M_UR50()
65
  self.models[model_name] = model
66
+ return {"success": True, "result": {"model_name": model_name}, "error": None}
67
  except Exception as e:
68
+ return {"success": False, "result": None, "error": f"Failed to load inverse folding model: {e}"}
69
 
70
  # ------------------------- Data Processing Module -------------------------
71
 
 
78
  """
79
  try:
80
  alphabet = Alphabet()
81
+ return {"success": True, "result": {"alphabet": alphabet}, "error": None}
82
  except Exception as e:
83
+ return {"success": False, "result": None, "error": f"Failed to create alphabet: {e}"}
84
 
85
  def create_batch_converter(self, alphabet):
86
  """
 
94
  """
95
  try:
96
  batch_converter = BatchConverter(alphabet)
97
+ return {"success": True, "result": {"batch_converter": batch_converter}, "error": None}
98
  except Exception as e:
99
+ return {"success": False, "result": None, "error": f"Failed to create batch converter: {e}"}
100
 
101
  # ------------------------- Model Instantiation Module -------------------------
102
 
 
120
  attention_heads=attention_heads,
121
  alphabet_size=alphabet_size
122
  )
123
+ return {"success": True, "result": {"model": model}, "error": None}
124
  except Exception as e:
125
+ return {"success": False, "result": None, "error": f"Failed to instantiate ESM1 model: {e}"}
126
 
127
  def create_esm2_model(self, num_layers=33, embed_dim=1280, attention_heads=20, alphabet_size=33):
128
  """
 
144
  attention_heads=attention_heads,
145
  alphabet_size=alphabet_size
146
  )
147
+ return {"success": True, "result": {"model": model}, "error": None}
148
  except Exception as e:
149
+ return {"success": False, "result": None, "error": f"Failed to instantiate ESM2 model: {e}"}
150
 
151
  def create_msa_transformer(self, num_layers=12, embed_dim=768, attention_heads=12, max_tokens_per_msa=2**14):
152
  """
 
168
  attention_heads=attention_heads,
169
  max_tokens_per_msa=max_tokens_per_msa
170
  )
171
+ return {"success": True, "result": {"model": model}, "error": None}
172
  except Exception as e:
173
+ return {"success": False, "result": None, "error": f"Failed to instantiate MSA Transformer model: {e}"}
174
 
175
  # ------------------------- Function Call Module -------------------------
176
 
177
+ def generate_fixed_backbone(self, pdbfile, chain_id=None, temperature=1.0, num_samples=1, multichain_backbone=False, nogpu=False):
178
  """
179
  Call fixed backbone generation function.
180
 
181
  Parameters:
182
+ - pdbfile: str, path to PDB/CIF file
183
+ - chain_id: str or None, chain identifier (ignored when multichain)
 
 
184
  - temperature: float, sampling temperature (default: 1.0)
185
  - num_samples: int, number of samples to generate (default: 1)
186
+ - multichain_backbone: bool, condition on complex if True
187
+ - nogpu: bool, force CPU
188
 
189
  Returns:
190
  - dict: Information containing status and generation result.
191
  """
192
  try:
193
+ import torch
194
+ model_obj, _alphabet = pretrained.esm_if1_gvp4_t16_142M_UR50()
195
+ model_obj = model_obj.eval()
196
+
197
+ sampled, recoveries = [], []
198
+
199
+ if not torch.cuda.is_available() or nogpu:
200
+ device = torch.device("cpu")
201
+ else:
202
+ model_obj = model_obj.cuda()
203
+ device = torch.device("cuda")
204
+
205
+ if multichain_backbone:
206
+ structure = inverse_folding.util.load_structure(pdbfile)
207
+ coords, native_seqs = inverse_folding.multichain_util.extract_coords_from_complex(structure)
208
+ target_chain_id = chain_id if (chain_id in native_seqs if chain_id is not None else False) else next(iter(native_seqs.keys()))
209
+ native_seq = native_seqs[target_chain_id]
210
+ for _ in range(num_samples):
211
+ sampled_seq = inverse_folding.multichain_util.sample_sequence_in_complex(
212
+ model_obj, coords, target_chain_id, temperature=temperature
213
+ )
214
+ sampled.append(sampled_seq)
215
+ try:
216
+ recoveries.append(sum(a == b for a, b in zip(native_seq, sampled_seq)) / max(1, len(native_seq)))
217
+ except Exception:
218
+ recoveries.append(None)
219
+ else:
220
+ coords, native_seq = inverse_folding.util.load_coords(pdbfile, chain_id)
221
+ for _ in range(num_samples):
222
+ sampled_seq = model_obj.sample(coords, temperature=temperature, device=device)
223
+ sampled.append(sampled_seq)
224
+ try:
225
+ recoveries.append(sum(a == b for a, b in zip(native_seq, sampled_seq)) / max(1, len(native_seq)))
226
+ except Exception:
227
+ recoveries.append(None)
228
+
229
+ return {"success": True, "result": {"sampled_sequences": sampled, "recovery": recoveries}, "error": None}
230
  except Exception as e:
231
+ return {"success": False, "result": None, "error": f"Failed to generate fixed backbone: {e}"}
232
 
233
+ def generate_free_backbone(self, *args, **kwargs):
234
  """
235
  Call free backbone generation function.
236
 
 
246
  - dict: Information containing status and generation result.
247
  """
248
  try:
249
+ return {"success": False, "result": None, "error": "free_backbone generation is not exposed in MCP"}
 
 
 
 
 
 
 
 
250
  except Exception as e:
251
+ return {"success": False, "result": None, "error": f"Failed to handle free backbone: {e}"}
252
 
253
+ def predict_variant_effect(self, sequence, mutation, model_location=None, scoring_strategy="wt-marginals", offset_idx=0, nogpu=False):
254
  """
255
  Call variant effect prediction function.
256
 
257
  Parameters:
 
 
258
  - sequence: str, wild-type protein sequence
259
+ - mutation: str, single mutation like "A42G" (WT, 1-based pos, MUT)
260
+ - model_location: optional model name/path (default ESM-1v)
261
+ - scoring_strategy: currently only "wt-marginals"
262
+ - offset_idx: int, position offset
263
+ - nogpu: bool
264
 
265
  Returns:
266
  - dict: Information containing status and prediction result.
267
  """
268
  try:
269
+ import re
270
+ import torch
271
+
272
+ sequence = sequence.strip()
273
+ m = re.match(r"^([ACDEFGHIKLMNPQRSTVWY])(\d+)([ACDEFGHIKLMNPQRSTVWY])$", mutation.strip().upper())
274
+ if not m:
275
+ return {"success": False, "result": None, "error": "Invalid mutation format. Use like 'A42G'"}
276
+ wt, pos_str, mt = m.group(1), m.group(2), m.group(3)
277
+ pos = int(pos_str) - offset_idx
278
+ if pos < 0 or pos >= len(sequence):
279
+ return {"success": False, "result": None, "error": "Mutation position out of range after offset"}
280
+ if sequence[pos].upper() != wt:
281
+ return {"success": False, "result": None, "error": "Wildtype residue does not match sequence at position"}
282
+
283
+ model_name = model_location or "esm1v_t33_650M_UR90S_1"
284
+ model_obj, alphabet = load_model_and_alphabet(model_name)
285
+ model_obj = model_obj.eval()
286
+ if torch.cuda.is_available() and not nogpu:
287
+ model_obj = model_obj.cuda()
288
+
289
+ batch_converter = alphabet.get_batch_converter()
290
+ data = [("protein1", sequence)]
291
+ _labels, _strs, batch_tokens = batch_converter(data)
292
+ with torch.no_grad():
293
+ if torch.cuda.is_available() and not nogpu:
294
+ batch_tokens = batch_tokens.cuda()
295
+ logits = model_obj(batch_tokens)["logits"]
296
+ token_log_probs = torch.log_softmax(logits, dim=-1)
297
+
298
+ wt_idx = alphabet.get_idx(wt)
299
+ mt_idx = alphabet.get_idx(mt)
300
+ score = (token_log_probs[0, 1 + pos, mt_idx] - token_log_probs[0, 1 + pos, wt_idx]).item()
301
+
302
+ return {"success": True, "result": {"score": score, "model": model_name, "strategy": scoring_strategy, "position_0_based": pos}, "error": None}
303
  except Exception as e:
304
+ return {"success": False, "result": None, "error": f"Failed to predict variant effect: {e}"}
305
 
306
+ def extract_features(self, *args, **kwargs):
307
  """
308
  Call feature extraction function.
309
 
 
319
  - dict: Information containing status and extraction result.
320
  """
321
  try:
322
+ return {"success": False, "result": None, "error": "extract_features not exposed via Adapter"}
 
 
 
 
 
 
 
 
323
  except Exception as e:
324
+ return {"success": False, "result": None, "error": f"Failed to handle extract_features: {e}"}
325
 
326
+ def predict_structure_local(self, *args, **kwargs):
327
  """
328
  Call local structure prediction function.
329
 
 
337
  - dict: Information containing status and prediction result.
338
  """
339
  try:
340
+ return {"success": False, "result": None, "error": "local structure prediction is not exposed via Adapter"}
 
 
 
 
 
 
341
  except Exception as e:
342
+ return {"success": False, "result": None, "error": f"Failed to handle predict_structure_local: {e}"}
343
 
344
  def predict_structure(self, sequence):
345
  """
 
374
  "num_atoms": len(list(structure.get_atoms())),
375
  "pdb_content": response.text
376
  }
377
+ return {"success": True, "result": structure_info, "error": None}
 
378
  else:
379
+ return {"success": False, "result": None, "error": f"API returned error: {response.status_code}"}
380
 
381
  except requests.exceptions.Timeout:
382
+ return {"success": False, "result": None, "error": "ESMFold API request timed out"}
383
  except Exception as e:
384
+ return {"success": False, "result": None, "error": f"Error predicting structure: {e}"}
385
 
386
  def analyze_protein_sequence(self, sequence):
387
  """
 
407
  "composition": composition,
408
  "sequence": sequence
409
  }
410
+ return {"success": True, "result": result, "error": None}
 
411
  except Exception as e:
412
+ return {"success": False, "result": None, "error": f"Failed to analyze sequence: {e}"}
413
 
414
  def validate_protein_sequence(self, sequence):
415
  """
 
435
  "length": len(sequence),
436
  "uppercase_sequence": sequence_upper
437
  }
438
+ return {"success": True, "result": result, "error": None}
 
439
  except Exception as e:
440
+ return {"success": False, "result": None, "error": f"Failed to validate sequence: {e}"}
441
 
442
  # ------------------------- Fallback Mode Handling -------------------------
443
 
 
445
  """
446
  Enable fallback mode, prompting the user that some functions are unavailable.
447
  """
448
+ return {"success": False, "result": None, "error": "Some functions are unavailable, please check module import status."}
esm/mcp_output/mcp_plugin/mcp_service.py CHANGED
@@ -56,51 +56,174 @@ def process_sequence_data(sequences: list):
56
  return {"success": False, "result": None, "error": str(e)}
57
 
58
  @mcp.tool(name="inverse_folding_model", description="Load inverse folding model")
59
- def inverse_folding_model():
60
  """
61
- Load the core model for inverse folding tasks.
 
 
 
62
 
63
  Returns:
64
- dict: Contains success/result/error fields.
65
  """
66
  try:
67
- model = inverse_folding.load_inverse_folding_model()
68
- return {"success": True, "result": model, "error": None}
 
 
 
 
 
 
 
 
 
69
  except Exception as e:
70
  return {"success": False, "result": None, "error": str(e)}
71
 
72
  @mcp.tool(name="generate_fixed_backbone", description="Generate protein sequence with fixed backbone")
73
- def generate_fixed_backbone(input_data: dict):
 
 
 
 
 
 
 
74
  """
75
- Generate protein sequences using a fixed backbone.
76
 
77
  Parameters:
78
- input_data (dict): Input data payload.
 
 
 
 
 
79
 
80
  Returns:
81
- dict: Contains success/result/error fields.
82
  """
83
  try:
84
- result = lm_design.generate_fixed_backbone(input_data)
85
- return {"success": False, "result": None, "error": "This feature is currently unavailable"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  except Exception as e:
87
  return {"success": False, "result": None, "error": str(e)}
88
 
89
  @mcp.tool(name="predict_variant_effect", description="Predict protein variant effects")
90
- def predict_variant_effect(sequence: str, mutation: str):
 
 
 
 
 
 
 
91
  """
92
- Predict the effect of a mutation in a protein sequence.
93
 
94
  Parameters:
95
- sequence (str): Protein sequence.
96
- mutation (str): Mutation description.
 
 
 
 
97
 
98
  Returns:
99
- dict: Contains success/result/error fields.
100
  """
101
  try:
102
- # result = predict.predict_variant_effect(sequence, mutation)
103
- return {"success": False, "result": None, "error": "This feature is currently unavailable"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
  except Exception as e:
105
  return {"success": False, "result": None, "error": str(e)}
106
 
 
56
  return {"success": False, "result": None, "error": str(e)}
57
 
58
  @mcp.tool(name="inverse_folding_model", description="Load inverse folding model")
59
+ def inverse_folding_model(model_name: str = "esm_if1_gvp4_t16_142M_UR50"):
60
  """
61
+ Ensure the inverse folding model weights are available and loadable.
62
+
63
+ Parameters:
64
+ model_name (str): Pretrained inverse folding model identifier.
65
 
66
  Returns:
67
+ dict: success/result/error. result contains { model_name }
68
  """
69
  try:
70
+ # Load to ensure environment and weights are OK; don't return the torch object
71
+ model_obj, _alphabet = pretrained.__dict__[model_name]() if hasattr(pretrained, model_name) else pretrained.esm_if1_gvp4_t16_142M_UR50()
72
+ # Put into eval mode and immediately free GPU if any
73
+ model_obj = model_obj.eval()
74
+ try:
75
+ # move back to CPU to avoid holding GPU memory
76
+ import torch # local import to avoid hard dep on torch at import time
77
+ model_obj.cpu()
78
+ except Exception:
79
+ pass
80
+ return {"success": True, "result": {"model_name": model_name}, "error": None}
81
  except Exception as e:
82
  return {"success": False, "result": None, "error": str(e)}
83
 
84
  @mcp.tool(name="generate_fixed_backbone", description="Generate protein sequence with fixed backbone")
85
+ def generate_fixed_backbone(
86
+ pdbfile: str,
87
+ chain: str | None = None,
88
+ temperature: float = 1.0,
89
+ num_samples: int = 1,
90
+ multichain_backbone: bool = False,
91
+ nogpu: bool = False,
92
+ ):
93
  """
94
+ Sample protein sequences conditioned on a fixed backbone structure.
95
 
96
  Parameters:
97
+ pdbfile (str): Path to input PDB/CIF file.
98
+ chain (str|None): Chain ID for single-chain conditioning. Ignored when multichain_backbone=True.
99
+ temperature (float): Sampling temperature (>1 for diversity).
100
+ num_samples (int): Number of sequences to sample.
101
+ multichain_backbone (bool): If True, condition on all chains in the complex.
102
+ nogpu (bool): If True, do not use GPU even if available.
103
 
104
  Returns:
105
+ dict: success/result/error. result contains { sampled_sequences, recovery (if native available) }
106
  """
107
  try:
108
+ import torch
109
+ model_obj, _alphabet = pretrained.esm_if1_gvp4_t16_142M_UR50()
110
+ model_obj = model_obj.eval()
111
+
112
+ sampled = []
113
+ recoveries = []
114
+
115
+ if not torch.cuda.is_available() or nogpu:
116
+ device = torch.device("cpu")
117
+ else:
118
+ model_obj = model_obj.cuda()
119
+ device = torch.device("cuda")
120
+
121
+ if multichain_backbone:
122
+ structure = inverse_folding.util.load_structure(pdbfile)
123
+ coords, native_seqs = inverse_folding.multichain_util.extract_coords_from_complex(structure)
124
+ # choose target chain: if chain provided and exists, use it; else pick first
125
+ target_chain_id = chain if (chain in native_seqs if chain is not None else False) else next(iter(native_seqs.keys()))
126
+ native_seq = native_seqs[target_chain_id]
127
+ for _ in range(num_samples):
128
+ sampled_seq = inverse_folding.multichain_util.sample_sequence_in_complex(
129
+ model_obj, coords, target_chain_id, temperature=temperature
130
+ )
131
+ sampled.append(sampled_seq)
132
+ try:
133
+ recoveries.append(sum(a == b for a, b in zip(native_seq, sampled_seq)) / max(1, len(native_seq)))
134
+ except Exception:
135
+ recoveries.append(None)
136
+ else:
137
+ coords, native_seq = inverse_folding.util.load_coords(pdbfile, chain)
138
+ for _ in range(num_samples):
139
+ sampled_seq = model_obj.sample(coords, temperature=temperature, device=device)
140
+ sampled.append(sampled_seq)
141
+ try:
142
+ recoveries.append(sum(a == b for a, b in zip(native_seq, sampled_seq)) / max(1, len(native_seq)))
143
+ except Exception:
144
+ recoveries.append(None)
145
+
146
+ return {
147
+ "success": True,
148
+ "result": {
149
+ "sampled_sequences": sampled,
150
+ "recovery": recoveries,
151
+ },
152
+ "error": None,
153
+ }
154
  except Exception as e:
155
  return {"success": False, "result": None, "error": str(e)}
156
 
157
  @mcp.tool(name="predict_variant_effect", description="Predict protein variant effects")
158
+ def predict_variant_effect(
159
+ sequence: str,
160
+ mutation: str,
161
+ model_location: str | None = None,
162
+ scoring_strategy: str = "wt-marginals",
163
+ offset_idx: int = 0,
164
+ nogpu: bool = False,
165
+ ):
166
  """
167
+ Score a single point mutation using a pretrained LM.
168
 
169
  Parameters:
170
+ sequence (str): Wildtype protein sequence.
171
+ mutation (str): In the form 'A42G' (WT + 1-based position + MUT). offset_idx can shift position.
172
+ model_location (str|None): Pretrained model name or path. Defaults to an ESM-1v model.
173
+ scoring_strategy (str): 'wt-marginals' (default). Others not implemented in this minimal API.
174
+ offset_idx (int): Position offset (e.g., 1 if your mutation indices are 1-based).
175
+ nogpu (bool): Do not use GPU even if available.
176
 
177
  Returns:
178
+ dict: success/result/error. result contains { score, model, strategy }
179
  """
180
  try:
181
+ import re
182
+ import torch
183
+
184
+ sequence = sequence.strip()
185
+ m = re.match(r"^([ACDEFGHIKLMNPQRSTVWY])(\d+)([ACDEFGHIKLMNPQRSTVWY])$", mutation.strip().upper())
186
+ if not m:
187
+ return {"success": False, "result": None, "error": "Invalid mutation format. Use like 'A42G'"}
188
+ wt, pos_str, mt = m.group(1), m.group(2), m.group(3)
189
+ pos = int(pos_str) - offset_idx # convert to 0-based index
190
+ if pos < 0 or pos >= len(sequence):
191
+ return {"success": False, "result": None, "error": "Mutation position out of range after offset"}
192
+ if sequence[pos].upper() != wt:
193
+ return {"success": False, "result": None, "error": "Wildtype residue does not match sequence at position"}
194
+
195
+ model_name = model_location or "esm1v_t33_650M_UR90S_1"
196
+ model_obj, alphabet = pretrained.load_model_and_alphabet(model_name)
197
+ model_obj = model_obj.eval()
198
+
199
+ if torch.cuda.is_available() and not nogpu:
200
+ model_obj = model_obj.cuda()
201
+
202
+ batch_converter = alphabet.get_batch_converter()
203
+ data = [("protein1", sequence)]
204
+ batch_labels, batch_strs, batch_tokens = batch_converter(data)
205
+
206
+ with torch.no_grad():
207
+ if torch.cuda.is_available() and not nogpu:
208
+ batch_tokens = batch_tokens.cuda()
209
+ logits = model_obj(batch_tokens)["logits"]
210
+ token_log_probs = torch.log_softmax(logits, dim=-1)
211
+
212
+ wt_idx = alphabet.get_idx(wt)
213
+ mt_idx = alphabet.get_idx(mt)
214
+ # +1 for BOS token alignment
215
+ score = (token_log_probs[0, 1 + pos, mt_idx] - token_log_probs[0, 1 + pos, wt_idx]).item()
216
+
217
+ return {
218
+ "success": True,
219
+ "result": {
220
+ "score": score,
221
+ "model": model_name,
222
+ "strategy": scoring_strategy,
223
+ "position_0_based": pos,
224
+ },
225
+ "error": None,
226
+ }
227
  except Exception as e:
228
  return {"success": False, "result": None, "error": str(e)}
229