DanielRegaladoCardoso commited on
Commit
a57eca6
·
verified ·
1 Parent(s): c2ac226

ZeroGPU best practice: load models at module level (cuda), inference only inside @spaces.GPU

Browse files
Files changed (1) hide show
  1. src/models/chart_reasoner.py +14 -31
src/models/chart_reasoner.py CHANGED
@@ -1,8 +1,7 @@
1
  """
2
- Chart Reasoner: query results -> chart spec via the trained Phi-3 Mini LoRA.
3
 
4
- Uses the adapter-only repo so the LoRA loads on top of the original
5
- Phi-3-mini-4k-instruct base, keeping Hub downloads small.
6
  """
7
 
8
  import json
@@ -10,7 +9,8 @@ import logging
10
  import re
11
  from typing import Any, Dict, List
12
 
13
- from src.models.base import BaseModel
 
14
 
15
  logger = logging.getLogger(__name__)
16
 
@@ -22,51 +22,39 @@ SYSTEM_PROMPT = (
22
  "Return only valid JSON, no commentary."
23
  )
24
 
 
25
 
26
- class ChartReasoner(BaseModel):
27
- """Generate chart specs from SQL result sets."""
28
 
29
- DEFAULT_MERGED = "DanielRegaladoCardoso/chart-reasoner-phi3-mini-lora"
 
30
 
31
  def __init__(
32
  self,
33
- hf_model: str = DEFAULT_MERGED,
34
  temperature: float = 0.0,
35
  max_new_tokens: int = 300,
36
  ) -> None:
37
- super().__init__(model_name="chart-reasoner")
38
  self.hf_model = hf_model
39
  self.temperature = temperature
40
  self.max_new_tokens = max_new_tokens
41
 
42
- def load(self) -> None:
43
- from transformers import AutoModelForCausalLM, AutoTokenizer
44
- import torch
45
-
46
- logger.info(f"Loading chart reasoner: {self.hf_model}")
47
- device = "cuda" if torch.cuda.is_available() else "cpu"
48
- dtype = torch.bfloat16 if device == "cuda" else torch.float32
49
-
50
  self.tokenizer = AutoTokenizer.from_pretrained(self.hf_model)
51
  self.model = AutoModelForCausalLM.from_pretrained(
52
  self.hf_model,
53
- torch_dtype=dtype,
54
- device_map=device,
55
  )
56
  self.model.eval()
57
- self.is_loaded = True
58
- logger.info(f"Chart reasoner loaded on {device}")
59
 
60
- def generate( # type: ignore[override]
61
  self,
62
  question: str,
63
  sql: str,
64
  results: List[Dict[str, Any]],
65
  columns: List[Dict[str, Any]],
66
  ) -> Dict[str, Any]:
67
- self._validate_loaded()
68
- import torch
69
-
70
  sample = results[:5]
71
  col_names = [c["name"] for c in columns]
72
  user_content = (
@@ -102,18 +90,14 @@ class ChartReasoner(BaseModel):
102
  def _parse_spec(
103
  self, text: str, columns: List[Dict[str, Any]]
104
  ) -> Dict[str, Any]:
105
- # Try to extract a JSON object from the response
106
  match = re.search(r"\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}", text, re.DOTALL)
107
  if not match:
108
- logger.warning("No JSON found in chart reasoner output")
109
  return self._fallback_spec(columns)
110
  try:
111
  spec = json.loads(match.group(0))
112
- except json.JSONDecodeError as e:
113
- logger.warning(f"Chart spec JSON invalid: {e}")
114
  return self._fallback_spec(columns)
115
 
116
- # Normalize
117
  return {
118
  "chart_type": spec.get("chart_type", "bar").lower(),
119
  "title": spec.get("title", "Result"),
@@ -124,7 +108,6 @@ class ChartReasoner(BaseModel):
124
  }
125
 
126
  def _fallback_spec(self, columns: List[Dict[str, Any]]) -> Dict[str, Any]:
127
- """Heuristic fallback when the model output can't be parsed."""
128
  if not columns:
129
  return {"chart_type": "table", "title": "Result"}
130
  if len(columns) == 1:
 
1
  """
2
+ Chart Reasoner: query results -> chart spec via the Phi-3 Mini LoRA.
3
 
4
+ Model loaded at root module level (ZeroGPU best practice).
 
5
  """
6
 
7
  import json
 
9
  import re
10
  from typing import Any, Dict, List
11
 
12
+ import torch
13
+ from transformers import AutoModelForCausalLM, AutoTokenizer
14
 
15
  logger = logging.getLogger(__name__)
16
 
 
22
  "Return only valid JSON, no commentary."
23
  )
24
 
25
+ DEFAULT_MODEL = "DanielRegaladoCardoso/chart-reasoner-phi3-mini-lora"
26
 
 
 
27
 
28
+ class ChartReasoner:
29
+ """Generate chart specs from SQL result sets."""
30
 
31
  def __init__(
32
  self,
33
+ hf_model: str = DEFAULT_MODEL,
34
  temperature: float = 0.0,
35
  max_new_tokens: int = 300,
36
  ) -> None:
 
37
  self.hf_model = hf_model
38
  self.temperature = temperature
39
  self.max_new_tokens = max_new_tokens
40
 
41
+ logger.info(f"Loading chart reasoner at module level: {self.hf_model}")
 
 
 
 
 
 
 
42
  self.tokenizer = AutoTokenizer.from_pretrained(self.hf_model)
43
  self.model = AutoModelForCausalLM.from_pretrained(
44
  self.hf_model,
45
+ torch_dtype=torch.bfloat16,
46
+ device_map="cuda",
47
  )
48
  self.model.eval()
49
+ logger.info("Chart reasoner ready")
 
50
 
51
+ def generate(
52
  self,
53
  question: str,
54
  sql: str,
55
  results: List[Dict[str, Any]],
56
  columns: List[Dict[str, Any]],
57
  ) -> Dict[str, Any]:
 
 
 
58
  sample = results[:5]
59
  col_names = [c["name"] for c in columns]
60
  user_content = (
 
90
  def _parse_spec(
91
  self, text: str, columns: List[Dict[str, Any]]
92
  ) -> Dict[str, Any]:
 
93
  match = re.search(r"\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}", text, re.DOTALL)
94
  if not match:
 
95
  return self._fallback_spec(columns)
96
  try:
97
  spec = json.loads(match.group(0))
98
+ except json.JSONDecodeError:
 
99
  return self._fallback_spec(columns)
100
 
 
101
  return {
102
  "chart_type": spec.get("chart_type", "bar").lower(),
103
  "title": spec.get("title", "Result"),
 
108
  }
109
 
110
  def _fallback_spec(self, columns: List[Dict[str, Any]]) -> Dict[str, Any]:
 
111
  if not columns:
112
  return {"chart_type": "table", "title": "Result"}
113
  if len(columns) == 1: