Honzus24 commited on
Commit
d73fc37
·
verified ·
1 Parent(s): 6886821

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -4
app.py CHANGED
@@ -230,8 +230,7 @@ def flex_seq(input_seq, input_file):
230
  print("Falling back to CPU execution. This may take a while...")
231
  return core_flex_seq(input_seq, input_file, force_cpu=True)
232
 
233
- @spaces.GPU
234
- def flex_3d(input_file):
235
  if not input_file:
236
  return None, "Provide a file or a input sequence"
237
 
@@ -293,10 +292,16 @@ def flex_3d(input_file):
293
  config = yaml.load(open('configs/train_config.yaml', 'r'), Loader=yaml.FullLoader)
294
  class_config=ClassConfig(config)
295
  class_config.adaptor_architecture = 'conv'
296
- config['inference_args']['device'] = config['inference_args']['device'] if torch.cuda.is_available() else 'cpu'
297
- model, tokenizer = PT5_classification_model(half_precision=config['mixed_precision'], class_config=class_config)
298
 
 
 
 
 
 
 
 
299
  model.to(config['inference_args']['device'])
 
300
  repo_id = "Honzus24/Flexpert_weights"
301
  print("Loading 3D model from {}".format(config['inference_args']['3d_model_path']))
302
  file_weights = config['inference_args']['3d_model_path']
@@ -402,6 +407,26 @@ def flex_3d(input_file):
402
 
403
  return output_files, output_message, output_files_enm
404
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
405
  def rescale_bfactors(pdb_file):
406
  base, ext = os.path.splitext(pdb_file)
407
  # Create the new filename
 
230
  print("Falling back to CPU execution. This may take a while...")
231
  return core_flex_seq(input_seq, input_file, force_cpu=True)
232
 
233
+ def core_flex_3d(input_file):
 
234
  if not input_file:
235
  return None, "Provide a file or a input sequence"
236
 
 
292
  config = yaml.load(open('configs/train_config.yaml', 'r'), Loader=yaml.FullLoader)
293
  class_config=ClassConfig(config)
294
  class_config.adaptor_architecture = 'conv'
 
 
295
 
296
+ if force_cpu:
297
+ target_device = 'cpu'
298
+ else:
299
+ target_device = config['inference_args']['device'] if torch.cuda.is_available() else 'cpu'
300
+ config['inference_args']['device'] = target_device
301
+
302
+ model, tokenizer = PT5_classification_model(half_precision=config['mixed_precision'], class_config=class_config)
303
  model.to(config['inference_args']['device'])
304
+
305
  repo_id = "Honzus24/Flexpert_weights"
306
  print("Loading 3D model from {}".format(config['inference_args']['3d_model_path']))
307
  file_weights = config['inference_args']['3d_model_path']
 
407
 
408
  return output_files, output_message, output_files_enm
409
 
410
+ @spaces.GPU
411
+ def flex_3d_gpu(input_seq, input_file):
412
+ """Wrapper that requests ZeroGPU."""
413
+ return core_flex_3d(input_seq, input_file, force_cpu=False)
414
+
415
+ def flex_3d(input_seq, input_file):
416
+ """
417
+ Main entry point for Gradio.
418
+ Tries GPU first, falls back to CPU if quota is exceeded or time limit runs out.
419
+ """
420
+ try:
421
+ print("Attempting to run on ZeroGPU...")
422
+ return flex_3d_gpu(input_seq, input_file)
423
+ except Exception as e:
424
+ # ZeroGPU exceptions (like SpaceTaskError or timeouts) are caught here
425
+ print(f"ZeroGPU failed or timed out. Reason: {e}")
426
+ print("Falling back to CPU execution. This may take longer...")
427
+ return core_flex_3d(input_seq, input_file, force_cpu=True)
428
+
429
+
430
  def rescale_bfactors(pdb_file):
431
  base, ext = os.path.splitext(pdb_file)
432
  # Create the new filename