"""Local Manim model inference with render-time self correction. This module recreates the runtime ideas from manim-trainer's inference script without importing that repo: load a local Unsloth/Transformers model, generate Manim code, extract code blocks, render, and feed render errors back in. """ from __future__ import annotations import json import os import re from dataclasses import dataclass from pathlib import Path from typing import Dict, List, Optional, Tuple from executor import render_manim_scene from prompt_templates import build_feedback_messages, build_initial_messages DEFAULT_MODEL_PATH = "vankhieu/Seed_Coder_8B_Instruct_unsloth_bnb_4bit_lora_r8_sft_grpo_rw_mean_text_visual" DEFAULT_SELECTED_MODEL = "unsloth/Seed-Coder-8B-Instruct-unsloth-bnb-4bit" THINKING_TOKEN_ID = 151668 THINKING_TOKEN = "" @dataclass class LocalModelConfig: model_path: str = DEFAULT_MODEL_PATH selected_model: str = DEFAULT_SELECTED_MODEL model_kind: str = "adapter" base_model_path: Optional[str] = None backend: str = "auto" prompt_mode: str = "chat" device_map: str = "auto" load_in_4bit: bool = True max_new_tokens: int = 8192 temperature: float = 0.0 top_p: float = 0.9 use_stop_criteria: bool = True timeout_seconds: int = 300 class LocalManimModel: """Small inference wrapper compatible with Unsloth and plain Transformers.""" def __init__(self, config: LocalModelConfig) -> None: self.config = config self.model = None self.tokenizer = None self.backend_used = "" self.error: Optional[str] = None self.resolved_model_path: Optional[Path] = None model_name_for_rules = config.selected_model or config.model_path self.remove_token_type_ids = "seed-coder" in model_name_for_rules.lower() self.no_system_role_support = "codegemma" in model_name_for_rules.lower() self._load() @property def ready(self) -> bool: return self.model is not None and self.tokenizer is not None and self.error is None def _load(self) -> None: model_path = self._resolve_model_path() if model_path is None: return self.resolved_model_path = model_path requested = self.config.backend.lower() if self.config.model_kind == "adapter" and requested in ("auto", "unsloth"): if self._load_with_unsloth(model_path): return if requested == "unsloth": return self._load_with_transformers(model_path) def _resolve_model_path(self) -> Optional[Path]: configured_path = Path(self.config.model_path) if configured_path.exists(): return configured_path try: from huggingface_hub import snapshot_download except Exception as exc: self.error = ( f"Model path `{self.config.model_path}` is not local, and huggingface_hub " f"is unavailable for download: {type(exc).__name__}: {exc}" ) return None try: if self.config.model_kind == "base": downloaded_path = snapshot_download( repo_id=self.config.model_path, repo_type="model", ) else: downloaded_path = snapshot_download( repo_id=self.config.model_path, repo_type="model", allow_patterns=[ "README.md", "adapter_config.json", "adapter_model.safetensors", "chat_template.jinja", "special_tokens_map.json", "tokenizer.json", "tokenizer_config.json", ], ) return Path(downloaded_path) except Exception as exc: self.error = f"Failed to download model adapter `{self.config.model_path}`: {type(exc).__name__}: {exc}" return None def _load_with_unsloth(self, model_path: Path) -> bool: try: from unsloth import FastLanguageModel, FastModel except Exception as exc: self.error = f"Unsloth is unavailable: {type(exc).__name__}: {exc}" return False try: loader = FastModel if self._is_moe(self.config.selected_model) else FastLanguageModel model, tokenizer = loader.from_pretrained( model_name=str(model_path), max_seq_length=self.config.max_new_tokens, dtype=None, load_in_4bit=self.config.load_in_4bit, ) loader.for_inference(model) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token self.model = model self.tokenizer = tokenizer self.backend_used = "unsloth" self.error = None return True except Exception as exc: self.error = f"Failed to load with Unsloth: {type(exc).__name__}: {exc}" return False def _load_with_transformers(self, model_path: Path) -> None: try: import torch from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig except Exception as exc: self.error = ( "Could not import local model dependencies. Install torch, transformers, " f"accelerate, peft, and optionally unsloth. Details: {type(exc).__name__}: {exc}" ) return try: tokenizer = AutoTokenizer.from_pretrained( str(model_path), trust_remote_code=True, ) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token adapter_config = model_path / "adapter_config.json" if self.config.model_kind == "adapter" and adapter_config.exists(): from peft import PeftModel base_model_path = self.config.base_model_path or self._adapter_base_path(adapter_config) kwargs = self._transformers_load_kwargs(torch, BitsAndBytesConfig) base_model = AutoModelForCausalLM.from_pretrained( base_model_path, trust_remote_code=True, **kwargs, ) model = PeftModel.from_pretrained(base_model, str(model_path)) else: kwargs = self._transformers_load_kwargs(torch, BitsAndBytesConfig) model = AutoModelForCausalLM.from_pretrained( str(model_path), trust_remote_code=True, **kwargs, ) model.eval() self.model = model self.tokenizer = tokenizer self.backend_used = "transformers" self.error = None except Exception as exc: self.error = f"Failed to load local model with Transformers: {type(exc).__name__}: {exc}" def _transformers_load_kwargs(self, torch, bits_and_bytes_config_cls) -> Dict[str, object]: if self.config.load_in_4bit: return { "device_map": self.config.device_map, "torch_dtype": "auto", "quantization_config": bits_and_bytes_config_cls( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_use_double_quant=False, bnb_4bit_compute_dtype=torch.bfloat16, ), } return {"device_map": self.config.device_map, "torch_dtype": "auto"} def _adapter_base_path(self, adapter_config: Path) -> str: if self.config.base_model_path: return self.config.base_model_path data = json.loads(adapter_config.read_text(encoding="utf-8")) base_model = data.get("base_model_name_or_path") if not base_model: raise ValueError("adapter_config.json does not define base_model_name_or_path.") return base_model def _is_moe(self, model_id: str) -> bool: lowered = model_id.lower() return any(marker in lowered for marker in ("mixtral", "moe", "deepseek-v2", "qwen3-moe")) def generate(self, messages: List[Dict[str, str]]) -> Tuple[str, str]: if not self.ready: raise RuntimeError(self.error or "Local model is not ready.") import torch if self.no_system_role_support and len(messages) >= 2 and messages[0]["role"] == "system": messages = [{"role": "user", "content": messages[0]["content"] + "\n\n" + messages[1]["content"]}] prompt = self._format_prompt(messages) model_inputs = self.tokenizer(text=[prompt], return_tensors="pt") model_device = getattr(self.model, "device", None) if model_device is not None: model_inputs = model_inputs.to(model_device) if self.remove_token_type_ids and "token_type_ids" in model_inputs: del model_inputs["token_type_ids"] generate_kwargs = { **model_inputs, "max_new_tokens": self.config.max_new_tokens, "pad_token_id": self.tokenizer.pad_token_id, "eos_token_id": self.tokenizer.eos_token_id, "use_cache": True, } if self.config.use_stop_criteria: from transformers import StoppingCriteriaList generate_kwargs["stopping_criteria"] = StoppingCriteriaList( [_StopOnTokenSequence(self.tokenizer.encode("", add_special_tokens=False))] ) if self.config.temperature and self.config.temperature > 0: generate_kwargs.update( { "do_sample": True, "temperature": self.config.temperature, "top_p": self.config.top_p, } ) else: generate_kwargs["do_sample"] = False with torch.inference_mode(): generated_ids = self.model.generate(**generate_kwargs) output_ids = generated_ids[0][len(model_inputs.input_ids[0]) :].tolist() try: index = len(output_ids) - output_ids[::-1].index(THINKING_TOKEN_ID) except ValueError: index = 0 thinking = self.tokenizer.decode(output_ids[:index], skip_special_tokens=True).strip() completion = self.tokenizer.decode(output_ids[index:], skip_special_tokens=True).strip() if completion.startswith(THINKING_TOKEN): completion = completion.removeprefix(THINKING_TOKEN).strip() return thinking, completion def _format_prompt(self, messages: List[Dict[str, str]]) -> str: if self.config.prompt_mode == "chat" and getattr(self.tokenizer, "chat_template", None): try: return self.tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True, enable_thinking=False, ) except TypeError: return self.tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True, ) chunks = [] for message in messages: chunks.append(f"### {message['role'].upper()}\n{message['content'].strip()}") chunks.append("### ASSISTANT\n") return "\n\n".join(chunks) class _StopOnTokenSequence: """Stop generation when a specific token suffix appears.""" def __init__(self, stop_token_ids: List[int]) -> None: self.stop_token_ids = stop_token_ids def __call__(self, input_ids, scores, **kwargs) -> bool: del scores, kwargs if not self.stop_token_ids: return False if input_ids.shape[-1] < len(self.stop_token_ids): return False tail = input_ids[0, -len(self.stop_token_ids) :].tolist() return tail == self.stop_token_ids class ManimVisualAgent: """Generate Manim code with local inference and repair failures.""" def __init__(self, config: Optional[LocalModelConfig] = None, **overrides: object) -> None: if config is None: config = LocalModelConfig(**overrides) self.config = config self.runtime = LocalManimModel(config) self.model_path = config.model_path self.model_error = self.runtime.error @property def ready(self) -> bool: return self.runtime.ready def _model_status_line(self) -> str: resolved = self.runtime.resolved_model_path if resolved is None: return f"[model] backend=unavailable kind={self.config.model_kind} model={self.config.model_path}" return ( f"[model] backend={self.runtime.backend_used or 'unavailable'} " f"kind={self.config.model_kind} model={self.config.model_path} cache={resolved}" ) def generate_and_fix( self, user_prompt: str, domain: str = "Mathematics", max_retries: int = 3, ) -> Tuple[bool, Optional[str], str, str]: if not user_prompt or not user_prompt.strip(): return False, None, "", "Prompt is empty. Describe the concept to visualize." terminal_log_parts = [self._model_status_line()] messages = self._initial_messages(user_prompt, domain) best_code = "" total_attempts = max(1, int(max_retries) + 1) for attempt in range(1, total_attempts + 1): terminal_log_parts.append(f"[attempt {attempt}/{total_attempts}] Generating Manim code...") try: raw_output = self._call_or_fallback(messages) best_code = extract_manim_code(raw_output) if not best_code: best_code = raw_output.strip() except Exception as exc: return ( False, None, best_code, "\n".join(terminal_log_parts) + f"\nLocal model generation failed: {type(exc).__name__}: {exc}", ) terminal_log_parts.append("[executor] Rendering with isolated Manim subprocess...") success, render_result = render_manim_scene( best_code, timeout_seconds=self.config.timeout_seconds, ) if success: terminal_log_parts.append(f"[success] Rendered video: {render_result}") return True, render_result, best_code, "\n".join(terminal_log_parts) terminal_log_parts.append("[error] Manim failed. Feeding render errors back to the model.") terminal_log_parts.append(render_result) messages = self._feedback_messages(user_prompt, domain, best_code, render_result) terminal_log_parts.append("[failed] Maximum self-correction retries exhausted.") return False, None, best_code, "\n".join(terminal_log_parts) def _call_or_fallback(self, messages: List[Dict[str, str]]) -> str: if not self.ready: return f"\n{self._fallback_scene(messages[-1]['content'])}\n" _, completion = self.runtime.generate(messages) return completion def _initial_messages(self, user_prompt: str, domain: str) -> List[Dict[str, str]]: del domain return build_initial_messages(user_prompt) def _feedback_messages( self, user_prompt: str, domain: str, initial_code: str, render_errors: str, ) -> List[Dict[str, str]]: del domain return build_feedback_messages(user_prompt, initial_code, render_errors) def _fallback_scene(self, prompt: str) -> str: safe_message = (self.model_error or "Local model is not ready.").replace("\\", "\\\\").replace('"', '\\"') safe_prompt = prompt.replace("\\", "\\\\").replace('"', '\\"')[:260] return f"""from manim import * class MainScene(Scene): def construct(self): self.camera.background_color = "#0b0f19" title = Text("SciVisual-Agent Offline", color="#00ffcc", font_size=42) subtitle = Text("Local model could not be initialized.", color="#ff0055", font_size=24) detail = Text("{safe_message}", color=GRAY_B, font_size=16).scale_to_fit_width(config.frame_width - 1) prompt = Text("Prompt: {safe_prompt}", color=WHITE, font_size=18).scale_to_fit_width(config.frame_width - 1) group = VGroup(title, subtitle, detail, prompt).arrange(DOWN, buff=0.35) self.play(FadeIn(group, shift=UP * 0.2)) self.wait(2) """ def extract_manim_code(response: str, select_index: int = -1) -> str: """Extract Manim code from manim-trainer-style responses.""" response = response or "" response = re.sub(r"``|``", "", response) response = ( response.replace("", "") .replace("", "") .replace("CODE>", "") .replace("", "") ) matches = re.findall(r"(.*?)", response, re.DOTALL) if not matches: matches = re.findall(r"```python\s*(.*?)```", response, re.DOTALL | re.IGNORECASE) if not matches: matches = re.findall(r"```\s*(.*?)```", response, re.DOTALL) if matches: return _clean_extracted_code(matches[select_index]) if "from manim import" in response and "class " in response: return _clean_extracted_code(response) return "" def _clean_extracted_code(code: str) -> str: """Remove wrapper tags and markdown fences from extracted Manim code.""" code = code or "" code = re.sub(r"", "", code, flags=re.IGNORECASE) code = re.sub(r"^\s*```(?:python)?\s*", "", code, flags=re.IGNORECASE) code = re.sub(r"\s*```\s*$", "", code) return code.strip() def build_default_agent() -> ManimVisualAgent: model_kind = os.environ.get("SCIVISUAL_MODEL_KIND", "adapter").strip().lower() if model_kind not in {"adapter", "base"}: model_kind = "adapter" config = LocalModelConfig( model_path=os.environ.get("SCIVISUAL_MODEL_PATH", DEFAULT_MODEL_PATH), selected_model=os.environ.get("SCIVISUAL_SELECTED_MODEL", DEFAULT_SELECTED_MODEL), model_kind=model_kind, base_model_path=os.environ.get("SCIVISUAL_BASE_MODEL_PATH") or None, backend=os.environ.get("SCIVISUAL_BACKEND", "auto"), prompt_mode=os.environ.get("SCIVISUAL_PROMPT_MODE", "chat"), device_map=os.environ.get("SCIVISUAL_DEVICE_MAP", "auto"), load_in_4bit=os.environ.get("SCIVISUAL_LOAD_IN_4BIT", "1").lower() not in ("0", "false", "no"), max_new_tokens=int(os.environ.get("SCIVISUAL_MAX_NEW_TOKENS", "8192")), temperature=float(os.environ.get("SCIVISUAL_TEMPERATURE", "0")), top_p=float(os.environ.get("SCIVISUAL_TOP_P", "0.9")), use_stop_criteria=os.environ.get("SCIVISUAL_USE_STOP_CRITERIA", "1").lower() not in ("0", "false", "no"), timeout_seconds=int(os.environ.get("SCIVISUAL_RENDER_TIMEOUT", "600")), ) return ManimVisualAgent(config=config)