Spaces:
Running on Zero
Running on Zero
File size: 4,119 Bytes
a067ada eb30a86 a067ada a57eca6 eb30a86 a067ada eb30a86 a067ada eb30a86 a067ada a57eca6 a067ada eb30a86 a067ada eb30a86 a57eca6 eb30a86 a067ada a57eca6 a067ada a57eca6 a067ada eb30a86 a067ada eb30a86 a067ada a57eca6 a067ada eb30a86 a067ada eb30a86 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 | """
Chart Reasoner: load the trained LoRA on top of Phi-3 Mini base.
"""
import json
import logging
import re
from typing import Any, Dict, List
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
logger = logging.getLogger(__name__)
SYSTEM_PROMPT = (
"You are a data visualization expert. Given a question, the SQL that "
"answers it, and a sample of the result rows, produce a JSON chart "
"specification. Return only valid JSON, no commentary."
)
BASE_MODEL = "microsoft/Phi-3-mini-4k-instruct"
ADAPTER_REPO = "DanielRegaladoCardoso/chart-reasoner-phi3-mini-adapter-only"
class ChartReasoner:
def __init__(self, temperature: float = 0.0, max_new_tokens: int = 300) -> None:
self.temperature = temperature
self.max_new_tokens = max_new_tokens
logger.info(f"Loading chart base: {BASE_MODEL}")
self.tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
base = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
torch_dtype=torch.bfloat16,
device_map="cuda",
trust_remote_code=True,
)
logger.info(f"Applying LoRA adapter: {ADAPTER_REPO}")
self.model = PeftModel.from_pretrained(
base,
ADAPTER_REPO,
torch_dtype=torch.bfloat16,
)
self.model.eval()
logger.info("Chart reasoner ready")
def generate(
self,
question: str,
sql: str,
results: List[Dict[str, Any]],
columns: List[Dict[str, Any]],
) -> Dict[str, Any]:
sample = results[:5]
col_names = [c["name"] for c in columns]
user_content = (
f"Question: {question}\n"
f"SQL: {sql}\n"
f"Columns: {col_names}\n"
f"Sample rows: {json.dumps(sample, default=str)}\n\n"
"Return JSON with: chart_type (one of: bar, line, scatter, pie, "
"area, table), title, x_column, y_column, color_column "
"(optional), rationale."
)
messages = [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": user_content},
]
input_ids = self.tokenizer.apply_chat_template(
messages, tokenize=True, add_generation_prompt=True, return_tensors="pt"
).to(self.model.device)
with torch.no_grad():
out = self.model.generate(
input_ids,
max_new_tokens=self.max_new_tokens,
do_sample=self.temperature > 0,
temperature=self.temperature if self.temperature > 0 else 1.0,
pad_token_id=self.tokenizer.eos_token_id,
)
raw = self.tokenizer.decode(
out[0][input_ids.shape[1]:], skip_special_tokens=True
)
return self._parse_spec(raw, columns)
def _parse_spec(self, text: str, columns: List[Dict[str, Any]]) -> Dict[str, Any]:
match = re.search(r"\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}", text, re.DOTALL)
if not match:
return self._fallback_spec(columns)
try:
spec = json.loads(match.group(0))
except json.JSONDecodeError:
return self._fallback_spec(columns)
return {
"chart_type": spec.get("chart_type", "bar").lower(),
"title": spec.get("title", "Result"),
"x_column": spec.get("x_column"),
"y_column": spec.get("y_column"),
"color_column": spec.get("color_column"),
"rationale": spec.get("rationale", ""),
}
def _fallback_spec(self, columns: List[Dict[str, Any]]) -> Dict[str, Any]:
if not columns:
return {"chart_type": "table", "title": "Result"}
if len(columns) == 1:
return {"chart_type": "table", "title": "Result",
"x_column": columns[0]["name"], "y_column": None}
return {"chart_type": "bar", "title": "Result",
"x_column": columns[0]["name"],
"y_column": columns[1]["name"], "color_column": None}
|