Spaces:
Sleeping
Sleeping
ZeroGPU best practice: load models at module level (cuda), inference only inside @spaces.GPU
Browse files
app.py
CHANGED
|
@@ -19,32 +19,8 @@ sys.path.insert(0, str(ROOT))
|
|
| 19 |
logging.basicConfig(level=logging.INFO, format="%(asctime)s | %(levelname)s | %(message)s")
|
| 20 |
logger = logging.getLogger(__name__)
|
| 21 |
|
| 22 |
-
# CRITICAL: pre-download model weights at module-load time (CPU phase, no GPU
|
| 23 |
-
# needed). When @spaces.GPU is later invoked, from_pretrained() finds the
|
| 24 |
-
# files already cached and just moves them to GPU — that loads in ~10s
|
| 25 |
-
# instead of 30-60s, which keeps us inside the ZeroGPU quota window.
|
| 26 |
-
MODEL_REPOS = [
|
| 27 |
-
"DanielRegaladoCardoso/sql-generator-qwen25-coder-7b-lora",
|
| 28 |
-
"DanielRegaladoCardoso/chart-reasoner-phi3-mini-lora",
|
| 29 |
-
"DanielRegaladoCardoso/svg-renderer-deepseek-coder-1.3b-lora",
|
| 30 |
-
]
|
| 31 |
-
|
| 32 |
-
try:
|
| 33 |
-
from huggingface_hub import snapshot_download
|
| 34 |
-
for repo in MODEL_REPOS:
|
| 35 |
-
try:
|
| 36 |
-
logger.info(f"Pre-downloading {repo}...")
|
| 37 |
-
snapshot_download(repo)
|
| 38 |
-
logger.info(f" cached")
|
| 39 |
-
except Exception as e:
|
| 40 |
-
logger.warning(f" pre-download failed (will retry on first use): {e}")
|
| 41 |
-
except Exception as e:
|
| 42 |
-
logger.warning(f"snapshot_download unavailable: {e}")
|
| 43 |
-
|
| 44 |
import gradio as gr # noqa: E402
|
| 45 |
|
| 46 |
-
from src.orchestrator.pipeline import SQLAgentOrchestrator # noqa: E402
|
| 47 |
-
|
| 48 |
try:
|
| 49 |
import spaces # type: ignore
|
| 50 |
HAS_SPACES = True
|
|
@@ -60,6 +36,20 @@ except ImportError:
|
|
| 60 |
|
| 61 |
spaces = _SpacesShim() # type: ignore
|
| 62 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
|
| 64 |
# ============================================================ THEME / CSS
|
| 65 |
THEME_CSS = """
|
|
@@ -416,7 +406,7 @@ _AGENT: Optional[SQLAgentOrchestrator] = None
|
|
| 416 |
def get_agent() -> SQLAgentOrchestrator:
|
| 417 |
global _AGENT
|
| 418 |
if _AGENT is None:
|
| 419 |
-
_AGENT = SQLAgentOrchestrator()
|
| 420 |
return _AGENT
|
| 421 |
|
| 422 |
|
|
@@ -599,9 +589,9 @@ def on_load_demo() -> Tuple[str, str, list]:
|
|
| 599 |
return "", f'<div class="turn-error">Could not load demo: {e}</div>', []
|
| 600 |
|
| 601 |
|
| 602 |
-
@spaces.GPU(duration=
|
| 603 |
def _gpu_process(question: str) -> dict:
|
| 604 |
-
"""
|
| 605 |
agent = get_agent()
|
| 606 |
return agent.process(question)
|
| 607 |
|
|
|
|
| 19 |
logging.basicConfig(level=logging.INFO, format="%(asctime)s | %(levelname)s | %(message)s")
|
| 20 |
logger = logging.getLogger(__name__)
|
| 21 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
import gradio as gr # noqa: E402
|
| 23 |
|
|
|
|
|
|
|
| 24 |
try:
|
| 25 |
import spaces # type: ignore
|
| 26 |
HAS_SPACES = True
|
|
|
|
| 36 |
|
| 37 |
spaces = _SpacesShim() # type: ignore
|
| 38 |
|
| 39 |
+
# CRITICAL: load all 3 models on cuda at module level per ZeroGPU best
|
| 40 |
+
# practice. PyTorch CUDA emulation handles this when no real GPU is present;
|
| 41 |
+
# inside @spaces.GPU calls, the real GPU is used and inference is fast.
|
| 42 |
+
logger.info("Loading models at module level...")
|
| 43 |
+
from src.models.sql_generator import SQLGenerator # noqa: E402
|
| 44 |
+
from src.models.chart_reasoner import ChartReasoner # noqa: E402
|
| 45 |
+
from src.models.svg_renderer import SVGRenderer # noqa: E402
|
| 46 |
+
from src.orchestrator.pipeline import SQLAgentOrchestrator # noqa: E402
|
| 47 |
+
|
| 48 |
+
_SQL_GEN = SQLGenerator()
|
| 49 |
+
_CHART_REASONER = ChartReasoner()
|
| 50 |
+
_SVG_RENDERER = SVGRenderer()
|
| 51 |
+
logger.info("All models loaded")
|
| 52 |
+
|
| 53 |
|
| 54 |
# ============================================================ THEME / CSS
|
| 55 |
THEME_CSS = """
|
|
|
|
| 406 |
def get_agent() -> SQLAgentOrchestrator:
|
| 407 |
global _AGENT
|
| 408 |
if _AGENT is None:
|
| 409 |
+
_AGENT = SQLAgentOrchestrator(_SQL_GEN, _CHART_REASONER, _SVG_RENDERER)
|
| 410 |
return _AGENT
|
| 411 |
|
| 412 |
|
|
|
|
| 589 |
return "", f'<div class="turn-error">Could not load demo: {e}</div>', []
|
| 590 |
|
| 591 |
|
| 592 |
+
@spaces.GPU(duration=60)
|
| 593 |
def _gpu_process(question: str) -> dict:
|
| 594 |
+
"""Inference only — models already on cuda from module-level loading."""
|
| 595 |
agent = get_agent()
|
| 596 |
return agent.process(question)
|
| 597 |
|