DanielRegaladoCardoso commited on
Commit
420b1db
·
verified ·
1 Parent(s): 61aee8d

ZeroGPU best practice: load models at module level (cuda), inference only inside @spaces.GPU

Browse files
Files changed (1) hide show
  1. src/orchestrator/pipeline.py +12 -46
src/orchestrator/pipeline.py CHANGED
@@ -1,11 +1,6 @@
1
  """
2
- SQL Agent orchestrator.
3
-
4
- Holds an in-memory DuckDB connection and the three specialist models, and
5
- walks a question through the pipeline:
6
-
7
- schema (DuckDB) -> SQL (Qwen) -> execute (DuckDB)
8
- -> chart spec (Phi-3) -> SVG (DeepSeek + theme)
9
  """
10
 
11
  import logging
@@ -26,24 +21,23 @@ logger = logging.getLogger(__name__)
26
  class SQLAgentOrchestrator:
27
  """End-to-end NL -> SQL -> chart pipeline backed by DuckDB."""
28
 
29
- def __init__(self) -> None:
 
 
 
 
 
30
  self.executor = SQLExecutor()
31
  self.rag = RAGEngine(self.executor.con)
 
 
 
32
 
33
- # Models are constructed eagerly but loaded lazily (HF Spaces ZeroGPU
34
- # gives us a GPU only inside @spaces.GPU calls, so model.load() must
35
- # happen there, not at import time).
36
- self.sql_generator = SQLGenerator()
37
- self.chart_reasoner = ChartReasoner()
38
- self.svg_renderer = SVGRenderer()
39
-
40
- # --------------------------------------------------------------- data
41
  def load_data(
42
  self,
43
  source: Union[str, Path, pd.DataFrame],
44
  table_name: Optional[str] = None,
45
  ) -> str:
46
- """Register a DataFrame or file as a queryable table. Returns the table name."""
47
  if isinstance(source, pd.DataFrame):
48
  name = table_name or "data"
49
  self.executor.register_dataframe(name, source)
@@ -59,14 +53,8 @@ class SQLAgentOrchestrator:
59
  def sample(self, table: str, n: int = 5) -> pd.DataFrame:
60
  return self.executor.get_sample(table, n)
61
 
62
- # ----------------------------------------------------------- pipeline
63
  def process(self, question: str) -> Dict[str, Any]:
64
- """
65
- Run the full pipeline for one question.
66
-
67
- Models are loaded and unloaded sequentially to keep peak VRAM low
68
- (only one of the 3 models lives in GPU at a time).
69
- """
70
  result: Dict[str, Any] = {
71
  "question": question,
72
  "sql": None,
@@ -83,37 +71,23 @@ class SQLAgentOrchestrator:
83
  result["error"] = "No data loaded. Upload a CSV/JSON first."
84
  return result
85
 
86
- # 1) SQL — load Qwen, generate, unload
87
- logger.info("Step 1/4: SQL generation")
88
- self.sql_generator.load()
89
  sql = self.sql_generator.generate(question=question, schema=schema)
90
- self.sql_generator.unload()
91
  result["sql"] = sql
92
 
93
  if not self.executor.validate_query(sql):
94
  result["error"] = f"Generated SQL is invalid:\n{sql}"
95
  return result
96
 
97
- # 2) Execute (CPU-only, no model needed)
98
- logger.info("Step 2/4: SQL execution")
99
  rows, cols = self.executor.execute(sql)
100
  result["results"] = rows
101
  result["columns"] = cols
102
 
103
- # 3) Chart spec — load Phi-3, generate, unload
104
- logger.info("Step 3/4: chart reasoning")
105
- self.chart_reasoner.load()
106
  spec = self.chart_reasoner.generate(
107
  question=question, sql=sql, results=rows, columns=cols,
108
  )
109
- self.chart_reasoner.unload()
110
  result["chart_spec"] = spec
111
 
112
- # 4) Render — load DeepSeek (or Plotly fallback), render, unload
113
- logger.info("Step 4/4: SVG rendering")
114
- self.svg_renderer.load()
115
  svg = self.svg_renderer.generate(spec, rows)
116
- self.svg_renderer.unload()
117
  result["svg"] = svg
118
 
119
  return result
@@ -121,17 +95,9 @@ class SQLAgentOrchestrator:
121
  except Exception as e:
122
  logger.exception("Pipeline failed")
123
  result["error"] = str(e)
124
- # Best-effort cleanup so a failure doesn't leak a model in VRAM
125
- for m in (self.sql_generator, self.chart_reasoner, self.svg_renderer):
126
- try:
127
- if m.is_loaded:
128
- m.unload()
129
- except Exception:
130
- pass
131
  return result
132
 
133
  def reset(self) -> None:
134
- """Drop all data tables (keeps the connection alive)."""
135
  self.executor.close()
136
  self.executor = SQLExecutor()
137
  self.rag.bind(self.executor.con)
 
1
  """
2
+ SQL Agent orchestrator. Models are constructed (loaded onto cuda) at
3
+ import time per ZeroGPU best practices. The pipeline runs inference only.
 
 
 
 
 
4
  """
5
 
6
  import logging
 
21
  class SQLAgentOrchestrator:
22
  """End-to-end NL -> SQL -> chart pipeline backed by DuckDB."""
23
 
24
+ def __init__(
25
+ self,
26
+ sql_generator: SQLGenerator,
27
+ chart_reasoner: ChartReasoner,
28
+ svg_renderer: SVGRenderer,
29
+ ) -> None:
30
  self.executor = SQLExecutor()
31
  self.rag = RAGEngine(self.executor.con)
32
+ self.sql_generator = sql_generator
33
+ self.chart_reasoner = chart_reasoner
34
+ self.svg_renderer = svg_renderer
35
 
 
 
 
 
 
 
 
 
36
  def load_data(
37
  self,
38
  source: Union[str, Path, pd.DataFrame],
39
  table_name: Optional[str] = None,
40
  ) -> str:
 
41
  if isinstance(source, pd.DataFrame):
42
  name = table_name or "data"
43
  self.executor.register_dataframe(name, source)
 
53
  def sample(self, table: str, n: int = 5) -> pd.DataFrame:
54
  return self.executor.get_sample(table, n)
55
 
 
56
  def process(self, question: str) -> Dict[str, Any]:
57
+ """Inference-only pipeline; models already loaded at module level."""
 
 
 
 
 
58
  result: Dict[str, Any] = {
59
  "question": question,
60
  "sql": None,
 
71
  result["error"] = "No data loaded. Upload a CSV/JSON first."
72
  return result
73
 
 
 
 
74
  sql = self.sql_generator.generate(question=question, schema=schema)
 
75
  result["sql"] = sql
76
 
77
  if not self.executor.validate_query(sql):
78
  result["error"] = f"Generated SQL is invalid:\n{sql}"
79
  return result
80
 
 
 
81
  rows, cols = self.executor.execute(sql)
82
  result["results"] = rows
83
  result["columns"] = cols
84
 
 
 
 
85
  spec = self.chart_reasoner.generate(
86
  question=question, sql=sql, results=rows, columns=cols,
87
  )
 
88
  result["chart_spec"] = spec
89
 
 
 
 
90
  svg = self.svg_renderer.generate(spec, rows)
 
91
  result["svg"] = svg
92
 
93
  return result
 
95
  except Exception as e:
96
  logger.exception("Pipeline failed")
97
  result["error"] = str(e)
 
 
 
 
 
 
 
98
  return result
99
 
100
  def reset(self) -> None:
 
101
  self.executor.close()
102
  self.executor = SQLExecutor()
103
  self.rag.bind(self.executor.con)