Spaces:
Running on Zero
Running on Zero
| """Local Manim model inference with render-time self correction. | |
| This module recreates the runtime ideas from manim-trainer's inference script | |
| without importing that repo: load a local Unsloth/Transformers model, generate | |
| Manim code, extract code blocks, render, and feed render errors back in. | |
| """ | |
| from __future__ import annotations | |
| import json | |
| import os | |
| import re | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| from typing import Dict, List, Optional, Tuple | |
| from executor import render_manim_scene | |
| from prompt_templates import build_feedback_messages, build_initial_messages | |
| DEFAULT_MODEL_PATH = "vankhieu/Seed_Coder_8B_Instruct_unsloth_bnb_4bit_lora_r8_sft_grpo_rw_mean_text_visual" | |
| DEFAULT_SELECTED_MODEL = "unsloth/Seed-Coder-8B-Instruct-unsloth-bnb-4bit" | |
| THINKING_TOKEN_ID = 151668 | |
| THINKING_TOKEN = "</think>" | |
| class LocalModelConfig: | |
| model_path: str = DEFAULT_MODEL_PATH | |
| selected_model: str = DEFAULT_SELECTED_MODEL | |
| model_kind: str = "adapter" | |
| base_model_path: Optional[str] = None | |
| backend: str = "auto" | |
| prompt_mode: str = "chat" | |
| device_map: str = "auto" | |
| load_in_4bit: bool = True | |
| max_new_tokens: int = 8192 | |
| temperature: float = 0.0 | |
| top_p: float = 0.9 | |
| use_stop_criteria: bool = True | |
| timeout_seconds: int = 300 | |
| class LocalManimModel: | |
| """Small inference wrapper compatible with Unsloth and plain Transformers.""" | |
| def __init__(self, config: LocalModelConfig) -> None: | |
| self.config = config | |
| self.model = None | |
| self.tokenizer = None | |
| self.backend_used = "" | |
| self.error: Optional[str] = None | |
| self.resolved_model_path: Optional[Path] = None | |
| model_name_for_rules = config.selected_model or config.model_path | |
| self.remove_token_type_ids = "seed-coder" in model_name_for_rules.lower() | |
| self.no_system_role_support = "codegemma" in model_name_for_rules.lower() | |
| self._load() | |
| def ready(self) -> bool: | |
| return self.model is not None and self.tokenizer is not None and self.error is None | |
| def _load(self) -> None: | |
| model_path = self._resolve_model_path() | |
| if model_path is None: | |
| return | |
| self.resolved_model_path = model_path | |
| requested = self.config.backend.lower() | |
| if self.config.model_kind == "adapter" and requested in ("auto", "unsloth"): | |
| if self._load_with_unsloth(model_path): | |
| return | |
| if requested == "unsloth": | |
| return | |
| self._load_with_transformers(model_path) | |
| def _resolve_model_path(self) -> Optional[Path]: | |
| configured_path = Path(self.config.model_path) | |
| if configured_path.exists(): | |
| return configured_path | |
| try: | |
| from huggingface_hub import snapshot_download | |
| except Exception as exc: | |
| self.error = ( | |
| f"Model path `{self.config.model_path}` is not local, and huggingface_hub " | |
| f"is unavailable for download: {type(exc).__name__}: {exc}" | |
| ) | |
| return None | |
| try: | |
| if self.config.model_kind == "base": | |
| downloaded_path = snapshot_download( | |
| repo_id=self.config.model_path, | |
| repo_type="model", | |
| ) | |
| else: | |
| downloaded_path = snapshot_download( | |
| repo_id=self.config.model_path, | |
| repo_type="model", | |
| allow_patterns=[ | |
| "README.md", | |
| "adapter_config.json", | |
| "adapter_model.safetensors", | |
| "chat_template.jinja", | |
| "special_tokens_map.json", | |
| "tokenizer.json", | |
| "tokenizer_config.json", | |
| ], | |
| ) | |
| return Path(downloaded_path) | |
| except Exception as exc: | |
| self.error = f"Failed to download model adapter `{self.config.model_path}`: {type(exc).__name__}: {exc}" | |
| return None | |
| def _load_with_unsloth(self, model_path: Path) -> bool: | |
| try: | |
| from unsloth import FastLanguageModel, FastModel | |
| except Exception as exc: | |
| self.error = f"Unsloth is unavailable: {type(exc).__name__}: {exc}" | |
| return False | |
| try: | |
| loader = FastModel if self._is_moe(self.config.selected_model) else FastLanguageModel | |
| model, tokenizer = loader.from_pretrained( | |
| model_name=str(model_path), | |
| max_seq_length=self.config.max_new_tokens, | |
| dtype=None, | |
| load_in_4bit=self.config.load_in_4bit, | |
| ) | |
| loader.for_inference(model) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| self.model = model | |
| self.tokenizer = tokenizer | |
| self.backend_used = "unsloth" | |
| self.error = None | |
| return True | |
| except Exception as exc: | |
| self.error = f"Failed to load with Unsloth: {type(exc).__name__}: {exc}" | |
| return False | |
| def _load_with_transformers(self, model_path: Path) -> None: | |
| try: | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig | |
| except Exception as exc: | |
| self.error = ( | |
| "Could not import local model dependencies. Install torch, transformers, " | |
| f"accelerate, peft, and optionally unsloth. Details: {type(exc).__name__}: {exc}" | |
| ) | |
| return | |
| try: | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| str(model_path), | |
| trust_remote_code=True, | |
| ) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| adapter_config = model_path / "adapter_config.json" | |
| if self.config.model_kind == "adapter" and adapter_config.exists(): | |
| from peft import PeftModel | |
| base_model_path = self.config.base_model_path or self._adapter_base_path(adapter_config) | |
| kwargs = self._transformers_load_kwargs(torch, BitsAndBytesConfig) | |
| base_model = AutoModelForCausalLM.from_pretrained( | |
| base_model_path, | |
| trust_remote_code=True, | |
| **kwargs, | |
| ) | |
| model = PeftModel.from_pretrained(base_model, str(model_path)) | |
| else: | |
| kwargs = self._transformers_load_kwargs(torch, BitsAndBytesConfig) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| str(model_path), | |
| trust_remote_code=True, | |
| **kwargs, | |
| ) | |
| model.eval() | |
| self.model = model | |
| self.tokenizer = tokenizer | |
| self.backend_used = "transformers" | |
| self.error = None | |
| except Exception as exc: | |
| self.error = f"Failed to load local model with Transformers: {type(exc).__name__}: {exc}" | |
| def _transformers_load_kwargs(self, torch, bits_and_bytes_config_cls) -> Dict[str, object]: | |
| if self.config.load_in_4bit: | |
| return { | |
| "device_map": self.config.device_map, | |
| "torch_dtype": "auto", | |
| "quantization_config": bits_and_bytes_config_cls( | |
| load_in_4bit=True, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_use_double_quant=False, | |
| bnb_4bit_compute_dtype=torch.bfloat16, | |
| ), | |
| } | |
| return {"device_map": self.config.device_map, "torch_dtype": "auto"} | |
| def _adapter_base_path(self, adapter_config: Path) -> str: | |
| if self.config.base_model_path: | |
| return self.config.base_model_path | |
| data = json.loads(adapter_config.read_text(encoding="utf-8")) | |
| base_model = data.get("base_model_name_or_path") | |
| if not base_model: | |
| raise ValueError("adapter_config.json does not define base_model_name_or_path.") | |
| return base_model | |
| def _is_moe(self, model_id: str) -> bool: | |
| lowered = model_id.lower() | |
| return any(marker in lowered for marker in ("mixtral", "moe", "deepseek-v2", "qwen3-moe")) | |
| def generate(self, messages: List[Dict[str, str]]) -> Tuple[str, str]: | |
| if not self.ready: | |
| raise RuntimeError(self.error or "Local model is not ready.") | |
| import torch | |
| if self.no_system_role_support and len(messages) >= 2 and messages[0]["role"] == "system": | |
| messages = [{"role": "user", "content": messages[0]["content"] + "\n\n" + messages[1]["content"]}] | |
| prompt = self._format_prompt(messages) | |
| model_inputs = self.tokenizer(text=[prompt], return_tensors="pt") | |
| model_device = getattr(self.model, "device", None) | |
| if model_device is not None: | |
| model_inputs = model_inputs.to(model_device) | |
| if self.remove_token_type_ids and "token_type_ids" in model_inputs: | |
| del model_inputs["token_type_ids"] | |
| generate_kwargs = { | |
| **model_inputs, | |
| "max_new_tokens": self.config.max_new_tokens, | |
| "pad_token_id": self.tokenizer.pad_token_id, | |
| "eos_token_id": self.tokenizer.eos_token_id, | |
| "use_cache": True, | |
| } | |
| if self.config.use_stop_criteria: | |
| from transformers import StoppingCriteriaList | |
| generate_kwargs["stopping_criteria"] = StoppingCriteriaList( | |
| [_StopOnTokenSequence(self.tokenizer.encode("</CODE>", add_special_tokens=False))] | |
| ) | |
| if self.config.temperature and self.config.temperature > 0: | |
| generate_kwargs.update( | |
| { | |
| "do_sample": True, | |
| "temperature": self.config.temperature, | |
| "top_p": self.config.top_p, | |
| } | |
| ) | |
| else: | |
| generate_kwargs["do_sample"] = False | |
| with torch.inference_mode(): | |
| generated_ids = self.model.generate(**generate_kwargs) | |
| output_ids = generated_ids[0][len(model_inputs.input_ids[0]) :].tolist() | |
| try: | |
| index = len(output_ids) - output_ids[::-1].index(THINKING_TOKEN_ID) | |
| except ValueError: | |
| index = 0 | |
| thinking = self.tokenizer.decode(output_ids[:index], skip_special_tokens=True).strip() | |
| completion = self.tokenizer.decode(output_ids[index:], skip_special_tokens=True).strip() | |
| if completion.startswith(THINKING_TOKEN): | |
| completion = completion.removeprefix(THINKING_TOKEN).strip() | |
| return thinking, completion | |
| def _format_prompt(self, messages: List[Dict[str, str]]) -> str: | |
| if self.config.prompt_mode == "chat" and getattr(self.tokenizer, "chat_template", None): | |
| try: | |
| return self.tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=True, | |
| enable_thinking=False, | |
| ) | |
| except TypeError: | |
| return self.tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=True, | |
| ) | |
| chunks = [] | |
| for message in messages: | |
| chunks.append(f"### {message['role'].upper()}\n{message['content'].strip()}") | |
| chunks.append("### ASSISTANT\n") | |
| return "\n\n".join(chunks) | |
| class _StopOnTokenSequence: | |
| """Stop generation when a specific token suffix appears.""" | |
| def __init__(self, stop_token_ids: List[int]) -> None: | |
| self.stop_token_ids = stop_token_ids | |
| def __call__(self, input_ids, scores, **kwargs) -> bool: | |
| del scores, kwargs | |
| if not self.stop_token_ids: | |
| return False | |
| if input_ids.shape[-1] < len(self.stop_token_ids): | |
| return False | |
| tail = input_ids[0, -len(self.stop_token_ids) :].tolist() | |
| return tail == self.stop_token_ids | |
| class ManimVisualAgent: | |
| """Generate Manim code with local inference and repair failures.""" | |
| def __init__(self, config: Optional[LocalModelConfig] = None, **overrides: object) -> None: | |
| if config is None: | |
| config = LocalModelConfig(**overrides) | |
| self.config = config | |
| self.runtime = LocalManimModel(config) | |
| self.model_path = config.model_path | |
| self.model_error = self.runtime.error | |
| def ready(self) -> bool: | |
| return self.runtime.ready | |
| def _model_status_line(self) -> str: | |
| resolved = self.runtime.resolved_model_path | |
| if resolved is None: | |
| return f"[model] backend=unavailable kind={self.config.model_kind} model={self.config.model_path}" | |
| return ( | |
| f"[model] backend={self.runtime.backend_used or 'unavailable'} " | |
| f"kind={self.config.model_kind} model={self.config.model_path} cache={resolved}" | |
| ) | |
| def generate_and_fix( | |
| self, | |
| user_prompt: str, | |
| domain: str = "Mathematics", | |
| max_retries: int = 3, | |
| ) -> Tuple[bool, Optional[str], str, str]: | |
| if not user_prompt or not user_prompt.strip(): | |
| return False, None, "", "Prompt is empty. Describe the concept to visualize." | |
| terminal_log_parts = [self._model_status_line()] | |
| messages = self._initial_messages(user_prompt, domain) | |
| best_code = "" | |
| total_attempts = max(1, int(max_retries) + 1) | |
| for attempt in range(1, total_attempts + 1): | |
| terminal_log_parts.append(f"[attempt {attempt}/{total_attempts}] Generating Manim code...") | |
| try: | |
| raw_output = self._call_or_fallback(messages) | |
| best_code = extract_manim_code(raw_output) | |
| if not best_code: | |
| best_code = raw_output.strip() | |
| except Exception as exc: | |
| return ( | |
| False, | |
| None, | |
| best_code, | |
| "\n".join(terminal_log_parts) | |
| + f"\nLocal model generation failed: {type(exc).__name__}: {exc}", | |
| ) | |
| terminal_log_parts.append("[executor] Rendering with isolated Manim subprocess...") | |
| success, render_result = render_manim_scene( | |
| best_code, | |
| timeout_seconds=self.config.timeout_seconds, | |
| ) | |
| if success: | |
| terminal_log_parts.append(f"[success] Rendered video: {render_result}") | |
| return True, render_result, best_code, "\n".join(terminal_log_parts) | |
| terminal_log_parts.append("[error] Manim failed. Feeding render errors back to the model.") | |
| terminal_log_parts.append(render_result) | |
| messages = self._feedback_messages(user_prompt, domain, best_code, render_result) | |
| terminal_log_parts.append("[failed] Maximum self-correction retries exhausted.") | |
| return False, None, best_code, "\n".join(terminal_log_parts) | |
| def _call_or_fallback(self, messages: List[Dict[str, str]]) -> str: | |
| if not self.ready: | |
| return f"<CODE>\n{self._fallback_scene(messages[-1]['content'])}\n</CODE>" | |
| _, completion = self.runtime.generate(messages) | |
| return completion | |
| def _initial_messages(self, user_prompt: str, domain: str) -> List[Dict[str, str]]: | |
| del domain | |
| return build_initial_messages(user_prompt) | |
| def _feedback_messages( | |
| self, | |
| user_prompt: str, | |
| domain: str, | |
| initial_code: str, | |
| render_errors: str, | |
| ) -> List[Dict[str, str]]: | |
| del domain | |
| return build_feedback_messages(user_prompt, initial_code, render_errors) | |
| def _fallback_scene(self, prompt: str) -> str: | |
| safe_message = (self.model_error or "Local model is not ready.").replace("\\", "\\\\").replace('"', '\\"') | |
| safe_prompt = prompt.replace("\\", "\\\\").replace('"', '\\"')[:260] | |
| return f"""from manim import * | |
| class MainScene(Scene): | |
| def construct(self): | |
| self.camera.background_color = "#0b0f19" | |
| title = Text("SciVisual-Agent Offline", color="#00ffcc", font_size=42) | |
| subtitle = Text("Local model could not be initialized.", color="#ff0055", font_size=24) | |
| detail = Text("{safe_message}", color=GRAY_B, font_size=16).scale_to_fit_width(config.frame_width - 1) | |
| prompt = Text("Prompt: {safe_prompt}", color=WHITE, font_size=18).scale_to_fit_width(config.frame_width - 1) | |
| group = VGroup(title, subtitle, detail, prompt).arrange(DOWN, buff=0.35) | |
| self.play(FadeIn(group, shift=UP * 0.2)) | |
| self.wait(2) | |
| """ | |
| def extract_manim_code(response: str, select_index: int = -1) -> str: | |
| """Extract Manim code from manim-trainer-style responses.""" | |
| response = response or "" | |
| response = re.sub(r"`<CODE>`|`</CODE>`", "", response) | |
| response = ( | |
| response.replace("<code>", "<CODE>") | |
| .replace("</code>", "</CODE>") | |
| .replace("CODE>", "<CODE>") | |
| .replace("</<CODE>", "</CODE>") | |
| ) | |
| matches = re.findall(r"<CODE>(.*?)</CODE>", response, re.DOTALL) | |
| if not matches: | |
| matches = re.findall(r"```python\s*(.*?)```", response, re.DOTALL | re.IGNORECASE) | |
| if not matches: | |
| matches = re.findall(r"```\s*(.*?)```", response, re.DOTALL) | |
| if matches: | |
| return _clean_extracted_code(matches[select_index]) | |
| if "from manim import" in response and "class " in response: | |
| return _clean_extracted_code(response) | |
| return "" | |
| def _clean_extracted_code(code: str) -> str: | |
| """Remove wrapper tags and markdown fences from extracted Manim code.""" | |
| code = code or "" | |
| code = re.sub(r"</?CODE>", "", code, flags=re.IGNORECASE) | |
| code = re.sub(r"^\s*```(?:python)?\s*", "", code, flags=re.IGNORECASE) | |
| code = re.sub(r"\s*```\s*$", "", code) | |
| return code.strip() | |
| def build_default_agent() -> ManimVisualAgent: | |
| model_kind = os.environ.get("SCIVISUAL_MODEL_KIND", "adapter").strip().lower() | |
| if model_kind not in {"adapter", "base"}: | |
| model_kind = "adapter" | |
| config = LocalModelConfig( | |
| model_path=os.environ.get("SCIVISUAL_MODEL_PATH", DEFAULT_MODEL_PATH), | |
| selected_model=os.environ.get("SCIVISUAL_SELECTED_MODEL", DEFAULT_SELECTED_MODEL), | |
| model_kind=model_kind, | |
| base_model_path=os.environ.get("SCIVISUAL_BASE_MODEL_PATH") or None, | |
| backend=os.environ.get("SCIVISUAL_BACKEND", "auto"), | |
| prompt_mode=os.environ.get("SCIVISUAL_PROMPT_MODE", "chat"), | |
| device_map=os.environ.get("SCIVISUAL_DEVICE_MAP", "auto"), | |
| load_in_4bit=os.environ.get("SCIVISUAL_LOAD_IN_4BIT", "1").lower() not in ("0", "false", "no"), | |
| max_new_tokens=int(os.environ.get("SCIVISUAL_MAX_NEW_TOKENS", "8192")), | |
| temperature=float(os.environ.get("SCIVISUAL_TEMPERATURE", "0")), | |
| top_p=float(os.environ.get("SCIVISUAL_TOP_P", "0.9")), | |
| use_stop_criteria=os.environ.get("SCIVISUAL_USE_STOP_CRITERIA", "1").lower() not in ("0", "false", "no"), | |
| timeout_seconds=int(os.environ.get("SCIVISUAL_RENDER_TIMEOUT", "600")), | |
| ) | |
| return ManimVisualAgent(config=config) | |