DanielRegaladoCardoso commited on
Commit
61aee8d
·
verified ·
1 Parent(s): a57eca6

ZeroGPU best practice: load models at module level (cuda), inference only inside @spaces.GPU

Browse files
Files changed (1) hide show
  1. src/models/svg_renderer.py +14 -30
src/models/svg_renderer.py CHANGED
@@ -1,19 +1,16 @@
1
  """
2
  SVG Renderer: chart spec + data -> inline SVG.
3
 
4
- Strategy:
5
- 1. Try the trained DeepSeek-Coder-1.3B SVG renderer model.
6
- 2. If its output isn't a valid SVG, fall back to the Plotly themed renderer.
7
-
8
- Either path goes through `apply_theme()` to enforce a consistent
9
- Apple/Claude visual: monochrome with one warm accent, thin strokes,
10
- SF font stack, responsive viewBox.
11
  """
12
 
13
  import logging
14
  from typing import Any, Dict, List
15
 
16
- from src.models.base import BaseModel
 
 
17
  from src.visualization.plotly_fallback import PlotlyRenderer
18
  from src.visualization.svg_theme import apply_theme, is_renderable_svg
19
 
@@ -26,11 +23,11 @@ SYSTEM_PROMPT = (
26
  "minimalist style. Return only the SVG, starting with <svg."
27
  )
28
 
 
29
 
30
- class SVGRenderer(BaseModel):
31
- """Render a chart spec to inline SVG."""
32
 
33
- DEFAULT_MODEL = "DanielRegaladoCardoso/svg-renderer-deepseek-coder-1.3b-lora"
 
34
 
35
  def __init__(
36
  self,
@@ -38,42 +35,31 @@ class SVGRenderer(BaseModel):
38
  temperature: float = 0.2,
39
  max_new_tokens: int = 1500,
40
  ) -> None:
41
- super().__init__(model_name="svg-renderer")
42
  self.hf_model = hf_model
43
  self.temperature = temperature
44
  self.max_new_tokens = max_new_tokens
45
  self._plotly = PlotlyRenderer()
46
 
47
- def load(self) -> None:
48
- from transformers import AutoModelForCausalLM, AutoTokenizer
49
- import torch
50
-
51
- logger.info(f"Loading SVG renderer: {self.hf_model}")
52
- device = "cuda" if torch.cuda.is_available() else "cpu"
53
- dtype = torch.bfloat16 if device == "cuda" else torch.float32
54
-
55
  try:
56
  self.tokenizer = AutoTokenizer.from_pretrained(self.hf_model)
57
  self.model = AutoModelForCausalLM.from_pretrained(
58
  self.hf_model,
59
- torch_dtype=dtype,
60
- device_map=device,
61
  )
62
  self.model.eval()
63
- self.is_loaded = True
64
- logger.info(f"SVG renderer loaded on {device}")
65
  except Exception as e:
66
- logger.warning(f"SVG model load failed ({e}); will use Plotly fallback")
67
  self.model = None
68
  self.tokenizer = None
69
- self.is_loaded = True # we can still render via Plotly
70
 
71
- def generate( # type: ignore[override]
72
  self,
73
  chart_spec: Dict[str, Any],
74
  data: List[Dict[str, Any]],
75
  ) -> str:
76
- # 1) Try trained model
77
  if self.model is not None and self.tokenizer is not None:
78
  try:
79
  svg = self._generate_model(chart_spec, data)
@@ -83,7 +69,6 @@ class SVGRenderer(BaseModel):
83
  except Exception as e:
84
  logger.warning(f"Model SVG generation error: {e}; falling back")
85
 
86
- # 2) Plotly fallback
87
  svg = self._plotly.render(chart_spec, data)
88
  return apply_theme(svg)
89
 
@@ -91,7 +76,6 @@ class SVGRenderer(BaseModel):
91
  self, chart_spec: Dict[str, Any], data: List[Dict[str, Any]]
92
  ) -> str:
93
  import json
94
- import torch
95
 
96
  sample = data[:50]
97
  user_content = (
 
1
  """
2
  SVG Renderer: chart spec + data -> inline SVG.
3
 
4
+ Model loaded at root module level (ZeroGPU best practice). If the model
5
+ output isn't a valid SVG, falls back to themed Plotly.
 
 
 
 
 
6
  """
7
 
8
  import logging
9
  from typing import Any, Dict, List
10
 
11
+ import torch
12
+ from transformers import AutoModelForCausalLM, AutoTokenizer
13
+
14
  from src.visualization.plotly_fallback import PlotlyRenderer
15
  from src.visualization.svg_theme import apply_theme, is_renderable_svg
16
 
 
23
  "minimalist style. Return only the SVG, starting with <svg."
24
  )
25
 
26
+ DEFAULT_MODEL = "DanielRegaladoCardoso/svg-renderer-deepseek-coder-1.3b-lora"
27
 
 
 
28
 
29
+ class SVGRenderer:
30
+ """Render a chart spec to inline SVG."""
31
 
32
  def __init__(
33
  self,
 
35
  temperature: float = 0.2,
36
  max_new_tokens: int = 1500,
37
  ) -> None:
 
38
  self.hf_model = hf_model
39
  self.temperature = temperature
40
  self.max_new_tokens = max_new_tokens
41
  self._plotly = PlotlyRenderer()
42
 
43
+ logger.info(f"Loading SVG renderer at module level: {self.hf_model}")
 
 
 
 
 
 
 
44
  try:
45
  self.tokenizer = AutoTokenizer.from_pretrained(self.hf_model)
46
  self.model = AutoModelForCausalLM.from_pretrained(
47
  self.hf_model,
48
+ torch_dtype=torch.bfloat16,
49
+ device_map="cuda",
50
  )
51
  self.model.eval()
52
+ logger.info("SVG renderer ready")
 
53
  except Exception as e:
54
+ logger.warning(f"SVG model load failed ({e}); will use Plotly fallback only")
55
  self.model = None
56
  self.tokenizer = None
 
57
 
58
+ def generate(
59
  self,
60
  chart_spec: Dict[str, Any],
61
  data: List[Dict[str, Any]],
62
  ) -> str:
 
63
  if self.model is not None and self.tokenizer is not None:
64
  try:
65
  svg = self._generate_model(chart_spec, data)
 
69
  except Exception as e:
70
  logger.warning(f"Model SVG generation error: {e}; falling back")
71
 
 
72
  svg = self._plotly.render(chart_spec, data)
73
  return apply_theme(svg)
74
 
 
76
  self, chart_spec: Dict[str, Any], data: List[Dict[str, Any]]
77
  ) -> str:
78
  import json
 
79
 
80
  sample = data[:50]
81
  user_content = (