Honzus24 commited on
Commit
26dd42d
·
verified ·
1 Parent(s): 5f38146

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -3
app.py CHANGED
@@ -24,9 +24,31 @@ from biotite.structure import annotate_sse
24
  import biotite.structure.io as strucio
25
  import biotite.structure.residues as residues
26
  import numpy as np
 
27
 
28
  from data.scripts.data_utils import modify_bfactor_biotite
29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  def process_pdb_file(pdb_file, backbones, sequences, names):
31
  _name = pdb_file[:-4]
32
  _chain = ""
@@ -111,8 +133,14 @@ def flex_seq(input_seq, input_file):
111
  config['inference_args']['device'] = config['inference_args']['device'] if torch.cuda.is_available() else 'cpu'
112
  model, tokenizer = PT5_classification_model(half_precision=config['mixed_precision'], class_config=class_config)
113
  model.to(config['inference_args']['device'])
114
- state_dict = torch.load(config['inference_args']['seq_model_path'], map_location=config['inference_args']['device'])
115
- model.load_state_dict(state_dict, strict=False)
 
 
 
 
 
 
116
  model.eval()
117
 
118
  data_to_collate = []
@@ -240,7 +268,14 @@ def flex_3d(input_file):
240
 
241
  model.to(config['inference_args']['device'])
242
  print("Loading 3D model from {}".format(config['inference_args']['3d_model_path']))
243
- state_dict = torch.load(config['inference_args']['3d_model_path'], map_location=config['inference_args']['device'])
 
 
 
 
 
 
 
244
  model.load_state_dict(state_dict, strict=False)
245
  model.eval()
246
 
 
24
  import biotite.structure.io as strucio
25
  import biotite.structure.residues as residues
26
  import numpy as np
27
+ from huggingface_hub import hf_hub_download, utils
28
 
29
  from data.scripts.data_utils import modify_bfactor_biotite
30
 
31
+ def get_weights_path(repo_id, filename):
32
+ """
33
+ Tries to get the local path immediately. If not found, downloads it.
34
+ """
35
+ print(f"Looking for {filename} in {repo_id}...")
36
+ try:
37
+ # 1. FASTEST: Try loading entirely from local cache (no internet check)
38
+ return hf_hub_download(
39
+ repo_id=repo_id,
40
+ filename=filename,
41
+ local_files_only=True
42
+ )
43
+ except (utils.EntryNotFoundError, utils.LocalEntryNotFoundError, FileNotFoundError):
44
+ # 2. FALLBACK: If not found locally, download it (cached for next time)
45
+ print(f"Weights not found locally. Downloading from HF Hub...")
46
+ return hf_hub_download(
47
+ repo_id=repo_id,
48
+ filename=filename,
49
+ local_files_only=False
50
+ )
51
+
52
  def process_pdb_file(pdb_file, backbones, sequences, names):
53
  _name = pdb_file[:-4]
54
  _chain = ""
 
133
  config['inference_args']['device'] = config['inference_args']['device'] if torch.cuda.is_available() else 'cpu'
134
  model, tokenizer = PT5_classification_model(half_precision=config['mixed_precision'], class_config=class_config)
135
  model.to(config['inference_args']['device'])
136
+ repo_id = "Honzus24/Flexpert_weights"
137
+ file_weights = config['inference_args']['seq_model_path']
138
+
139
+ # Get path (instant if cached)
140
+ weights_path = get_weights_path(repo_id, file_weights)
141
+
142
+ # Load weights
143
+ state_dict = torch.load(weights_path, map_location=config['inference_args']['device']) model.load_state_dict(state_dict, strict=False)
144
  model.eval()
145
 
146
  data_to_collate = []
 
268
 
269
  model.to(config['inference_args']['device'])
270
  print("Loading 3D model from {}".format(config['inference_args']['3d_model_path']))
271
+ repo_id = "Honzus24/Flexpert_weights"
272
+ file_weights = config['inference_args']['3d_model_path']
273
+
274
+ # Get path (instant if cached)
275
+ weights_path = get_weights_path(repo_id, file_weights)
276
+
277
+ # Load weights
278
+ state_dict = torch.load(weights_path, map_location=config['inference_args']['device'])
279
  model.load_state_dict(state_dict, strict=False)
280
  model.eval()
281