TikZilla
Collection
Text-Guided TikZ Graphics Program Generation for Scientific Figures • 5 items • Updated
TikZilla-8B is a language model for generating TikZ/LaTeX figures from natural language descriptions.
It is based on Qwen3-8B-Base and was trained with supervised fine-tuning (SFT) on DaTikZ-V4 for scientific figure generation.
pip install torch==2.5.1 transformers==4.53.2 accelerate==1.8.1
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
model_id = "nllg/TikZilla-8B"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
device_map="auto",
)
eos_token_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
pad_token_id = tokenizer.pad_token_id or tokenizer.eos_token_id or eos_token_id
gen_config = GenerationConfig(
do_sample=True,
temperature=1.0,
top_p=0.9,
max_new_tokens=2048,
eos_token_id=eos_token_id,
pad_token_id=pad_token_id,
)
your_input_description = "A scientific line plot showing two curves. The x-axis is labeled 'Time' ranging from 0 to 100, and the y-axis is labeled 'Value' ranging from 0 to 1. The first curve is a blue solid line that gradually increases from near 0 and levels off around 0.9. The second curve is a red dashed line that rises to a peak around the middle of the plot and then decreases. A legend in the upper right labels the blue line as 'Model A' and the red dashed line as 'Model B'. The background is white with light gray grid lines."
messages = [
{
"role": "user",
"content": (
"Generate a complete LaTeX document that contains a TikZ figure according to the following requirements:\n"
+ your_input_description +
"\nWrap your code using \\documentclass[tikz]{standalone}, and include \\begin{document}...\\end{document}. "
"Only output valid LaTeX code with no extra text."
),
}
]
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
inputs = tokenizer([text], return_tensors="pt").to(model.device)
output_ids = model.generate(**inputs, generation_config=gen_config)
response_ids = output_ids[0][len(inputs["input_ids"][0]):]
output = tokenizer.decode(response_ids, skip_special_tokens=True)
print(output)