DanielRegaladoCardoso commited on
Commit
1bbdff9
·
verified ·
1 Parent(s): eb30a86

Load LoRA via PeftModel on top of standard base models (fixes r=16 vs r=8 mismatch)

Browse files
Files changed (1) hide show
  1. src/models/svg_renderer.py +31 -31
src/models/svg_renderer.py CHANGED
@@ -1,15 +1,16 @@
1
  """
2
- SVG Renderer: chart spec + data -> inline SVG.
3
-
4
- Model loaded at root module level (ZeroGPU best practice). If the model
5
- output isn't a valid SVG, falls back to themed Plotly.
6
  """
7
 
 
8
  import logging
 
9
  from typing import Any, Dict, List
10
 
11
  import torch
12
  from transformers import AutoModelForCausalLM, AutoTokenizer
 
13
 
14
  from src.visualization.plotly_fallback import PlotlyRenderer
15
  from src.visualization.svg_theme import apply_theme, is_renderable_svg
@@ -23,43 +24,47 @@ SYSTEM_PROMPT = (
23
  "minimalist style. Return only the SVG, starting with <svg."
24
  )
25
 
26
- DEFAULT_MODEL = "DanielRegaladoCardoso/svg-renderer-deepseek-coder-1.3b-lora"
 
27
 
28
 
29
  class SVGRenderer:
30
- """Render a chart spec to inline SVG."""
31
-
32
- def __init__(
33
- self,
34
- hf_model: str = DEFAULT_MODEL,
35
- temperature: float = 0.2,
36
- max_new_tokens: int = 1500,
37
- ) -> None:
38
- self.hf_model = hf_model
39
  self.temperature = temperature
40
  self.max_new_tokens = max_new_tokens
41
  self._plotly = PlotlyRenderer()
 
 
42
 
43
- logger.info(f"Loading SVG renderer at module level: {self.hf_model}")
44
  try:
45
- self.tokenizer = AutoTokenizer.from_pretrained(self.hf_model)
46
- self.model = AutoModelForCausalLM.from_pretrained(
47
- self.hf_model,
 
48
  torch_dtype=torch.bfloat16,
49
  device_map="cuda",
 
50
  )
 
 
 
 
 
 
 
 
 
 
 
 
51
  self.model.eval()
52
- logger.info("SVG renderer ready")
53
  except Exception as e:
54
- logger.warning(f"SVG model load failed ({e}); will use Plotly fallback only")
55
  self.model = None
56
  self.tokenizer = None
57
 
58
- def generate(
59
- self,
60
- chart_spec: Dict[str, Any],
61
- data: List[Dict[str, Any]],
62
- ) -> str:
63
  if self.model is not None and self.tokenizer is not None:
64
  try:
65
  svg = self._generate_model(chart_spec, data)
@@ -72,11 +77,7 @@ class SVGRenderer:
72
  svg = self._plotly.render(chart_spec, data)
73
  return apply_theme(svg)
74
 
75
- def _generate_model(
76
- self, chart_spec: Dict[str, Any], data: List[Dict[str, Any]]
77
- ) -> str:
78
- import json
79
-
80
  sample = data[:50]
81
  user_content = (
82
  f"Chart spec: {json.dumps(chart_spec, default=str)}\n"
@@ -107,6 +108,5 @@ class SVGRenderer:
107
 
108
  @staticmethod
109
  def _extract_svg(text: str) -> str:
110
- import re
111
  m = re.search(r"<svg[\s\S]*?</svg>", text, re.IGNORECASE)
112
  return m.group(0) if m else text.strip()
 
1
  """
2
+ SVG Renderer: load the trained LoRA on top of DeepSeek Coder 1.3B base.
3
+ Falls back to themed Plotly if the model output isn't a valid SVG.
 
 
4
  """
5
 
6
+ import json
7
  import logging
8
+ import re
9
  from typing import Any, Dict, List
10
 
11
  import torch
12
  from transformers import AutoModelForCausalLM, AutoTokenizer
13
+ from peft import PeftModel
14
 
15
  from src.visualization.plotly_fallback import PlotlyRenderer
16
  from src.visualization.svg_theme import apply_theme, is_renderable_svg
 
24
  "minimalist style. Return only the SVG, starting with <svg."
25
  )
26
 
27
+ BASE_MODEL = "deepseek-ai/deepseek-coder-1.3b-instruct"
28
+ ADAPTER_REPO = "DanielRegaladoCardoso/svg-renderer-deepseek-coder-1.3b-lora"
29
 
30
 
31
  class SVGRenderer:
32
+
33
+ def __init__(self, temperature: float = 0.2, max_new_tokens: int = 1500) -> None:
 
 
 
 
 
 
 
34
  self.temperature = temperature
35
  self.max_new_tokens = max_new_tokens
36
  self._plotly = PlotlyRenderer()
37
+ self.model = None
38
+ self.tokenizer = None
39
 
 
40
  try:
41
+ logger.info(f"Loading SVG base: {BASE_MODEL}")
42
+ self.tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True)
43
+ base = AutoModelForCausalLM.from_pretrained(
44
+ BASE_MODEL,
45
  torch_dtype=torch.bfloat16,
46
  device_map="cuda",
47
+ trust_remote_code=True,
48
  )
49
+ # Try LoRA. If it fails (e.g., adapter has only model weights as one-piece file
50
+ # rather than a peft adapter), fall back to base model.
51
+ try:
52
+ self.model = PeftModel.from_pretrained(
53
+ base,
54
+ ADAPTER_REPO,
55
+ torch_dtype=torch.bfloat16,
56
+ )
57
+ logger.info("SVG renderer ready (LoRA applied)")
58
+ except Exception as e:
59
+ logger.warning(f"LoRA load failed ({e}); using base model")
60
+ self.model = base
61
  self.model.eval()
 
62
  except Exception as e:
63
+ logger.warning(f"SVG model load failed entirely ({e}); Plotly fallback only")
64
  self.model = None
65
  self.tokenizer = None
66
 
67
+ def generate(self, chart_spec: Dict[str, Any], data: List[Dict[str, Any]]) -> str:
 
 
 
 
68
  if self.model is not None and self.tokenizer is not None:
69
  try:
70
  svg = self._generate_model(chart_spec, data)
 
77
  svg = self._plotly.render(chart_spec, data)
78
  return apply_theme(svg)
79
 
80
+ def _generate_model(self, chart_spec: Dict[str, Any], data: List[Dict[str, Any]]) -> str:
 
 
 
 
81
  sample = data[:50]
82
  user_content = (
83
  f"Chart spec: {json.dumps(chart_spec, default=str)}\n"
 
108
 
109
  @staticmethod
110
  def _extract_svg(text: str) -> str:
 
111
  m = re.search(r"<svg[\s\S]*?</svg>", text, re.IGNORECASE)
112
  return m.group(0) if m else text.strip()