math-to-visual-agent / generator.py
vankhieu's picture
Add stop criteria toggle
c13d7d0
Raw
History Blame Contribute Delete
19.4 kB
"""Local Manim model inference with render-time self correction.
This module recreates the runtime ideas from manim-trainer's inference script
without importing that repo: load a local Unsloth/Transformers model, generate
Manim code, extract code blocks, render, and feed render errors back in.
"""
from __future__ import annotations
import json
import os
import re
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Optional, Tuple
from executor import render_manim_scene
from prompt_templates import build_feedback_messages, build_initial_messages
DEFAULT_MODEL_PATH = "vankhieu/Seed_Coder_8B_Instruct_unsloth_bnb_4bit_lora_r8_sft_grpo_rw_mean_text_visual"
DEFAULT_SELECTED_MODEL = "unsloth/Seed-Coder-8B-Instruct-unsloth-bnb-4bit"
THINKING_TOKEN_ID = 151668
THINKING_TOKEN = "</think>"
@dataclass
class LocalModelConfig:
model_path: str = DEFAULT_MODEL_PATH
selected_model: str = DEFAULT_SELECTED_MODEL
model_kind: str = "adapter"
base_model_path: Optional[str] = None
backend: str = "auto"
prompt_mode: str = "chat"
device_map: str = "auto"
load_in_4bit: bool = True
max_new_tokens: int = 8192
temperature: float = 0.0
top_p: float = 0.9
use_stop_criteria: bool = True
timeout_seconds: int = 300
class LocalManimModel:
"""Small inference wrapper compatible with Unsloth and plain Transformers."""
def __init__(self, config: LocalModelConfig) -> None:
self.config = config
self.model = None
self.tokenizer = None
self.backend_used = ""
self.error: Optional[str] = None
self.resolved_model_path: Optional[Path] = None
model_name_for_rules = config.selected_model or config.model_path
self.remove_token_type_ids = "seed-coder" in model_name_for_rules.lower()
self.no_system_role_support = "codegemma" in model_name_for_rules.lower()
self._load()
@property
def ready(self) -> bool:
return self.model is not None and self.tokenizer is not None and self.error is None
def _load(self) -> None:
model_path = self._resolve_model_path()
if model_path is None:
return
self.resolved_model_path = model_path
requested = self.config.backend.lower()
if self.config.model_kind == "adapter" and requested in ("auto", "unsloth"):
if self._load_with_unsloth(model_path):
return
if requested == "unsloth":
return
self._load_with_transformers(model_path)
def _resolve_model_path(self) -> Optional[Path]:
configured_path = Path(self.config.model_path)
if configured_path.exists():
return configured_path
try:
from huggingface_hub import snapshot_download
except Exception as exc:
self.error = (
f"Model path `{self.config.model_path}` is not local, and huggingface_hub "
f"is unavailable for download: {type(exc).__name__}: {exc}"
)
return None
try:
if self.config.model_kind == "base":
downloaded_path = snapshot_download(
repo_id=self.config.model_path,
repo_type="model",
)
else:
downloaded_path = snapshot_download(
repo_id=self.config.model_path,
repo_type="model",
allow_patterns=[
"README.md",
"adapter_config.json",
"adapter_model.safetensors",
"chat_template.jinja",
"special_tokens_map.json",
"tokenizer.json",
"tokenizer_config.json",
],
)
return Path(downloaded_path)
except Exception as exc:
self.error = f"Failed to download model adapter `{self.config.model_path}`: {type(exc).__name__}: {exc}"
return None
def _load_with_unsloth(self, model_path: Path) -> bool:
try:
from unsloth import FastLanguageModel, FastModel
except Exception as exc:
self.error = f"Unsloth is unavailable: {type(exc).__name__}: {exc}"
return False
try:
loader = FastModel if self._is_moe(self.config.selected_model) else FastLanguageModel
model, tokenizer = loader.from_pretrained(
model_name=str(model_path),
max_seq_length=self.config.max_new_tokens,
dtype=None,
load_in_4bit=self.config.load_in_4bit,
)
loader.for_inference(model)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
self.model = model
self.tokenizer = tokenizer
self.backend_used = "unsloth"
self.error = None
return True
except Exception as exc:
self.error = f"Failed to load with Unsloth: {type(exc).__name__}: {exc}"
return False
def _load_with_transformers(self, model_path: Path) -> None:
try:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
except Exception as exc:
self.error = (
"Could not import local model dependencies. Install torch, transformers, "
f"accelerate, peft, and optionally unsloth. Details: {type(exc).__name__}: {exc}"
)
return
try:
tokenizer = AutoTokenizer.from_pretrained(
str(model_path),
trust_remote_code=True,
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
adapter_config = model_path / "adapter_config.json"
if self.config.model_kind == "adapter" and adapter_config.exists():
from peft import PeftModel
base_model_path = self.config.base_model_path or self._adapter_base_path(adapter_config)
kwargs = self._transformers_load_kwargs(torch, BitsAndBytesConfig)
base_model = AutoModelForCausalLM.from_pretrained(
base_model_path,
trust_remote_code=True,
**kwargs,
)
model = PeftModel.from_pretrained(base_model, str(model_path))
else:
kwargs = self._transformers_load_kwargs(torch, BitsAndBytesConfig)
model = AutoModelForCausalLM.from_pretrained(
str(model_path),
trust_remote_code=True,
**kwargs,
)
model.eval()
self.model = model
self.tokenizer = tokenizer
self.backend_used = "transformers"
self.error = None
except Exception as exc:
self.error = f"Failed to load local model with Transformers: {type(exc).__name__}: {exc}"
def _transformers_load_kwargs(self, torch, bits_and_bytes_config_cls) -> Dict[str, object]:
if self.config.load_in_4bit:
return {
"device_map": self.config.device_map,
"torch_dtype": "auto",
"quantization_config": bits_and_bytes_config_cls(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=False,
bnb_4bit_compute_dtype=torch.bfloat16,
),
}
return {"device_map": self.config.device_map, "torch_dtype": "auto"}
def _adapter_base_path(self, adapter_config: Path) -> str:
if self.config.base_model_path:
return self.config.base_model_path
data = json.loads(adapter_config.read_text(encoding="utf-8"))
base_model = data.get("base_model_name_or_path")
if not base_model:
raise ValueError("adapter_config.json does not define base_model_name_or_path.")
return base_model
def _is_moe(self, model_id: str) -> bool:
lowered = model_id.lower()
return any(marker in lowered for marker in ("mixtral", "moe", "deepseek-v2", "qwen3-moe"))
def generate(self, messages: List[Dict[str, str]]) -> Tuple[str, str]:
if not self.ready:
raise RuntimeError(self.error or "Local model is not ready.")
import torch
if self.no_system_role_support and len(messages) >= 2 and messages[0]["role"] == "system":
messages = [{"role": "user", "content": messages[0]["content"] + "\n\n" + messages[1]["content"]}]
prompt = self._format_prompt(messages)
model_inputs = self.tokenizer(text=[prompt], return_tensors="pt")
model_device = getattr(self.model, "device", None)
if model_device is not None:
model_inputs = model_inputs.to(model_device)
if self.remove_token_type_ids and "token_type_ids" in model_inputs:
del model_inputs["token_type_ids"]
generate_kwargs = {
**model_inputs,
"max_new_tokens": self.config.max_new_tokens,
"pad_token_id": self.tokenizer.pad_token_id,
"eos_token_id": self.tokenizer.eos_token_id,
"use_cache": True,
}
if self.config.use_stop_criteria:
from transformers import StoppingCriteriaList
generate_kwargs["stopping_criteria"] = StoppingCriteriaList(
[_StopOnTokenSequence(self.tokenizer.encode("</CODE>", add_special_tokens=False))]
)
if self.config.temperature and self.config.temperature > 0:
generate_kwargs.update(
{
"do_sample": True,
"temperature": self.config.temperature,
"top_p": self.config.top_p,
}
)
else:
generate_kwargs["do_sample"] = False
with torch.inference_mode():
generated_ids = self.model.generate(**generate_kwargs)
output_ids = generated_ids[0][len(model_inputs.input_ids[0]) :].tolist()
try:
index = len(output_ids) - output_ids[::-1].index(THINKING_TOKEN_ID)
except ValueError:
index = 0
thinking = self.tokenizer.decode(output_ids[:index], skip_special_tokens=True).strip()
completion = self.tokenizer.decode(output_ids[index:], skip_special_tokens=True).strip()
if completion.startswith(THINKING_TOKEN):
completion = completion.removeprefix(THINKING_TOKEN).strip()
return thinking, completion
def _format_prompt(self, messages: List[Dict[str, str]]) -> str:
if self.config.prompt_mode == "chat" and getattr(self.tokenizer, "chat_template", None):
try:
return self.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
enable_thinking=False,
)
except TypeError:
return self.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
chunks = []
for message in messages:
chunks.append(f"### {message['role'].upper()}\n{message['content'].strip()}")
chunks.append("### ASSISTANT\n")
return "\n\n".join(chunks)
class _StopOnTokenSequence:
"""Stop generation when a specific token suffix appears."""
def __init__(self, stop_token_ids: List[int]) -> None:
self.stop_token_ids = stop_token_ids
def __call__(self, input_ids, scores, **kwargs) -> bool:
del scores, kwargs
if not self.stop_token_ids:
return False
if input_ids.shape[-1] < len(self.stop_token_ids):
return False
tail = input_ids[0, -len(self.stop_token_ids) :].tolist()
return tail == self.stop_token_ids
class ManimVisualAgent:
"""Generate Manim code with local inference and repair failures."""
def __init__(self, config: Optional[LocalModelConfig] = None, **overrides: object) -> None:
if config is None:
config = LocalModelConfig(**overrides)
self.config = config
self.runtime = LocalManimModel(config)
self.model_path = config.model_path
self.model_error = self.runtime.error
@property
def ready(self) -> bool:
return self.runtime.ready
def _model_status_line(self) -> str:
resolved = self.runtime.resolved_model_path
if resolved is None:
return f"[model] backend=unavailable kind={self.config.model_kind} model={self.config.model_path}"
return (
f"[model] backend={self.runtime.backend_used or 'unavailable'} "
f"kind={self.config.model_kind} model={self.config.model_path} cache={resolved}"
)
def generate_and_fix(
self,
user_prompt: str,
domain: str = "Mathematics",
max_retries: int = 3,
) -> Tuple[bool, Optional[str], str, str]:
if not user_prompt or not user_prompt.strip():
return False, None, "", "Prompt is empty. Describe the concept to visualize."
terminal_log_parts = [self._model_status_line()]
messages = self._initial_messages(user_prompt, domain)
best_code = ""
total_attempts = max(1, int(max_retries) + 1)
for attempt in range(1, total_attempts + 1):
terminal_log_parts.append(f"[attempt {attempt}/{total_attempts}] Generating Manim code...")
try:
raw_output = self._call_or_fallback(messages)
best_code = extract_manim_code(raw_output)
if not best_code:
best_code = raw_output.strip()
except Exception as exc:
return (
False,
None,
best_code,
"\n".join(terminal_log_parts)
+ f"\nLocal model generation failed: {type(exc).__name__}: {exc}",
)
terminal_log_parts.append("[executor] Rendering with isolated Manim subprocess...")
success, render_result = render_manim_scene(
best_code,
timeout_seconds=self.config.timeout_seconds,
)
if success:
terminal_log_parts.append(f"[success] Rendered video: {render_result}")
return True, render_result, best_code, "\n".join(terminal_log_parts)
terminal_log_parts.append("[error] Manim failed. Feeding render errors back to the model.")
terminal_log_parts.append(render_result)
messages = self._feedback_messages(user_prompt, domain, best_code, render_result)
terminal_log_parts.append("[failed] Maximum self-correction retries exhausted.")
return False, None, best_code, "\n".join(terminal_log_parts)
def _call_or_fallback(self, messages: List[Dict[str, str]]) -> str:
if not self.ready:
return f"<CODE>\n{self._fallback_scene(messages[-1]['content'])}\n</CODE>"
_, completion = self.runtime.generate(messages)
return completion
def _initial_messages(self, user_prompt: str, domain: str) -> List[Dict[str, str]]:
del domain
return build_initial_messages(user_prompt)
def _feedback_messages(
self,
user_prompt: str,
domain: str,
initial_code: str,
render_errors: str,
) -> List[Dict[str, str]]:
del domain
return build_feedback_messages(user_prompt, initial_code, render_errors)
def _fallback_scene(self, prompt: str) -> str:
safe_message = (self.model_error or "Local model is not ready.").replace("\\", "\\\\").replace('"', '\\"')
safe_prompt = prompt.replace("\\", "\\\\").replace('"', '\\"')[:260]
return f"""from manim import *
class MainScene(Scene):
def construct(self):
self.camera.background_color = "#0b0f19"
title = Text("SciVisual-Agent Offline", color="#00ffcc", font_size=42)
subtitle = Text("Local model could not be initialized.", color="#ff0055", font_size=24)
detail = Text("{safe_message}", color=GRAY_B, font_size=16).scale_to_fit_width(config.frame_width - 1)
prompt = Text("Prompt: {safe_prompt}", color=WHITE, font_size=18).scale_to_fit_width(config.frame_width - 1)
group = VGroup(title, subtitle, detail, prompt).arrange(DOWN, buff=0.35)
self.play(FadeIn(group, shift=UP * 0.2))
self.wait(2)
"""
def extract_manim_code(response: str, select_index: int = -1) -> str:
"""Extract Manim code from manim-trainer-style responses."""
response = response or ""
response = re.sub(r"`<CODE>`|`</CODE>`", "", response)
response = (
response.replace("<code>", "<CODE>")
.replace("</code>", "</CODE>")
.replace("CODE>", "<CODE>")
.replace("</<CODE>", "</CODE>")
)
matches = re.findall(r"<CODE>(.*?)</CODE>", response, re.DOTALL)
if not matches:
matches = re.findall(r"```python\s*(.*?)```", response, re.DOTALL | re.IGNORECASE)
if not matches:
matches = re.findall(r"```\s*(.*?)```", response, re.DOTALL)
if matches:
return _clean_extracted_code(matches[select_index])
if "from manim import" in response and "class " in response:
return _clean_extracted_code(response)
return ""
def _clean_extracted_code(code: str) -> str:
"""Remove wrapper tags and markdown fences from extracted Manim code."""
code = code or ""
code = re.sub(r"</?CODE>", "", code, flags=re.IGNORECASE)
code = re.sub(r"^\s*```(?:python)?\s*", "", code, flags=re.IGNORECASE)
code = re.sub(r"\s*```\s*$", "", code)
return code.strip()
def build_default_agent() -> ManimVisualAgent:
model_kind = os.environ.get("SCIVISUAL_MODEL_KIND", "adapter").strip().lower()
if model_kind not in {"adapter", "base"}:
model_kind = "adapter"
config = LocalModelConfig(
model_path=os.environ.get("SCIVISUAL_MODEL_PATH", DEFAULT_MODEL_PATH),
selected_model=os.environ.get("SCIVISUAL_SELECTED_MODEL", DEFAULT_SELECTED_MODEL),
model_kind=model_kind,
base_model_path=os.environ.get("SCIVISUAL_BASE_MODEL_PATH") or None,
backend=os.environ.get("SCIVISUAL_BACKEND", "auto"),
prompt_mode=os.environ.get("SCIVISUAL_PROMPT_MODE", "chat"),
device_map=os.environ.get("SCIVISUAL_DEVICE_MAP", "auto"),
load_in_4bit=os.environ.get("SCIVISUAL_LOAD_IN_4BIT", "1").lower() not in ("0", "false", "no"),
max_new_tokens=int(os.environ.get("SCIVISUAL_MAX_NEW_TOKENS", "8192")),
temperature=float(os.environ.get("SCIVISUAL_TEMPERATURE", "0")),
top_p=float(os.environ.get("SCIVISUAL_TOP_P", "0.9")),
use_stop_criteria=os.environ.get("SCIVISUAL_USE_STOP_CRITERIA", "1").lower() not in ("0", "false", "no"),
timeout_seconds=int(os.environ.get("SCIVISUAL_RENDER_TIMEOUT", "600")),
)
return ManimVisualAgent(config=config)