bernardo-de-almeida commited on
Commit
e2bc640
·
1 Parent(s): 9dd80fe

feat: improve timings

Browse files
Files changed (1) hide show
  1. app.py +9 -3
app.py CHANGED
@@ -491,11 +491,16 @@ def predict(
491
 
492
  # move to GPU only once the ZeroGPU context is active
493
  device = "cuda" if torch.cuda.is_available() else "cpu"
494
- pipe.model.to(device)
 
 
 
 
 
 
495
  pipe.model.eval()
496
  print(f"Running on {next(pipe.model.parameters()).device}")
497
-
498
- tprint(f"pipe.model.to({device})")
499
 
500
  # run inference
501
  out = pipe(inputs)
@@ -597,6 +602,7 @@ def predict(
597
  fig.axes[-1].set_xlabel(region)
598
 
599
  png_path = _save_fig_png(fig)
 
600
 
601
  meta = {
602
  "model_id": current_model_id,
 
491
 
492
  # move to GPU only once the ZeroGPU context is active
493
  device = "cuda" if torch.cuda.is_available() else "cpu"
494
+ # check where the model currently lives
495
+ current = next(pipe.model.parameters()).device.type # "cpu" or "cuda"
496
+ # only move if needed
497
+ if current != device:
498
+ pipe.model.to(device)
499
+ tprint(f"model.ensure_on({device})")
500
+
501
  pipe.model.eval()
502
  print(f"Running on {next(pipe.model.parameters()).device}")
503
+ tprint("model ready to run inference")
 
504
 
505
  # run inference
506
  out = pipe(inputs)
 
602
  fig.axes[-1].set_xlabel(region)
603
 
604
  png_path = _save_fig_png(fig)
605
+ tprint("figure png saved")
606
 
607
  meta = {
608
  "model_id": current_model_id,