Honzus24 commited on
Commit
20fff56
·
verified ·
1 Parent(s): d73fc37

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +81 -22
app.py CHANGED
@@ -32,6 +32,44 @@ import biotite.sequence as seq
32
 
33
  from data.scripts.data_utils import modify_bfactor_biotite
34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  def get_first_chain_id(pdb_file):
36
  try:
37
  # Load the PDB file
@@ -156,16 +194,29 @@ def core_flex_seq(input_seq, input_file, force_cpu=False):
156
  target_device = config['inference_args']['device'] if torch.cuda.is_available() else 'cpu'
157
  config['inference_args']['device'] = target_device
158
 
159
- model, tokenizer = PT5_classification_model(half_precision=config['mixed_precision'], class_config=class_config)
160
- model.to(target_device)
161
 
162
- repo_id = "Honzus24/Flexpert_weights"
163
- file_weights = config['inference_args']['seq_model_path']
164
- weights_path = get_weights_path(repo_id, file_weights)
165
-
166
- state_dict = torch.load(weights_path, map_location=target_device)
167
- model.load_state_dict(state_dict, strict=False)
168
- model.eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
169
 
170
  data_to_collate = []
171
  for idx, (backbone, sequence) in enumerate(zip(backbones, sequences)):
@@ -299,20 +350,28 @@ def core_flex_3d(input_file):
299
  target_device = config['inference_args']['device'] if torch.cuda.is_available() else 'cpu'
300
  config['inference_args']['device'] = target_device
301
 
302
- model, tokenizer = PT5_classification_model(half_precision=config['mixed_precision'], class_config=class_config)
303
- model.to(config['inference_args']['device'])
304
-
305
- repo_id = "Honzus24/Flexpert_weights"
306
- print("Loading 3D model from {}".format(config['inference_args']['3d_model_path']))
307
- file_weights = config['inference_args']['3d_model_path']
308
 
309
- # Get path (instant if cached)
310
- weights_path = get_weights_path(repo_id, file_weights)
311
-
312
- # Load weights
313
- state_dict = torch.load(weights_path, map_location=config['inference_args']['device'])
314
- model.load_state_dict(state_dict, strict=False)
315
- model.eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
316
 
317
  data_to_collate = []
318
  for idx, (backbone, sequence) in enumerate(zip(backbones, sequences)):
 
32
 
33
  from data.scripts.data_utils import modify_bfactor_biotite
34
 
35
+ GLOBAL_MODEL_CACHE = {}
36
+
37
+ def get_loaded_model_and_tokenizer(target_device):
38
+ """Loads the model once and caches it in memory."""
39
+ global _GLOBAL_MODEL, _GLOBAL_TOKENIZER
40
+
41
+ if _GLOBAL_MODEL is not None and _GLOBAL_TOKENIZER is not None:
42
+ # Move to the requested device (ZeroGPU handles this dynamically)
43
+ _GLOBAL_MODEL.to(target_device)
44
+ return _GLOBAL_MODEL, _GLOBAL_TOKENIZER
45
+
46
+ print("First run: Initializing model and downloading/loading weights...")
47
+
48
+ # Initialize architecture
49
+ model, tokenizer = PT5_classification_model(
50
+ half_precision=config['mixed_precision'],
51
+ class_config=class_config
52
+ )
53
+
54
+ # Fetch weights
55
+ repo_id = "Honzus24/Flexpert_weights"
56
+ file_weights = config['inference_args']['seq_model_path']
57
+ weights_path = get_weights_path(repo_id, file_weights)
58
+
59
+ # Load weights into CPU memory first (best practice for HF Spaces)
60
+ state_dict = torch.load(weights_path, map_location='cpu')
61
+ model.load_state_dict(state_dict, strict=False)
62
+ model.eval()
63
+
64
+ # Cache them globally
65
+ _GLOBAL_MODEL = model
66
+ _GLOBAL_TOKENIZER = tokenizer
67
+
68
+ # Move to the target device
69
+ _GLOBAL_MODEL.to(target_device)
70
+
71
+ return _GLOBAL_MODEL, _GLOBAL_TOKENIZER
72
+
73
  def get_first_chain_id(pdb_file):
74
  try:
75
  # Load the PDB file
 
194
  target_device = config['inference_args']['device'] if torch.cuda.is_available() else 'cpu'
195
  config['inference_args']['device'] = target_device
196
 
197
+ global GLOBAL_MODEL_CACHE
198
+ model_key = 'seq'
199
 
200
+ if model_key not in GLOBAL_MODEL_CACHE:
201
+ # 1. Initialize model
202
+ model, tokenizer = PT5_classification_model(half_precision=config['mixed_precision'], class_config=class_config)
203
+
204
+ # 2. Get weights path
205
+ repo_id = "Honzus24/Flexpert_weights"
206
+ file_weights = config['inference_args']['seq_model_path'] # Update for 3D if needed
207
+ weights_path = get_weights_path(repo_id, file_weights)
208
+
209
+ # 3. Load weights to CPU first (Crucial for ZeroGPU Spaces compatibility)
210
+ state_dict = torch.load(weights_path, map_location='cpu')
211
+ model.load_state_dict(state_dict, strict=False)
212
+ model.eval()
213
+
214
+ # 4. Save to cache
215
+ GLOBAL_MODEL_CACHE[model_key] = (model, tokenizer)
216
+
217
+ # Retrieve from cache and move to the dynamically assigned device
218
+ model, tokenizer = GLOBAL_MODEL_CACHE[model_key]
219
+ model.to(target_device)
220
 
221
  data_to_collate = []
222
  for idx, (backbone, sequence) in enumerate(zip(backbones, sequences)):
 
350
  target_device = config['inference_args']['device'] if torch.cuda.is_available() else 'cpu'
351
  config['inference_args']['device'] = target_device
352
 
353
+ model_key = '3d'
 
 
 
 
 
354
 
355
+ if model_key not in GLOBAL_MODEL_CACHE:
356
+ # 1. Initialize model
357
+ model, tokenizer = PT5_classification_model(half_precision=config['mixed_precision'], class_config=class_config)
358
+
359
+ # 2. Get weights path
360
+ repo_id = "Honzus24/Flexpert_weights"
361
+ file_weights = config['inference_args']['3d_model_path'] # Update for 3D if needed
362
+ weights_path = get_weights_path(repo_id, file_weights)
363
+
364
+ # 3. Load weights to CPU first (Crucial for ZeroGPU Spaces compatibility)
365
+ state_dict = torch.load(weights_path, map_location='cpu')
366
+ model.load_state_dict(state_dict, strict=False)
367
+ model.eval()
368
+
369
+ # 4. Save to cache
370
+ GLOBAL_MODEL_CACHE[model_key] = (model, tokenizer)
371
+
372
+ # Retrieve from cache and move to the dynamically assigned device
373
+ model, tokenizer = GLOBAL_MODEL_CACHE[model_key]
374
+ model.to(target_device)
375
 
376
  data_to_collate = []
377
  for idx, (backbone, sequence) in enumerate(zip(backbones, sequences)):