Spaces:
Running on Zero
Running on Zero
Update app.py
Browse files
app.py
CHANGED
|
@@ -24,9 +24,31 @@ from biotite.structure import annotate_sse
|
|
| 24 |
import biotite.structure.io as strucio
|
| 25 |
import biotite.structure.residues as residues
|
| 26 |
import numpy as np
|
|
|
|
| 27 |
|
| 28 |
from data.scripts.data_utils import modify_bfactor_biotite
|
| 29 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
def process_pdb_file(pdb_file, backbones, sequences, names):
|
| 31 |
_name = pdb_file[:-4]
|
| 32 |
_chain = ""
|
|
@@ -111,8 +133,14 @@ def flex_seq(input_seq, input_file):
|
|
| 111 |
config['inference_args']['device'] = config['inference_args']['device'] if torch.cuda.is_available() else 'cpu'
|
| 112 |
model, tokenizer = PT5_classification_model(half_precision=config['mixed_precision'], class_config=class_config)
|
| 113 |
model.to(config['inference_args']['device'])
|
| 114 |
-
|
| 115 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 116 |
model.eval()
|
| 117 |
|
| 118 |
data_to_collate = []
|
|
@@ -240,7 +268,14 @@ def flex_3d(input_file):
|
|
| 240 |
|
| 241 |
model.to(config['inference_args']['device'])
|
| 242 |
print("Loading 3D model from {}".format(config['inference_args']['3d_model_path']))
|
| 243 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 244 |
model.load_state_dict(state_dict, strict=False)
|
| 245 |
model.eval()
|
| 246 |
|
|
|
|
| 24 |
import biotite.structure.io as strucio
|
| 25 |
import biotite.structure.residues as residues
|
| 26 |
import numpy as np
|
| 27 |
+
from huggingface_hub import hf_hub_download, utils
|
| 28 |
|
| 29 |
from data.scripts.data_utils import modify_bfactor_biotite
|
| 30 |
|
| 31 |
+
def get_weights_path(repo_id, filename):
|
| 32 |
+
"""
|
| 33 |
+
Tries to get the local path immediately. If not found, downloads it.
|
| 34 |
+
"""
|
| 35 |
+
print(f"Looking for {filename} in {repo_id}...")
|
| 36 |
+
try:
|
| 37 |
+
# 1. FASTEST: Try loading entirely from local cache (no internet check)
|
| 38 |
+
return hf_hub_download(
|
| 39 |
+
repo_id=repo_id,
|
| 40 |
+
filename=filename,
|
| 41 |
+
local_files_only=True
|
| 42 |
+
)
|
| 43 |
+
except (utils.EntryNotFoundError, utils.LocalEntryNotFoundError, FileNotFoundError):
|
| 44 |
+
# 2. FALLBACK: If not found locally, download it (cached for next time)
|
| 45 |
+
print(f"Weights not found locally. Downloading from HF Hub...")
|
| 46 |
+
return hf_hub_download(
|
| 47 |
+
repo_id=repo_id,
|
| 48 |
+
filename=filename,
|
| 49 |
+
local_files_only=False
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
def process_pdb_file(pdb_file, backbones, sequences, names):
|
| 53 |
_name = pdb_file[:-4]
|
| 54 |
_chain = ""
|
|
|
|
| 133 |
config['inference_args']['device'] = config['inference_args']['device'] if torch.cuda.is_available() else 'cpu'
|
| 134 |
model, tokenizer = PT5_classification_model(half_precision=config['mixed_precision'], class_config=class_config)
|
| 135 |
model.to(config['inference_args']['device'])
|
| 136 |
+
repo_id = "Honzus24/Flexpert_weights"
|
| 137 |
+
file_weights = config['inference_args']['seq_model_path']
|
| 138 |
+
|
| 139 |
+
# Get path (instant if cached)
|
| 140 |
+
weights_path = get_weights_path(repo_id, file_weights)
|
| 141 |
+
|
| 142 |
+
# Load weights
|
| 143 |
+
state_dict = torch.load(weights_path, map_location=config['inference_args']['device']) model.load_state_dict(state_dict, strict=False)
|
| 144 |
model.eval()
|
| 145 |
|
| 146 |
data_to_collate = []
|
|
|
|
| 268 |
|
| 269 |
model.to(config['inference_args']['device'])
|
| 270 |
print("Loading 3D model from {}".format(config['inference_args']['3d_model_path']))
|
| 271 |
+
repo_id = "Honzus24/Flexpert_weights"
|
| 272 |
+
file_weights = config['inference_args']['3d_model_path']
|
| 273 |
+
|
| 274 |
+
# Get path (instant if cached)
|
| 275 |
+
weights_path = get_weights_path(repo_id, file_weights)
|
| 276 |
+
|
| 277 |
+
# Load weights
|
| 278 |
+
state_dict = torch.load(weights_path, map_location=config['inference_args']['device'])
|
| 279 |
model.load_state_dict(state_dict, strict=False)
|
| 280 |
model.eval()
|
| 281 |
|