bernardo-de-almeida commited on
Commit
33a2516
·
1 Parent(s): 3fb06c7

feat: improve usage of GPU

Browse files
Files changed (1) hide show
  1. app.py +13 -1
app.py CHANGED
@@ -7,6 +7,7 @@ import numpy as np
7
  import gradio as gr
8
  import asyncio
9
  import spaces
 
10
 
11
  # Set matplotlib to use non-interactive backend before importing pyplot
12
  # This is required for Gradio which runs on worker threads
@@ -55,7 +56,7 @@ def load_pipeline(model_id: str, species: str = DEFAULT_SPECIES):
55
  pipe = load_ntv3_tracks_pipeline(
56
  model=model_id,
57
  token=HF_TOKEN,
58
- device="auto",
59
  default_species=species,
60
  verbose=False,
61
  )
@@ -427,8 +428,19 @@ def predict(
427
  if "species" not in inputs:
428
  raise gr.Error(f"Internal error: species not found in inputs dict. Inputs: {list(inputs.keys())}")
429
 
 
 
 
 
 
 
 
430
  out = pipe(inputs)
431
 
 
 
 
 
432
  bw_names = out.bigwig_track_names or []
433
  bw = out.bigwig_tracks_logits
434
  bed_names = out.bed_element_names or []
 
7
  import gradio as gr
8
  import asyncio
9
  import spaces
10
+ import torch
11
 
12
  # Set matplotlib to use non-interactive backend before importing pyplot
13
  # This is required for Gradio which runs on worker threads
 
56
  pipe = load_ntv3_tracks_pipeline(
57
  model=model_id,
58
  token=HF_TOKEN,
59
+ device="cpu", # This prevents the pipeline constructor from doing model.to(cuda) during import.
60
  default_species=species,
61
  verbose=False,
62
  )
 
428
  if "species" not in inputs:
429
  raise gr.Error(f"Internal error: species not found in inputs dict. Inputs: {list(inputs.keys())}")
430
 
431
+ # move to GPU only once the ZeroGPU context is active
432
+ device = "cuda" if torch.cuda.is_available() else "cpu"
433
+ pipe.model.to(device)
434
+ pipe.model.eval()
435
+ print(f"Running on {next(pipe.model.parameters()).device}")
436
+
437
+ # run inference
438
  out = pipe(inputs)
439
 
440
+ # optional: move back to CPU so you don’t rely on any persistent CUDA context
441
+ if device == "cuda":
442
+ pipe.model.to("cpu")
443
+
444
  bw_names = out.bigwig_track_names or []
445
  bw = out.bigwig_tracks_logits
446
  bed_names = out.bed_element_names or []