DanielRegaladoCardoso commited on
Commit
05de2b3
·
verified ·
1 Parent(s): 420b1db

ZeroGPU best practice: load models at module level (cuda), inference only inside @spaces.GPU

Browse files
Files changed (1) hide show
  1. app.py +17 -27
app.py CHANGED
@@ -19,32 +19,8 @@ sys.path.insert(0, str(ROOT))
19
  logging.basicConfig(level=logging.INFO, format="%(asctime)s | %(levelname)s | %(message)s")
20
  logger = logging.getLogger(__name__)
21
 
22
- # CRITICAL: pre-download model weights at module-load time (CPU phase, no GPU
23
- # needed). When @spaces.GPU is later invoked, from_pretrained() finds the
24
- # files already cached and just moves them to GPU — that loads in ~10s
25
- # instead of 30-60s, which keeps us inside the ZeroGPU quota window.
26
- MODEL_REPOS = [
27
- "DanielRegaladoCardoso/sql-generator-qwen25-coder-7b-lora",
28
- "DanielRegaladoCardoso/chart-reasoner-phi3-mini-lora",
29
- "DanielRegaladoCardoso/svg-renderer-deepseek-coder-1.3b-lora",
30
- ]
31
-
32
- try:
33
- from huggingface_hub import snapshot_download
34
- for repo in MODEL_REPOS:
35
- try:
36
- logger.info(f"Pre-downloading {repo}...")
37
- snapshot_download(repo)
38
- logger.info(f" cached")
39
- except Exception as e:
40
- logger.warning(f" pre-download failed (will retry on first use): {e}")
41
- except Exception as e:
42
- logger.warning(f"snapshot_download unavailable: {e}")
43
-
44
  import gradio as gr # noqa: E402
45
 
46
- from src.orchestrator.pipeline import SQLAgentOrchestrator # noqa: E402
47
-
48
  try:
49
  import spaces # type: ignore
50
  HAS_SPACES = True
@@ -60,6 +36,20 @@ except ImportError:
60
 
61
  spaces = _SpacesShim() # type: ignore
62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
  # ============================================================ THEME / CSS
65
  THEME_CSS = """
@@ -416,7 +406,7 @@ _AGENT: Optional[SQLAgentOrchestrator] = None
416
  def get_agent() -> SQLAgentOrchestrator:
417
  global _AGENT
418
  if _AGENT is None:
419
- _AGENT = SQLAgentOrchestrator()
420
  return _AGENT
421
 
422
 
@@ -599,9 +589,9 @@ def on_load_demo() -> Tuple[str, str, list]:
599
  return "", f'<div class="turn-error">Could not load demo: {e}</div>', []
600
 
601
 
602
- @spaces.GPU(duration=120)
603
  def _gpu_process(question: str) -> dict:
604
- """The GPU-bound call. Models initialize lazily inside this scope."""
605
  agent = get_agent()
606
  return agent.process(question)
607
 
 
19
  logging.basicConfig(level=logging.INFO, format="%(asctime)s | %(levelname)s | %(message)s")
20
  logger = logging.getLogger(__name__)
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  import gradio as gr # noqa: E402
23
 
 
 
24
  try:
25
  import spaces # type: ignore
26
  HAS_SPACES = True
 
36
 
37
  spaces = _SpacesShim() # type: ignore
38
 
39
+ # CRITICAL: load all 3 models on cuda at module level per ZeroGPU best
40
+ # practice. PyTorch CUDA emulation handles this when no real GPU is present;
41
+ # inside @spaces.GPU calls, the real GPU is used and inference is fast.
42
+ logger.info("Loading models at module level...")
43
+ from src.models.sql_generator import SQLGenerator # noqa: E402
44
+ from src.models.chart_reasoner import ChartReasoner # noqa: E402
45
+ from src.models.svg_renderer import SVGRenderer # noqa: E402
46
+ from src.orchestrator.pipeline import SQLAgentOrchestrator # noqa: E402
47
+
48
+ _SQL_GEN = SQLGenerator()
49
+ _CHART_REASONER = ChartReasoner()
50
+ _SVG_RENDERER = SVGRenderer()
51
+ logger.info("All models loaded")
52
+
53
 
54
  # ============================================================ THEME / CSS
55
  THEME_CSS = """
 
406
  def get_agent() -> SQLAgentOrchestrator:
407
  global _AGENT
408
  if _AGENT is None:
409
+ _AGENT = SQLAgentOrchestrator(_SQL_GEN, _CHART_REASONER, _SVG_RENDERER)
410
  return _AGENT
411
 
412
 
 
589
  return "", f'<div class="turn-error">Could not load demo: {e}</div>', []
590
 
591
 
592
+ @spaces.GPU(duration=60)
593
  def _gpu_process(question: str) -> dict:
594
+ """Inference only models already on cuda from module-level loading."""
595
  agent = get_agent()
596
  return agent.process(question)
597