Spaces:
Sleeping
Sleeping
File size: 4,783 Bytes
a067ada 1bbdff9 a067ada 1bbdff9 a067ada 1bbdff9 a067ada 61aee8d 1bbdff9 61aee8d a067ada 1bbdff9 a067ada 61aee8d 1bbdff9 a067ada 1bbdff9 a067ada 1bbdff9 61aee8d 1bbdff9 a067ada 1bbdff9 a067ada 1bbdff9 a067ada 1bbdff9 0ac4ef2 a067ada 0ac4ef2 a067ada 0ac4ef2 a067ada 1bbdff9 a067ada | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 | """
SVG Renderer: load the trained LoRA on top of DeepSeek Coder 1.3B base.
Falls back to themed Plotly if the model output isn't a valid SVG.
"""
import json
import logging
import re
from typing import Any, Dict, List
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
from src.visualization.plotly_fallback import PlotlyRenderer
from src.visualization.svg_theme import apply_theme, is_renderable_svg
logger = logging.getLogger(__name__)
SYSTEM_PROMPT = (
"You are an SVG chart artist. Given a chart spec and a small data "
"sample, produce a single inline SVG visualization. Use a clean, "
"minimalist style. Return only the SVG, starting with <svg."
)
BASE_MODEL = "deepseek-ai/deepseek-coder-1.3b-instruct"
ADAPTER_REPO = "DanielRegaladoCardoso/svg-renderer-deepseek-coder-1.3b-lora"
class SVGRenderer:
def __init__(self, temperature: float = 0.2, max_new_tokens: int = 1500) -> None:
self.temperature = temperature
self.max_new_tokens = max_new_tokens
self._plotly = PlotlyRenderer()
self.model = None
self.tokenizer = None
try:
logger.info(f"Loading SVG base: {BASE_MODEL}")
self.tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True)
base = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
torch_dtype=torch.bfloat16,
device_map="cuda",
trust_remote_code=True,
)
# Try LoRA. If it fails (e.g., adapter has only model weights as one-piece file
# rather than a peft adapter), fall back to base model.
try:
self.model = PeftModel.from_pretrained(
base,
ADAPTER_REPO,
torch_dtype=torch.bfloat16,
)
logger.info("SVG renderer ready (LoRA applied)")
except Exception as e:
logger.warning(f"LoRA load failed ({e}); using base model")
self.model = base
self.model.eval()
except Exception as e:
logger.warning(f"SVG model load failed entirely ({e}); Plotly fallback only")
self.model = None
self.tokenizer = None
def generate(self, chart_spec: Dict[str, Any], data: List[Dict[str, Any]]) -> str:
"""Plotly first (reliable, consistent theming), trained model as fallback."""
try:
svg = self._plotly.render(chart_spec, data)
if is_renderable_svg(svg):
return apply_theme(svg)
logger.info("Plotly returned non-SVG; trying model")
except Exception as e:
logger.warning(f"Plotly render failed ({e}); trying model")
if self.model is not None and self.tokenizer is not None:
try:
svg = self._generate_model(chart_spec, data)
if is_renderable_svg(svg):
return apply_theme(svg)
except Exception as e:
logger.warning(f"Model SVG generation error: {e}")
# Last resort: native Python SVG (always produces something)
from src.visualization.plotly_fallback import PlotlyRenderer
svg = self._plotly._empty("Could not render chart; see Data section.")
return apply_theme(svg)
def _generate_model(self, chart_spec: Dict[str, Any], data: List[Dict[str, Any]]) -> str:
sample = data[:50]
user_content = (
f"Chart spec: {json.dumps(chart_spec, default=str)}\n"
f"Data ({len(data)} rows, showing {len(sample)}): "
f"{json.dumps(sample, default=str)}\n\n"
"Render an inline SVG. Use viewBox 0 0 600 400."
)
messages = [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": user_content},
]
input_ids = self.tokenizer.apply_chat_template(
messages, tokenize=True, add_generation_prompt=True, return_tensors="pt"
).to(self.model.device)
with torch.no_grad():
out = self.model.generate(
input_ids,
max_new_tokens=self.max_new_tokens,
do_sample=self.temperature > 0,
temperature=self.temperature if self.temperature > 0 else 1.0,
pad_token_id=self.tokenizer.eos_token_id,
)
text = self.tokenizer.decode(
out[0][input_ids.shape[1]:], skip_special_tokens=True
)
return self._extract_svg(text)
@staticmethod
def _extract_svg(text: str) -> str:
m = re.search(r"<svg[\s\S]*?</svg>", text, re.IGNORECASE)
return m.group(0) if m else text.strip()
|