fid input parsing and cuda
Browse files
app.py
CHANGED
|
@@ -21,6 +21,7 @@ from huggingface_hub.utils import GatedRepoError, RepositoryNotFoundError
|
|
| 21 |
from loguru import logger
|
| 22 |
from matplotlib.cm import ScalarMappable
|
| 23 |
from matplotlib.colors import Normalize
|
|
|
|
| 24 |
|
| 25 |
from rocketshp import RocketSHP
|
| 26 |
from rocketshp import load_sequence as get_sequence_features
|
|
@@ -163,9 +164,11 @@ def predict_rocketshp(
|
|
| 163 |
if structure_code == "":
|
| 164 |
raise gr.Error("Structure input is required for the selected model.")
|
| 165 |
|
|
|
|
|
|
|
| 166 |
structure_tmp_dir = tempfile.TemporaryDirectory()
|
| 167 |
structure_file = rcsb.fetch(
|
| 168 |
-
structure_code,
|
| 169 |
"pdb",
|
| 170 |
target_path=structure_tmp_dir.name,
|
| 171 |
)
|
|
@@ -247,7 +250,14 @@ def predict_rocketshp(
|
|
| 247 |
rmsf = dynamics_pred["rmsf"].squeeze().cpu().numpy()
|
| 248 |
gcc_lmi = dynamics_pred["gcc_lmi"].squeeze().cpu().numpy()
|
| 249 |
shp = dynamics_pred["shp"].squeeze().cpu().numpy()
|
| 250 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 251 |
|
| 252 |
fig, plot_file_name = plot_predictions(
|
| 253 |
rmsf,
|
|
@@ -439,12 +449,10 @@ def check_user_access(oauth_token: gr.OAuthToken | None):
|
|
| 439 |
raise gr.Error("Please log in to use this Space")
|
| 440 |
|
| 441 |
token = oauth_token.token
|
| 442 |
-
print(token)
|
| 443 |
|
| 444 |
try:
|
| 445 |
# Try to access a file from your private repo
|
| 446 |
api = HfApi(token=token)
|
| 447 |
-
print(api)
|
| 448 |
|
| 449 |
# Test access by trying to get repo info
|
| 450 |
_ = api.repo_info(
|
|
@@ -531,6 +539,7 @@ with rocketshp_gradio:
|
|
| 531 |
sequence = gr.State([])
|
| 532 |
|
| 533 |
gr.LoginButton()
|
|
|
|
| 534 |
model_variant = gr.Dropdown(
|
| 535 |
label="Select RocketSHP Model",
|
| 536 |
choices=["latest", "v1_seq", "v1_mini"],
|
|
@@ -645,4 +654,4 @@ with rocketshp_gradio:
|
|
| 645 |
|
| 646 |
|
| 647 |
if __name__ == "__main__":
|
| 648 |
-
rocketshp_gradio.launch(share=
|
|
|
|
| 21 |
from loguru import logger
|
| 22 |
from matplotlib.cm import ScalarMappable
|
| 23 |
from matplotlib.colors import Normalize
|
| 24 |
+
from scipy.spatial.distance import cdist
|
| 25 |
|
| 26 |
from rocketshp import RocketSHP
|
| 27 |
from rocketshp import load_sequence as get_sequence_features
|
|
|
|
| 164 |
if structure_code == "":
|
| 165 |
raise gr.Error("Structure input is required for the selected model.")
|
| 166 |
|
| 167 |
+
structure_code = structure_code.strip().upper()
|
| 168 |
+
|
| 169 |
structure_tmp_dir = tempfile.TemporaryDirectory()
|
| 170 |
structure_file = rcsb.fetch(
|
| 171 |
+
structure_code.strip(),
|
| 172 |
"pdb",
|
| 173 |
target_path=structure_tmp_dir.name,
|
| 174 |
)
|
|
|
|
| 250 |
rmsf = dynamics_pred["rmsf"].squeeze().cpu().numpy()
|
| 251 |
gcc_lmi = dynamics_pred["gcc_lmi"].squeeze().cpu().numpy()
|
| 252 |
shp = dynamics_pred["shp"].squeeze().cpu().numpy()
|
| 253 |
+
|
| 254 |
+
if is_sequence_model:
|
| 255 |
+
ca_dist = dynamics_pred["ca_dist"].squeeze().cpu().numpy()
|
| 256 |
+
else:
|
| 257 |
+
ca_struct = structure[bs.filter_amino_acids(structure)]
|
| 258 |
+
ca_struct = structure[structure.atom_name == "CA"]
|
| 259 |
+
ca_dist = cdist(ca_struct.coord, ca_struct.coord)
|
| 260 |
+
ca_dist /= 10.0 # convert to nm
|
| 261 |
|
| 262 |
fig, plot_file_name = plot_predictions(
|
| 263 |
rmsf,
|
|
|
|
| 449 |
raise gr.Error("Please log in to use this Space")
|
| 450 |
|
| 451 |
token = oauth_token.token
|
|
|
|
| 452 |
|
| 453 |
try:
|
| 454 |
# Try to access a file from your private repo
|
| 455 |
api = HfApi(token=token)
|
|
|
|
| 456 |
|
| 457 |
# Test access by trying to get repo info
|
| 458 |
_ = api.repo_info(
|
|
|
|
| 539 |
sequence = gr.State([])
|
| 540 |
|
| 541 |
gr.LoginButton()
|
| 542 |
+
|
| 543 |
model_variant = gr.Dropdown(
|
| 544 |
label="Select RocketSHP Model",
|
| 545 |
choices=["latest", "v1_seq", "v1_mini"],
|
|
|
|
| 654 |
|
| 655 |
|
| 656 |
if __name__ == "__main__":
|
| 657 |
+
rocketshp_gradio.launch(share=True)
|