samsl commited on
Commit
602a930
·
1 Parent(s): f0a85c1

fid input parsing and cuda

Browse files
Files changed (1) hide show
  1. app.py +14 -5
app.py CHANGED
@@ -21,6 +21,7 @@ from huggingface_hub.utils import GatedRepoError, RepositoryNotFoundError
21
  from loguru import logger
22
  from matplotlib.cm import ScalarMappable
23
  from matplotlib.colors import Normalize
 
24
 
25
  from rocketshp import RocketSHP
26
  from rocketshp import load_sequence as get_sequence_features
@@ -163,9 +164,11 @@ def predict_rocketshp(
163
  if structure_code == "":
164
  raise gr.Error("Structure input is required for the selected model.")
165
 
 
 
166
  structure_tmp_dir = tempfile.TemporaryDirectory()
167
  structure_file = rcsb.fetch(
168
- structure_code,
169
  "pdb",
170
  target_path=structure_tmp_dir.name,
171
  )
@@ -247,7 +250,14 @@ def predict_rocketshp(
247
  rmsf = dynamics_pred["rmsf"].squeeze().cpu().numpy()
248
  gcc_lmi = dynamics_pred["gcc_lmi"].squeeze().cpu().numpy()
249
  shp = dynamics_pred["shp"].squeeze().cpu().numpy()
250
- ca_dist = dynamics_pred["ca_dist"].squeeze().cpu().numpy()
 
 
 
 
 
 
 
251
 
252
  fig, plot_file_name = plot_predictions(
253
  rmsf,
@@ -439,12 +449,10 @@ def check_user_access(oauth_token: gr.OAuthToken | None):
439
  raise gr.Error("Please log in to use this Space")
440
 
441
  token = oauth_token.token
442
- print(token)
443
 
444
  try:
445
  # Try to access a file from your private repo
446
  api = HfApi(token=token)
447
- print(api)
448
 
449
  # Test access by trying to get repo info
450
  _ = api.repo_info(
@@ -531,6 +539,7 @@ with rocketshp_gradio:
531
  sequence = gr.State([])
532
 
533
  gr.LoginButton()
 
534
  model_variant = gr.Dropdown(
535
  label="Select RocketSHP Model",
536
  choices=["latest", "v1_seq", "v1_mini"],
@@ -645,4 +654,4 @@ with rocketshp_gradio:
645
 
646
 
647
  if __name__ == "__main__":
648
- rocketshp_gradio.launch(share=False)
 
21
  from loguru import logger
22
  from matplotlib.cm import ScalarMappable
23
  from matplotlib.colors import Normalize
24
+ from scipy.spatial.distance import cdist
25
 
26
  from rocketshp import RocketSHP
27
  from rocketshp import load_sequence as get_sequence_features
 
164
  if structure_code == "":
165
  raise gr.Error("Structure input is required for the selected model.")
166
 
167
+ structure_code = structure_code.strip().upper()
168
+
169
  structure_tmp_dir = tempfile.TemporaryDirectory()
170
  structure_file = rcsb.fetch(
171
+ structure_code.strip(),
172
  "pdb",
173
  target_path=structure_tmp_dir.name,
174
  )
 
250
  rmsf = dynamics_pred["rmsf"].squeeze().cpu().numpy()
251
  gcc_lmi = dynamics_pred["gcc_lmi"].squeeze().cpu().numpy()
252
  shp = dynamics_pred["shp"].squeeze().cpu().numpy()
253
+
254
+ if is_sequence_model:
255
+ ca_dist = dynamics_pred["ca_dist"].squeeze().cpu().numpy()
256
+ else:
257
+ ca_struct = structure[bs.filter_amino_acids(structure)]
258
+ ca_struct = structure[structure.atom_name == "CA"]
259
+ ca_dist = cdist(ca_struct.coord, ca_struct.coord)
260
+ ca_dist /= 10.0 # convert to nm
261
 
262
  fig, plot_file_name = plot_predictions(
263
  rmsf,
 
449
  raise gr.Error("Please log in to use this Space")
450
 
451
  token = oauth_token.token
 
452
 
453
  try:
454
  # Try to access a file from your private repo
455
  api = HfApi(token=token)
 
456
 
457
  # Test access by trying to get repo info
458
  _ = api.repo_info(
 
539
  sequence = gr.State([])
540
 
541
  gr.LoginButton()
542
+
543
  model_variant = gr.Dropdown(
544
  label="Select RocketSHP Model",
545
  choices=["latest", "v1_seq", "v1_mini"],
 
654
 
655
 
656
  if __name__ == "__main__":
657
+ rocketshp_gradio.launch(share=True)