File size: 4,783 Bytes
a067ada
1bbdff9
 
a067ada
 
1bbdff9
a067ada
1bbdff9
a067ada
 
61aee8d
 
1bbdff9
61aee8d
a067ada
 
 
 
 
 
 
 
 
 
 
 
1bbdff9
 
a067ada
 
61aee8d
1bbdff9
 
a067ada
 
 
1bbdff9
 
a067ada
 
1bbdff9
 
 
 
61aee8d
 
1bbdff9
a067ada
1bbdff9
 
 
 
 
 
 
 
 
 
 
 
a067ada
 
1bbdff9
a067ada
 
 
1bbdff9
0ac4ef2
 
 
 
 
 
 
 
 
a067ada
 
 
 
 
 
0ac4ef2
a067ada
0ac4ef2
 
 
a067ada
 
1bbdff9
a067ada
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
"""
SVG Renderer: load the trained LoRA on top of DeepSeek Coder 1.3B base.
Falls back to themed Plotly if the model output isn't a valid SVG.
"""

import json
import logging
import re
from typing import Any, Dict, List

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel

from src.visualization.plotly_fallback import PlotlyRenderer
from src.visualization.svg_theme import apply_theme, is_renderable_svg

logger = logging.getLogger(__name__)


SYSTEM_PROMPT = (
    "You are an SVG chart artist. Given a chart spec and a small data "
    "sample, produce a single inline SVG visualization. Use a clean, "
    "minimalist style. Return only the SVG, starting with <svg."
)

BASE_MODEL = "deepseek-ai/deepseek-coder-1.3b-instruct"
ADAPTER_REPO = "DanielRegaladoCardoso/svg-renderer-deepseek-coder-1.3b-lora"


class SVGRenderer:

    def __init__(self, temperature: float = 0.2, max_new_tokens: int = 1500) -> None:
        self.temperature = temperature
        self.max_new_tokens = max_new_tokens
        self._plotly = PlotlyRenderer()
        self.model = None
        self.tokenizer = None

        try:
            logger.info(f"Loading SVG base: {BASE_MODEL}")
            self.tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True)
            base = AutoModelForCausalLM.from_pretrained(
                BASE_MODEL,
                torch_dtype=torch.bfloat16,
                device_map="cuda",
                trust_remote_code=True,
            )
            # Try LoRA. If it fails (e.g., adapter has only model weights as one-piece file
            # rather than a peft adapter), fall back to base model.
            try:
                self.model = PeftModel.from_pretrained(
                    base,
                    ADAPTER_REPO,
                    torch_dtype=torch.bfloat16,
                )
                logger.info("SVG renderer ready (LoRA applied)")
            except Exception as e:
                logger.warning(f"LoRA load failed ({e}); using base model")
                self.model = base
            self.model.eval()
        except Exception as e:
            logger.warning(f"SVG model load failed entirely ({e}); Plotly fallback only")
            self.model = None
            self.tokenizer = None

    def generate(self, chart_spec: Dict[str, Any], data: List[Dict[str, Any]]) -> str:
        """Plotly first (reliable, consistent theming), trained model as fallback."""
        try:
            svg = self._plotly.render(chart_spec, data)
            if is_renderable_svg(svg):
                return apply_theme(svg)
            logger.info("Plotly returned non-SVG; trying model")
        except Exception as e:
            logger.warning(f"Plotly render failed ({e}); trying model")

        if self.model is not None and self.tokenizer is not None:
            try:
                svg = self._generate_model(chart_spec, data)
                if is_renderable_svg(svg):
                    return apply_theme(svg)
            except Exception as e:
                logger.warning(f"Model SVG generation error: {e}")

        # Last resort: native Python SVG (always produces something)
        from src.visualization.plotly_fallback import PlotlyRenderer
        svg = self._plotly._empty("Could not render chart; see Data section.")
        return apply_theme(svg)

    def _generate_model(self, chart_spec: Dict[str, Any], data: List[Dict[str, Any]]) -> str:
        sample = data[:50]
        user_content = (
            f"Chart spec: {json.dumps(chart_spec, default=str)}\n"
            f"Data ({len(data)} rows, showing {len(sample)}): "
            f"{json.dumps(sample, default=str)}\n\n"
            "Render an inline SVG. Use viewBox 0 0 600 400."
        )
        messages = [
            {"role": "system", "content": SYSTEM_PROMPT},
            {"role": "user", "content": user_content},
        ]
        input_ids = self.tokenizer.apply_chat_template(
            messages, tokenize=True, add_generation_prompt=True, return_tensors="pt"
        ).to(self.model.device)

        with torch.no_grad():
            out = self.model.generate(
                input_ids,
                max_new_tokens=self.max_new_tokens,
                do_sample=self.temperature > 0,
                temperature=self.temperature if self.temperature > 0 else 1.0,
                pad_token_id=self.tokenizer.eos_token_id,
            )
        text = self.tokenizer.decode(
            out[0][input_ids.shape[1]:], skip_special_tokens=True
        )
        return self._extract_svg(text)

    @staticmethod
    def _extract_svg(text: str) -> str:
        m = re.search(r"<svg[\s\S]*?</svg>", text, re.IGNORECASE)
        return m.group(0) if m else text.strip()