Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
33a2516
1
Parent(s):
3fb06c7
feat: improve usage of GPU
Browse files
app.py
CHANGED
|
@@ -7,6 +7,7 @@ import numpy as np
|
|
| 7 |
import gradio as gr
|
| 8 |
import asyncio
|
| 9 |
import spaces
|
|
|
|
| 10 |
|
| 11 |
# Set matplotlib to use non-interactive backend before importing pyplot
|
| 12 |
# This is required for Gradio which runs on worker threads
|
|
@@ -55,7 +56,7 @@ def load_pipeline(model_id: str, species: str = DEFAULT_SPECIES):
|
|
| 55 |
pipe = load_ntv3_tracks_pipeline(
|
| 56 |
model=model_id,
|
| 57 |
token=HF_TOKEN,
|
| 58 |
-
device="
|
| 59 |
default_species=species,
|
| 60 |
verbose=False,
|
| 61 |
)
|
|
@@ -427,8 +428,19 @@ def predict(
|
|
| 427 |
if "species" not in inputs:
|
| 428 |
raise gr.Error(f"Internal error: species not found in inputs dict. Inputs: {list(inputs.keys())}")
|
| 429 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 430 |
out = pipe(inputs)
|
| 431 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 432 |
bw_names = out.bigwig_track_names or []
|
| 433 |
bw = out.bigwig_tracks_logits
|
| 434 |
bed_names = out.bed_element_names or []
|
|
|
|
| 7 |
import gradio as gr
|
| 8 |
import asyncio
|
| 9 |
import spaces
|
| 10 |
+
import torch
|
| 11 |
|
| 12 |
# Set matplotlib to use non-interactive backend before importing pyplot
|
| 13 |
# This is required for Gradio which runs on worker threads
|
|
|
|
| 56 |
pipe = load_ntv3_tracks_pipeline(
|
| 57 |
model=model_id,
|
| 58 |
token=HF_TOKEN,
|
| 59 |
+
device="cpu", # This prevents the pipeline constructor from doing model.to(cuda) during import.
|
| 60 |
default_species=species,
|
| 61 |
verbose=False,
|
| 62 |
)
|
|
|
|
| 428 |
if "species" not in inputs:
|
| 429 |
raise gr.Error(f"Internal error: species not found in inputs dict. Inputs: {list(inputs.keys())}")
|
| 430 |
|
| 431 |
+
# move to GPU only once the ZeroGPU context is active
|
| 432 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 433 |
+
pipe.model.to(device)
|
| 434 |
+
pipe.model.eval()
|
| 435 |
+
print(f"Running on {next(pipe.model.parameters()).device}")
|
| 436 |
+
|
| 437 |
+
# run inference
|
| 438 |
out = pipe(inputs)
|
| 439 |
|
| 440 |
+
# optional: move back to CPU so you don’t rely on any persistent CUDA context
|
| 441 |
+
if device == "cuda":
|
| 442 |
+
pipe.model.to("cpu")
|
| 443 |
+
|
| 444 |
bw_names = out.bigwig_track_names or []
|
| 445 |
bw = out.bigwig_tracks_logits
|
| 446 |
bed_names = out.bed_element_names or []
|