DanielRegaladoCardoso commited on
Commit
eb30a86
·
verified ·
1 Parent(s): 730b25d

Load LoRA via PeftModel on top of standard base models (fixes r=16 vs r=8 mismatch)

Browse files
Files changed (1) hide show
  1. src/models/chart_reasoner.py +25 -37
src/models/chart_reasoner.py CHANGED
@@ -1,7 +1,5 @@
1
  """
2
- Chart Reasoner: query results -> chart spec via the Phi-3 Mini LoRA.
3
-
4
- Model loaded at root module level (ZeroGPU best practice).
5
  """
6
 
7
  import json
@@ -11,6 +9,7 @@ from typing import Any, Dict, List
11
 
12
  import torch
13
  from transformers import AutoModelForCausalLM, AutoTokenizer
 
14
 
15
  logger = logging.getLogger(__name__)
16
 
@@ -18,32 +17,32 @@ logger = logging.getLogger(__name__)
18
  SYSTEM_PROMPT = (
19
  "You are a data visualization expert. Given a question, the SQL that "
20
  "answers it, and a sample of the result rows, produce a JSON chart "
21
- "specification. Choose the chart type that tells the clearest story. "
22
- "Return only valid JSON, no commentary."
23
  )
24
 
25
- DEFAULT_MODEL = "DanielRegaladoCardoso/chart-reasoner-phi3-mini-lora"
 
26
 
27
 
28
  class ChartReasoner:
29
- """Generate chart specs from SQL result sets."""
30
 
31
- def __init__(
32
- self,
33
- hf_model: str = DEFAULT_MODEL,
34
- temperature: float = 0.0,
35
- max_new_tokens: int = 300,
36
- ) -> None:
37
- self.hf_model = hf_model
38
  self.temperature = temperature
39
  self.max_new_tokens = max_new_tokens
40
 
41
- logger.info(f"Loading chart reasoner at module level: {self.hf_model}")
42
- self.tokenizer = AutoTokenizer.from_pretrained(self.hf_model)
43
- self.model = AutoModelForCausalLM.from_pretrained(
44
- self.hf_model,
45
  torch_dtype=torch.bfloat16,
46
  device_map="cuda",
 
 
 
 
 
 
 
47
  )
48
  self.model.eval()
49
  logger.info("Chart reasoner ready")
@@ -62,9 +61,9 @@ class ChartReasoner:
62
  f"SQL: {sql}\n"
63
  f"Columns: {col_names}\n"
64
  f"Sample rows: {json.dumps(sample, default=str)}\n\n"
65
- "Return JSON with: chart_type (one of: bar, line, scatter, "
66
- "pie, area, table), title, x_column, y_column, "
67
- "color_column (optional), rationale."
68
  )
69
  messages = [
70
  {"role": "system", "content": SYSTEM_PROMPT},
@@ -87,9 +86,7 @@ class ChartReasoner:
87
  )
88
  return self._parse_spec(raw, columns)
89
 
90
- def _parse_spec(
91
- self, text: str, columns: List[Dict[str, Any]]
92
- ) -> Dict[str, Any]:
93
  match = re.search(r"\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}", text, re.DOTALL)
94
  if not match:
95
  return self._fallback_spec(columns)
@@ -97,7 +94,6 @@ class ChartReasoner:
97
  spec = json.loads(match.group(0))
98
  except json.JSONDecodeError:
99
  return self._fallback_spec(columns)
100
-
101
  return {
102
  "chart_type": spec.get("chart_type", "bar").lower(),
103
  "title": spec.get("title", "Result"),
@@ -111,16 +107,8 @@ class ChartReasoner:
111
  if not columns:
112
  return {"chart_type": "table", "title": "Result"}
113
  if len(columns) == 1:
114
- return {
115
- "chart_type": "table",
116
- "title": "Result",
117
  "x_column": columns[0]["name"],
118
- "y_column": None,
119
- }
120
- return {
121
- "chart_type": "bar",
122
- "title": "Result",
123
- "x_column": columns[0]["name"],
124
- "y_column": columns[1]["name"],
125
- "color_column": None,
126
- }
 
1
  """
2
+ Chart Reasoner: load the trained LoRA on top of Phi-3 Mini base.
 
 
3
  """
4
 
5
  import json
 
9
 
10
  import torch
11
  from transformers import AutoModelForCausalLM, AutoTokenizer
12
+ from peft import PeftModel
13
 
14
  logger = logging.getLogger(__name__)
15
 
 
17
  SYSTEM_PROMPT = (
18
  "You are a data visualization expert. Given a question, the SQL that "
19
  "answers it, and a sample of the result rows, produce a JSON chart "
20
+ "specification. Return only valid JSON, no commentary."
 
21
  )
22
 
23
+ BASE_MODEL = "microsoft/Phi-3-mini-4k-instruct"
24
+ ADAPTER_REPO = "DanielRegaladoCardoso/chart-reasoner-phi3-mini-adapter-only"
25
 
26
 
27
  class ChartReasoner:
 
28
 
29
+ def __init__(self, temperature: float = 0.0, max_new_tokens: int = 300) -> None:
 
 
 
 
 
 
30
  self.temperature = temperature
31
  self.max_new_tokens = max_new_tokens
32
 
33
+ logger.info(f"Loading chart base: {BASE_MODEL}")
34
+ self.tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
35
+ base = AutoModelForCausalLM.from_pretrained(
36
+ BASE_MODEL,
37
  torch_dtype=torch.bfloat16,
38
  device_map="cuda",
39
+ trust_remote_code=True,
40
+ )
41
+ logger.info(f"Applying LoRA adapter: {ADAPTER_REPO}")
42
+ self.model = PeftModel.from_pretrained(
43
+ base,
44
+ ADAPTER_REPO,
45
+ torch_dtype=torch.bfloat16,
46
  )
47
  self.model.eval()
48
  logger.info("Chart reasoner ready")
 
61
  f"SQL: {sql}\n"
62
  f"Columns: {col_names}\n"
63
  f"Sample rows: {json.dumps(sample, default=str)}\n\n"
64
+ "Return JSON with: chart_type (one of: bar, line, scatter, pie, "
65
+ "area, table), title, x_column, y_column, color_column "
66
+ "(optional), rationale."
67
  )
68
  messages = [
69
  {"role": "system", "content": SYSTEM_PROMPT},
 
86
  )
87
  return self._parse_spec(raw, columns)
88
 
89
+ def _parse_spec(self, text: str, columns: List[Dict[str, Any]]) -> Dict[str, Any]:
 
 
90
  match = re.search(r"\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}", text, re.DOTALL)
91
  if not match:
92
  return self._fallback_spec(columns)
 
94
  spec = json.loads(match.group(0))
95
  except json.JSONDecodeError:
96
  return self._fallback_spec(columns)
 
97
  return {
98
  "chart_type": spec.get("chart_type", "bar").lower(),
99
  "title": spec.get("title", "Result"),
 
107
  if not columns:
108
  return {"chart_type": "table", "title": "Result"}
109
  if len(columns) == 1:
110
+ return {"chart_type": "table", "title": "Result",
111
+ "x_column": columns[0]["name"], "y_column": None}
112
+ return {"chart_type": "bar", "title": "Result",
113
  "x_column": columns[0]["name"],
114
+ "y_column": columns[1]["name"], "color_column": None}