bernardo-de-almeida commited on
Commit
eec0acd
·
1 Parent(s): 33a2516

feat: add timings

Browse files
Files changed (1) hide show
  1. app.py +35 -8
app.py CHANGED
@@ -1,13 +1,13 @@
 
1
  import os
2
- import uuid
3
  import tempfile
4
- import csv
 
5
  from pathlib import Path
 
6
  import numpy as np
7
  import gradio as gr
8
- import asyncio
9
  import spaces
10
- import torch
11
 
12
  # Set matplotlib to use non-interactive backend before importing pyplot
13
  # This is required for Gradio which runs on worker threads
@@ -43,7 +43,6 @@ if HF_TOKEN is None:
43
  PLOT_TARGET_POINTS = int(os.environ.get("PLOT_TARGET_POINTS", "1500"))
44
  SEARCH_MAX_RESULTS = int(os.environ.get("SEARCH_MAX_RESULTS", "50"))
45
 
46
-
47
  # -----------------------------
48
  # Load pipeline (reloadable)
49
  # -----------------------------
@@ -70,6 +69,24 @@ load_pipeline(MODEL_ID, DEFAULT_SPECIES)
70
  # -----------------------------
71
  # Helpers
72
  # -----------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  def _global_stride(L: int, target: int) -> int:
74
  if target <= 0 or L <= target:
75
  return 1
@@ -396,6 +413,8 @@ def predict(
396
  bigwig_selected: list[str],
397
  bed_elements: list[str],
398
  ):
 
 
399
  # Debug: verify species is being passed
400
  if not species:
401
  raise gr.Error("Species parameter is missing. Please select a species.")
@@ -428,18 +447,24 @@ def predict(
428
  if "species" not in inputs:
429
  raise gr.Error(f"Internal error: species not found in inputs dict. Inputs: {list(inputs.keys())}")
430
 
 
 
431
  # move to GPU only once the ZeroGPU context is active
432
  device = "cuda" if torch.cuda.is_available() else "cpu"
433
  pipe.model.to(device)
434
  pipe.model.eval()
435
  print(f"Running on {next(pipe.model.parameters()).device}")
436
 
 
 
437
  # run inference
438
  out = pipe(inputs)
439
 
 
 
440
  # optional: move back to CPU so you don’t rely on any persistent CUDA context
441
- if device == "cuda":
442
- pipe.model.to("cpu")
443
 
444
  bw_names = out.bigwig_track_names or []
445
  bw = out.bigwig_tracks_logits
@@ -516,7 +541,9 @@ def predict(
516
  eidx = bed_names.index(ename)
517
  series.append((ename, probs[:, eidx, 1][::stride].astype(float)))
518
 
 
519
  fig = _make_tracks_figure(x, series)
 
520
 
521
  region = f"{out.chrom}:{out.pred_start}-{out.pred_end}" if out.chrom else f"{x0}-{x1}"
522
  if out.assembly:
@@ -895,7 +922,7 @@ with gr.Blocks(title="NTv3 Tracks Demo") as demo:
895
 
896
  <div class="intro-tip">
897
  <span class="intro-tip-icon">💡</span>
898
- <span><strong>Tip:</strong> The demo includes default settings that you can use to get started, taking ~ 10 seconds to run.</span>
899
  </div>
900
 
901
  <div style="margin-top: 16px; padding: 12px; background: rgba(0,0,0,0.03); border-radius: 12px; font-size: 0.95rem;">
 
1
+ import csv
2
  import os
 
3
  import tempfile
4
+ import time
5
+ import uuid
6
  from pathlib import Path
7
+ import torch
8
  import numpy as np
9
  import gradio as gr
 
10
  import spaces
 
11
 
12
  # Set matplotlib to use non-interactive backend before importing pyplot
13
  # This is required for Gradio which runs on worker threads
 
43
  PLOT_TARGET_POINTS = int(os.environ.get("PLOT_TARGET_POINTS", "1500"))
44
  SEARCH_MAX_RESULTS = int(os.environ.get("SEARCH_MAX_RESULTS", "50"))
45
 
 
46
  # -----------------------------
47
  # Load pipeline (reloadable)
48
  # -----------------------------
 
69
  # -----------------------------
70
  # Helpers
71
  # -----------------------------
72
+
73
+ _t0 = None
74
+ _tlast = None
75
+
76
+ def tprint(msg: str):
77
+ "Function to print timing information"
78
+ global _t0, _tlast
79
+ if _t0 is None:
80
+ _t0 = _tlast = time.perf_counter()
81
+
82
+ # CUDA ops are async → synchronize to get real timings
83
+ if torch.cuda.is_available():
84
+ torch.cuda.synchronize()
85
+
86
+ now = time.perf_counter()
87
+ print(f"[timing] {msg}: {now - _tlast:.3f}s (total {now - _t0:.3f}s)")
88
+ _tlast = now
89
+
90
  def _global_stride(L: int, target: int) -> int:
91
  if target <= 0 or L <= target:
92
  return 1
 
413
  bigwig_selected: list[str],
414
  bed_elements: list[str],
415
  ):
416
+ tprint("start")
417
+
418
  # Debug: verify species is being passed
419
  if not species:
420
  raise gr.Error("Species parameter is missing. Please select a species.")
 
447
  if "species" not in inputs:
448
  raise gr.Error(f"Internal error: species not found in inputs dict. Inputs: {list(inputs.keys())}")
449
 
450
+ tprint("inputs prepared")
451
+
452
  # move to GPU only once the ZeroGPU context is active
453
  device = "cuda" if torch.cuda.is_available() else "cpu"
454
  pipe.model.to(device)
455
  pipe.model.eval()
456
  print(f"Running on {next(pipe.model.parameters()).device}")
457
 
458
+ tprint(f"pipe.model.to({device})")
459
+
460
  # run inference
461
  out = pipe(inputs)
462
 
463
+ tprint("inference completed")
464
+
465
  # optional: move back to CPU so you don’t rely on any persistent CUDA context
466
+ # if device == "cuda":
467
+ # pipe.model.to("cpu")
468
 
469
  bw_names = out.bigwig_track_names or []
470
  bw = out.bigwig_tracks_logits
 
541
  eidx = bed_names.index(ename)
542
  series.append((ename, probs[:, eidx, 1][::stride].astype(float)))
543
 
544
+ tprint("figure data processed created")
545
  fig = _make_tracks_figure(x, series)
546
+ tprint("figure created")
547
 
548
  region = f"{out.chrom}:{out.pred_start}-{out.pred_end}" if out.chrom else f"{x0}-{x1}"
549
  if out.assembly:
 
922
 
923
  <div class="intro-tip">
924
  <span class="intro-tip-icon">💡</span>
925
+ <span><strong>Tip:</strong> The demo includes default settings that you can use to get started, taking ~ 15 seconds to run for the example on human.</span>
926
  </div>
927
 
928
  <div style="margin-top: 16px; padding: 12px; background: rgba(0,0,0,0.03); border-radius: 12px; font-size: 0.95rem;">