sql-agent / src /models /svg_renderer.py
DanielRegaladoCardoso's picture
SVG: standalone download (XML prolog, white bg, fixed dims) + Plotly first
0ac4ef2 verified
"""
SVG Renderer: load the trained LoRA on top of DeepSeek Coder 1.3B base.
Falls back to themed Plotly if the model output isn't a valid SVG.
"""
import json
import logging
import re
from typing import Any, Dict, List
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
from src.visualization.plotly_fallback import PlotlyRenderer
from src.visualization.svg_theme import apply_theme, is_renderable_svg
logger = logging.getLogger(__name__)
SYSTEM_PROMPT = (
"You are an SVG chart artist. Given a chart spec and a small data "
"sample, produce a single inline SVG visualization. Use a clean, "
"minimalist style. Return only the SVG, starting with <svg."
)
BASE_MODEL = "deepseek-ai/deepseek-coder-1.3b-instruct"
ADAPTER_REPO = "DanielRegaladoCardoso/svg-renderer-deepseek-coder-1.3b-lora"
class SVGRenderer:
def __init__(self, temperature: float = 0.2, max_new_tokens: int = 1500) -> None:
self.temperature = temperature
self.max_new_tokens = max_new_tokens
self._plotly = PlotlyRenderer()
self.model = None
self.tokenizer = None
try:
logger.info(f"Loading SVG base: {BASE_MODEL}")
self.tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True)
base = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
torch_dtype=torch.bfloat16,
device_map="cuda",
trust_remote_code=True,
)
# Try LoRA. If it fails (e.g., adapter has only model weights as one-piece file
# rather than a peft adapter), fall back to base model.
try:
self.model = PeftModel.from_pretrained(
base,
ADAPTER_REPO,
torch_dtype=torch.bfloat16,
)
logger.info("SVG renderer ready (LoRA applied)")
except Exception as e:
logger.warning(f"LoRA load failed ({e}); using base model")
self.model = base
self.model.eval()
except Exception as e:
logger.warning(f"SVG model load failed entirely ({e}); Plotly fallback only")
self.model = None
self.tokenizer = None
def generate(self, chart_spec: Dict[str, Any], data: List[Dict[str, Any]]) -> str:
"""Plotly first (reliable, consistent theming), trained model as fallback."""
try:
svg = self._plotly.render(chart_spec, data)
if is_renderable_svg(svg):
return apply_theme(svg)
logger.info("Plotly returned non-SVG; trying model")
except Exception as e:
logger.warning(f"Plotly render failed ({e}); trying model")
if self.model is not None and self.tokenizer is not None:
try:
svg = self._generate_model(chart_spec, data)
if is_renderable_svg(svg):
return apply_theme(svg)
except Exception as e:
logger.warning(f"Model SVG generation error: {e}")
# Last resort: native Python SVG (always produces something)
from src.visualization.plotly_fallback import PlotlyRenderer
svg = self._plotly._empty("Could not render chart; see Data section.")
return apply_theme(svg)
def _generate_model(self, chart_spec: Dict[str, Any], data: List[Dict[str, Any]]) -> str:
sample = data[:50]
user_content = (
f"Chart spec: {json.dumps(chart_spec, default=str)}\n"
f"Data ({len(data)} rows, showing {len(sample)}): "
f"{json.dumps(sample, default=str)}\n\n"
"Render an inline SVG. Use viewBox 0 0 600 400."
)
messages = [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": user_content},
]
input_ids = self.tokenizer.apply_chat_template(
messages, tokenize=True, add_generation_prompt=True, return_tensors="pt"
).to(self.model.device)
with torch.no_grad():
out = self.model.generate(
input_ids,
max_new_tokens=self.max_new_tokens,
do_sample=self.temperature > 0,
temperature=self.temperature if self.temperature > 0 else 1.0,
pad_token_id=self.tokenizer.eos_token_id,
)
text = self.tokenizer.decode(
out[0][input_ids.shape[1]:], skip_special_tokens=True
)
return self._extract_svg(text)
@staticmethod
def _extract_svg(text: str) -> str:
m = re.search(r"<svg[\s\S]*?</svg>", text, re.IGNORECASE)
return m.group(0) if m else text.strip()