Spaces:
Sleeping
Sleeping
| """ | |
| SeqTex Texture Generator — Hugging Face Space | |
| ============================================= | |
| Startup-safe + Storage Bucket + separate WAN bucket + cached-model-only Generate. | |
| Plain Python file. Do not paste Markdown fences such as ```python into app.py. | |
| Recommended buckets: | |
| - Existing/general bucket mounted at: /data | |
| - New WAN bucket mounted at: /wan-cache | |
| What this version does: | |
| - Launches Gradio first. | |
| - Keeps SeqTex Space utilities and SeqTex-Transformer cache under /data. | |
| - Stores the large WAN base model in the separate bucket /wan-cache. | |
| - Uses CPU/network prewarm from the Cache / Startup tab. | |
| - Uses ZeroGPU only for Generate Texture. | |
| - During Generate, forces all known model repo IDs to local cached paths. | |
| - If the cache is incomplete, Generate fails fast instead of downloading during ZeroGPU. | |
| - Includes nvdiffrast_plugin and Diffusers WAN compatibility patches. | |
| """ | |
| from __future__ import annotations | |
| # --------------------------------------------------------------------------- | |
| # 0. Minimal startup section | |
| # --------------------------------------------------------------------------- | |
| import fnmatch | |
| import importlib | |
| import logging | |
| import os | |
| import pickle | |
| import shutil | |
| import sys | |
| import tempfile | |
| import threading | |
| import time | |
| import traceback | |
| from typing import Any | |
| print("[BOOT 00] app.py started", flush=True) | |
| # General persistent bucket. Your existing bucket can stay here. | |
| PERSISTENT_ROOT = os.getenv("PERSISTENT_ROOT", "/data") | |
| CACHE_ROOT = os.getenv("CACHE_ROOT", os.path.join(PERSISTENT_ROOT, "hf_home")) | |
| # New dedicated WAN bucket. Mount the new bucket here in HF Settings -> Storage Buckets. | |
| WAN_BUCKET_ROOT = os.getenv("WAN_BUCKET_ROOT", "/wan-cache") | |
| # Hugging Face cache for SeqTex-Transformer and normal Hub cache. | |
| os.environ.setdefault("HF_HOME", CACHE_ROOT) | |
| os.environ.setdefault("HF_HUB_CACHE", os.path.join(CACHE_ROOT, "hub")) | |
| os.environ.setdefault("TRANSFORMERS_CACHE", os.path.join(CACHE_ROOT, "hub")) | |
| os.environ.setdefault("DIFFUSERS_CACHE", os.path.join(CACHE_ROOT, "hub")) | |
| os.environ.setdefault("TORCH_HOME", os.path.join(PERSISTENT_ROOT, "torch")) | |
| # nvdiffrast compiles a small runtime extension. Keep this on local tmp. | |
| # Do not use /data or /wan-cache for Torch extensions. | |
| os.environ.setdefault("TORCH_EXTENSIONS_DIR", os.getenv("TORCH_EXTENSIONS_DIR", "/tmp/torch_extensions")) | |
| os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "0") | |
| os.environ.setdefault("GRADIO_SSR_MODE", "False") | |
| os.environ.setdefault("GRADIO_ANALYTICS_ENABLED", "False") | |
| # During Generate, use only cached/local files. Prewarm can still download. | |
| FORCE_LOCAL_GENERATE = os.getenv("FORCE_LOCAL_GENERATE", "1").strip().lower() in {"1", "true", "yes", "on"} | |
| SEQTEX_SPACE_REPO = os.getenv("SEQTEX_SPACE_REPO", "VAST-AI/SeqTex") | |
| SEQTEX_MODEL_REPO = os.getenv("SEQTEX_MODEL_REPO", "VAST-AI/SeqTex-Transformer") | |
| SEQTEX_SPACE_DIR = os.getenv("SEQTEX_SPACE_DIR", os.path.join(PERSISTENT_ROOT, "seqtex_space")) | |
| # SeqTex's official loader pulls this WAN base model at runtime. | |
| DEFAULT_WAN_MODEL_REPO = os.getenv("DEFAULT_WAN_MODEL_REPO", "Wan-AI/Wan2.1-T2V-1.3B-Diffusers") | |
| WAN_LOCAL_MODEL_DIR = os.getenv( | |
| "WAN_LOCAL_MODEL_DIR", | |
| os.path.join(WAN_BUCKET_ROOT, "models", "Wan-AI", "Wan2.1-T2V-1.3B-Diffusers"), | |
| ) | |
| # Optional comma-separated extra model repos. By default, include WAN because SeqTex needs it. | |
| _extra_repos_env = os.getenv("EXTRA_MODEL_REPOS") | |
| if _extra_repos_env is None: | |
| _extra_repos_env = DEFAULT_WAN_MODEL_REPO | |
| EXTRA_MODEL_REPOS = [repo.strip() for repo in _extra_repos_env.split(",") if repo.strip()] | |
| AUTO_PREWARM = os.getenv("AUTO_PREWARM", "0").strip().lower() in {"1", "true", "yes", "on"} | |
| AUTO_PREWARM_DELAY_SECONDS = int(os.getenv("AUTO_PREWARM_DELAY_SECONDS", "10")) | |
| print("[BOOT 01] environment variables set", flush=True) | |
| # --------------------------------------------------------------------------- | |
| # 1. Light imports only | |
| # --------------------------------------------------------------------------- | |
| try: | |
| import gradio as gr | |
| print("[BOOT 02] gradio imported", flush=True) | |
| except Exception: | |
| print("[BOOT ERROR] gradio import failed", flush=True) | |
| raise | |
| try: | |
| import spaces | |
| print("[BOOT 03] spaces imported", flush=True) | |
| except Exception as exc: | |
| print(f"[BOOT WARN] spaces import failed: {exc}", flush=True) | |
| spaces = None | |
| try: | |
| from PIL import Image | |
| print("[BOOT 04] PIL imported", flush=True) | |
| except Exception: | |
| print("[BOOT ERROR] PIL import failed", flush=True) | |
| raise | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="%(asctime)s %(levelname)s %(name)s: %(message)s", | |
| handlers=[logging.StreamHandler(sys.stdout)], | |
| force=True, | |
| ) | |
| log = logging.getLogger("seqtex-app") | |
| log.info("PERSISTENT_ROOT=%s", PERSISTENT_ROOT) | |
| log.info("WAN_BUCKET_ROOT=%s", WAN_BUCKET_ROOT) | |
| log.info("HF_HOME=%s", os.getenv("HF_HOME")) | |
| log.info("HF_HUB_CACHE=%s", os.getenv("HF_HUB_CACHE")) | |
| log.info("TORCH_HOME=%s", os.getenv("TORCH_HOME")) | |
| log.info("TORCH_EXTENSIONS_DIR=%s", os.getenv("TORCH_EXTENSIONS_DIR")) | |
| log.info("SEQTEX_SPACE_REPO=%s", SEQTEX_SPACE_REPO) | |
| log.info("SEQTEX_MODEL_REPO=%s", SEQTEX_MODEL_REPO) | |
| log.info("DEFAULT_WAN_MODEL_REPO=%s", DEFAULT_WAN_MODEL_REPO) | |
| log.info("WAN_LOCAL_MODEL_DIR=%s", WAN_LOCAL_MODEL_DIR) | |
| log.info("SEQTEX_SPACE_DIR=%s", SEQTEX_SPACE_DIR) | |
| log.info("AUTO_PREWARM=%s", AUTO_PREWARM) | |
| log.info("EXTRA_MODEL_REPOS=%s", EXTRA_MODEL_REPOS) | |
| log.info("FORCE_LOCAL_GENERATE=%s", FORCE_LOCAL_GENERATE) | |
| def _patch_zero_startup_report_timeout() -> None: | |
| """ | |
| Avoid rare startup crash when ZeroGPU local startup-report API times out. | |
| This does not disable @spaces.GPU for Generate. | |
| """ | |
| if spaces is None: | |
| return | |
| try: | |
| import httpx | |
| import spaces.zero.client as zero_client | |
| except Exception as exc: | |
| print(f"[BOOT WARN] could not import ZeroGPU client for timeout patch: {exc}", flush=True) | |
| return | |
| original_startup_report = getattr(zero_client, "startup_report", None) | |
| if original_startup_report is None or getattr(original_startup_report, "_seqtex_timeout_safe", False): | |
| return | |
| def timeout_safe_startup_report(*args, **kwargs): | |
| last_exc: Exception | None = None | |
| for attempt in range(1, 4): | |
| try: | |
| print(f"[BOOT ZG] ZeroGPU startup_report attempt {attempt}/3", flush=True) | |
| return original_startup_report(*args, **kwargs) | |
| except httpx.TimeoutException as exc: | |
| last_exc = exc | |
| print(f"[BOOT WARN] ZeroGPU startup_report timed out on attempt {attempt}/3: {exc}", flush=True) | |
| time.sleep(1.0) | |
| except Exception as exc: | |
| name = exc.__class__.__name__.lower() | |
| msg = str(exc).lower() | |
| if "timeout" in name or "timed out" in msg: | |
| last_exc = exc | |
| print(f"[BOOT WARN] ZeroGPU startup_report timeout-like error on attempt {attempt}/3: {exc}", flush=True) | |
| time.sleep(1.0) | |
| continue | |
| raise | |
| print( | |
| f"[BOOT WARN] ZeroGPU startup_report failed after retries; continuing launch anyway. Last error: {last_exc}", | |
| flush=True, | |
| ) | |
| return None | |
| timeout_safe_startup_report._seqtex_timeout_safe = True | |
| zero_client.startup_report = timeout_safe_startup_report | |
| print("[BOOT 04B] ZeroGPU startup_report timeout patch installed", flush=True) | |
| _patch_zero_startup_report_timeout() | |
| # --------------------------------------------------------------------------- | |
| # 2. Global state and helpers | |
| # --------------------------------------------------------------------------- | |
| _seqtex_modules: dict[str, Any] | None = None | |
| _seqtex_pipe: Any | None = None | |
| _prewarm_lock = threading.Lock() | |
| _prewarm_thread: threading.Thread | None = None | |
| _prewarm_status: dict[str, Any] = { | |
| "state": "idle", | |
| "started_at": None, | |
| "finished_at": None, | |
| "last_error": None, | |
| "log": [], | |
| } | |
| MAX_PREWARM_LOG_LINES = 400 | |
| _CPP_EXTENSION_LOAD_PATCHED = False | |
| _DIFFUSERS_WAN_CONFIG_PATCHED = False | |
| _CACHE_PATCHED = False | |
| _LOCAL_GENERATE_MODE = False | |
| _CACHED_REPO_PATHS: dict[tuple[str, str], str] = {} | |
| class StartupFixError(RuntimeError): | |
| """Friendly configuration/runtime error shown in the UI.""" | |
| def _gpu_decorator(duration: int = 120): | |
| def _decorator(fn): | |
| if spaces is None: | |
| log.warning("spaces module unavailable; running without ZeroGPU decorator") | |
| return fn | |
| return spaces.GPU(duration=duration)(fn) | |
| return _decorator | |
| def _append_prewarm_log(message: str) -> None: | |
| line = f"{time.strftime('%H:%M:%S')} | {message}" | |
| log.info("PREWARM: %s", message) | |
| _prewarm_status["log"].append(line) | |
| if len(_prewarm_status["log"]) > MAX_PREWARM_LOG_LINES: | |
| _prewarm_status["log"] = _prewarm_status["log"][-MAX_PREWARM_LOG_LINES:] | |
| def _prewarm_log_text() -> str: | |
| header = [ | |
| f"state: {_prewarm_status.get('state', 'unknown')}", | |
| f"started_at: {_prewarm_status.get('started_at') or '-'}", | |
| f"finished_at: {_prewarm_status.get('finished_at') or '-'}", | |
| f"last_error: {_prewarm_status.get('last_error') or '-'}", | |
| "", | |
| "logs:", | |
| ] | |
| return "\n".join(header + list(_prewarm_status.get("log", []))) | |
| def _get_hf_token() -> str | None: | |
| token = ( | |
| os.getenv("SEQTEX_SPACE_TOKEN") | |
| or os.getenv("HF_TOKEN") | |
| or os.getenv("HUGGINGFACE_HUB_TOKEN") | |
| or os.getenv("HUGGING_FACE_HUB_TOKEN") | |
| ) | |
| if token: | |
| os.environ.setdefault("SEQTEX_SPACE_TOKEN", token) | |
| os.environ.setdefault("HF_TOKEN", token) | |
| os.environ.setdefault("HUGGINGFACE_HUB_TOKEN", token) | |
| os.environ.setdefault("HUGGING_FACE_HUB_TOKEN", token) | |
| return token | |
| def _ensure_seqtex_token() -> None: | |
| if not _get_hf_token(): | |
| raise StartupFixError( | |
| "Missing Hugging Face token secret. Add SEQTEX_SPACE_TOKEN or HF_TOKEN " | |
| "in Settings -> Variables and secrets." | |
| ) | |
| def _prepare_runtime_dirs() -> None: | |
| for path in [ | |
| os.getenv("HF_HOME"), | |
| os.getenv("HF_HUB_CACHE"), | |
| os.getenv("TORCH_HOME"), | |
| os.getenv("TORCH_EXTENSIONS_DIR"), | |
| SEQTEX_SPACE_DIR, | |
| WAN_BUCKET_ROOT, | |
| WAN_LOCAL_MODEL_DIR, | |
| ]: | |
| if not path: | |
| continue | |
| os.makedirs(path, exist_ok=True) | |
| _append_prewarm_log(f"directory ready: {path}") | |
| def _skip_file(filename: str, ignore_patterns: list[str]) -> bool: | |
| return any(fnmatch.fnmatch(filename, pattern) for pattern in ignore_patterns) | |
| def _repo_name_tail(repo_id: str) -> str: | |
| return repo_id.rstrip("/").split("/")[-1] | |
| def _is_wan_repo(repo_id: str) -> bool: | |
| return repo_id.strip() == DEFAULT_WAN_MODEL_REPO | |
| def _is_wan_local_dir_ready() -> bool: | |
| required = [ | |
| "model_index.json", | |
| "scheduler/scheduler_config.json", | |
| "text_encoder/config.json", | |
| "text_encoder/model.safetensors.index.json", | |
| "tokenizer/tokenizer_config.json", | |
| "transformer/config.json", | |
| "transformer/diffusion_pytorch_model.safetensors.index.json", | |
| "vae/config.json", | |
| "vae/diffusion_pytorch_model.safetensors", | |
| ] | |
| return all(os.path.exists(os.path.join(WAN_LOCAL_MODEL_DIR, item)) for item in required) | |
| def _download_repo_files_with_logs( | |
| *, | |
| repo_id: str, | |
| repo_type: str, | |
| local_dir: str | None = None, | |
| ignore_patterns: list[str] | None = None, | |
| progress: gr.Progress | None = None, | |
| progress_start: float = 0.0, | |
| progress_end: float = 1.0, | |
| ) -> None: | |
| """ | |
| Download/check a Hub repo file-by-file. | |
| If local_dir is None, files go to HF_HUB_CACHE. | |
| If local_dir is set, files go to that folder as a normal local snapshot. | |
| """ | |
| from huggingface_hub import HfApi, hf_hub_download | |
| token = _get_hf_token() | |
| ignore_patterns = ignore_patterns or [] | |
| _append_prewarm_log(f"listing {repo_type} repo: {repo_id}") | |
| api = HfApi(token=token) | |
| all_files = api.list_repo_files(repo_id=repo_id, repo_type=repo_type) | |
| files = [f for f in all_files if not _skip_file(f, ignore_patterns)] | |
| total = len(files) | |
| _append_prewarm_log(f"{repo_id}: {total} files to check/download") | |
| if total == 0: | |
| return | |
| if local_dir: | |
| os.makedirs(local_dir, exist_ok=True) | |
| _append_prewarm_log(f"{repo_id}: local_dir={local_dir}") | |
| for index, filename in enumerate(files, start=1): | |
| frac = progress_start + ((index - 1) / total) * (progress_end - progress_start) | |
| if progress is not None: | |
| progress(frac, desc=f"Caching {repo_id}: {index}/{total} {filename}") | |
| _append_prewarm_log(f"{repo_id}: [{index}/{total}] {filename}") | |
| kwargs: dict[str, Any] = { | |
| "repo_id": repo_id, | |
| "repo_type": repo_type, | |
| "filename": filename, | |
| "token": token, | |
| "force_download": False, | |
| "resume_download": True, | |
| } | |
| if local_dir: | |
| kwargs["local_dir"] = local_dir | |
| else: | |
| kwargs["cache_dir"] = os.getenv("HF_HUB_CACHE") | |
| try: | |
| hf_hub_download(**kwargs) | |
| except TypeError: | |
| # Compatibility for older huggingface_hub versions. | |
| kwargs.pop("resume_download", None) | |
| kwargs.pop("force_download", None) | |
| hf_hub_download(**kwargs) | |
| if progress is not None: | |
| progress(progress_end, desc=f"Cached {repo_id}") | |
| _append_prewarm_log(f"finished caching {repo_id}") | |
| def prewarm_cache_impl(progress: gr.Progress | None = None) -> str: | |
| """ | |
| CPU/network-only cache preparation. | |
| This does not allocate ZeroGPU and does not load CUDA. | |
| """ | |
| with _prewarm_lock: | |
| _prewarm_status["state"] = "running" | |
| _prewarm_status["started_at"] = time.strftime("%Y-%m-%d %H:%M:%S") | |
| _prewarm_status["finished_at"] = None | |
| _prewarm_status["last_error"] = None | |
| try: | |
| if progress is not None: | |
| progress(0.01, desc="Preparing persistent cache directories...") | |
| _append_prewarm_log("cache preparation started") | |
| _prepare_runtime_dirs() | |
| _ensure_seqtex_token() | |
| # 1) SeqTex Space helper code. | |
| _download_repo_files_with_logs( | |
| repo_id=SEQTEX_SPACE_REPO, | |
| repo_type="space", | |
| local_dir=SEQTEX_SPACE_DIR, | |
| ignore_patterns=[ | |
| ".git/*", | |
| "__pycache__/*", | |
| "*.png", | |
| "*.jpg", | |
| "*.jpeg", | |
| "*.gif", | |
| "*.mp4", | |
| "*.webm", | |
| "examples/*", | |
| "outputs/*", | |
| ], | |
| progress=progress, | |
| progress_start=0.05, | |
| progress_end=0.25, | |
| ) | |
| # 2) SeqTex transformer into /data HF cache. | |
| _download_repo_files_with_logs( | |
| repo_id=SEQTEX_MODEL_REPO, | |
| repo_type="model", | |
| local_dir=None, | |
| ignore_patterns=[], | |
| progress=progress, | |
| progress_start=0.25, | |
| progress_end=0.45, | |
| ) | |
| # 3) WAN model into the separate /wan-cache local snapshot. | |
| # Ignore non-model assets; Diffusers only needs model_index + component folders. | |
| if DEFAULT_WAN_MODEL_REPO: | |
| _download_repo_files_with_logs( | |
| repo_id=DEFAULT_WAN_MODEL_REPO, | |
| repo_type="model", | |
| local_dir=WAN_LOCAL_MODEL_DIR, | |
| ignore_patterns=[ | |
| ".git/*", | |
| "assets/*", | |
| "examples/*", | |
| "*.png", | |
| "*.jpg", | |
| "*.jpeg", | |
| "*.gif", | |
| "*.mp4", | |
| "*.webm", | |
| ], | |
| progress=progress, | |
| progress_start=0.45, | |
| progress_end=0.95, | |
| ) | |
| # 4) Extra repos, if any and not already the WAN repo. | |
| extra_unique = [r for r in EXTRA_MODEL_REPOS if r and r != DEFAULT_WAN_MODEL_REPO] | |
| if extra_unique: | |
| span = 0.04 / max(len(extra_unique), 1) | |
| start = 0.95 | |
| for repo in extra_unique: | |
| end = min(0.99, start + span) | |
| _download_repo_files_with_logs( | |
| repo_id=repo, | |
| repo_type="model", | |
| local_dir=None, | |
| ignore_patterns=[], | |
| progress=progress, | |
| progress_start=start, | |
| progress_end=end, | |
| ) | |
| start = end | |
| if progress is not None: | |
| progress(1.0, desc="Cache preparation complete") | |
| _prewarm_status["state"] = "done" | |
| _prewarm_status["finished_at"] = time.strftime("%Y-%m-%d %H:%M:%S") | |
| _append_prewarm_log("cache preparation complete") | |
| return _prewarm_log_text() | |
| except Exception as exc: | |
| tb = traceback.format_exc() | |
| _prewarm_status["state"] = "error" | |
| _prewarm_status["finished_at"] = time.strftime("%Y-%m-%d %H:%M:%S") | |
| _prewarm_status["last_error"] = str(exc) | |
| _append_prewarm_log(f"ERROR: {exc}") | |
| log.error("Prewarm failed:\n%s", tb) | |
| return _prewarm_log_text() | |
| def prewarm_cache_ui(progress: gr.Progress = gr.Progress(track_tqdm=True)) -> str: | |
| return prewarm_cache_impl(progress=progress) | |
| def get_cache_status_ui() -> str: | |
| return _prewarm_log_text() | |
| def _auto_prewarm_worker() -> None: | |
| try: | |
| _append_prewarm_log(f"auto-prewarm will start after {AUTO_PREWARM_DELAY_SECONDS}s") | |
| time.sleep(AUTO_PREWARM_DELAY_SECONDS) | |
| prewarm_cache_impl(progress=None) | |
| except Exception: | |
| log.error("Auto-prewarm worker crashed:\n%s", traceback.format_exc()) | |
| def start_auto_prewarm_once() -> str: | |
| global _prewarm_thread | |
| if not AUTO_PREWARM: | |
| _append_prewarm_log("auto-prewarm disabled by AUTO_PREWARM=0") | |
| return _prewarm_log_text() | |
| if _prewarm_thread is not None and _prewarm_thread.is_alive(): | |
| return _prewarm_log_text() | |
| if _prewarm_status.get("state") in {"running", "done"}: | |
| return _prewarm_log_text() | |
| _prewarm_thread = threading.Thread(target=_auto_prewarm_worker, daemon=True, name="seqtex-auto-prewarm") | |
| _prewarm_thread.start() | |
| _append_prewarm_log("auto-prewarm thread started") | |
| return _prewarm_log_text() | |
| # --------------------------------------------------------------------------- | |
| # 3. Build/runtime compatibility patches | |
| # --------------------------------------------------------------------------- | |
| def _clean_nvdiffrast_extension_cache() -> None: | |
| ext_dir = os.getenv("TORCH_EXTENSIONS_DIR") or "/tmp/torch_extensions" | |
| log.warning("Cleaning nvdiffrast extension cache under %s", ext_dir) | |
| for mod_name in list(sys.modules.keys()): | |
| if mod_name.startswith("nvdiffrast_plugin"): | |
| sys.modules.pop(mod_name, None) | |
| if not os.path.isdir(ext_dir): | |
| return | |
| for root, dirs, files in os.walk(ext_dir): | |
| for dirname in list(dirs): | |
| if "nvdiffrast" in dirname or dirname == "nvdiffrast_plugin": | |
| path = os.path.join(root, dirname) | |
| try: | |
| shutil.rmtree(path, ignore_errors=True) | |
| log.warning("Removed stale nvdiffrast extension dir: %s", path) | |
| except Exception as exc: | |
| log.warning("Could not remove %s: %s", path, exc) | |
| for filename in files: | |
| if "nvdiffrast" in filename or filename.startswith("nvdiffrast_plugin"): | |
| path = os.path.join(root, filename) | |
| try: | |
| os.remove(path) | |
| log.warning("Removed stale nvdiffrast extension file: %s", path) | |
| except Exception as exc: | |
| log.warning("Could not remove %s: %s", path, exc) | |
| def _patch_torch_cpp_extension_load(torch_module: Any) -> None: | |
| global _CPP_EXTENSION_LOAD_PATCHED | |
| if _CPP_EXTENSION_LOAD_PATCHED: | |
| return | |
| try: | |
| cpp_ext = torch_module.utils.cpp_extension | |
| except Exception as exc: | |
| log.warning("Could not access torch.utils.cpp_extension: %s", exc) | |
| return | |
| original_load = cpp_ext.load | |
| def load_and_register(*args, **kwargs): | |
| plugin_name = kwargs.get("name") | |
| if plugin_name is None and args: | |
| plugin_name = args[0] | |
| module = original_load(*args, **kwargs) | |
| if plugin_name and module is not None: | |
| try: | |
| sys.modules[str(plugin_name)] = module | |
| module_file = getattr(module, "__file__", None) | |
| if module_file: | |
| module_dir = os.path.dirname(os.path.abspath(module_file)) | |
| parent_dir = os.path.dirname(module_dir) | |
| for path in (module_dir, parent_dir): | |
| if path and path not in sys.path: | |
| sys.path.insert(0, path) | |
| log.info("Registered Torch extension module %s from %s", plugin_name, module_file) | |
| else: | |
| log.info("Registered Torch extension module %s", plugin_name) | |
| except Exception as exc: | |
| log.warning("Torch extension registration failed for %s: %s", plugin_name, exc) | |
| return module | |
| cpp_ext.load = load_and_register | |
| _CPP_EXTENSION_LOAD_PATCHED = True | |
| log.info("Patched torch.utils.cpp_extension.load for nvdiffrast plugin registration") | |
| def _reset_nvdiffrast_runtime_modules() -> None: | |
| for mod_name in list(sys.modules.keys()): | |
| if mod_name.startswith("nvdiffrast_plugin"): | |
| sys.modules.pop(mod_name, None) | |
| try: | |
| ops_mod = sys.modules.get("nvdiffrast.torch.ops") | |
| if ops_mod is not None and hasattr(ops_mod, "_cached_plugin"): | |
| ops_mod._cached_plugin = {False: None, True: None} | |
| log.warning("Reset nvdiffrast.torch.ops._cached_plugin") | |
| except Exception as exc: | |
| log.warning("Could not reset nvdiffrast cached plugin state: %s", exc) | |
| def _patch_diffusers_wan_frozendict_config() -> None: | |
| global _DIFFUSERS_WAN_CONFIG_PATCHED | |
| if _DIFFUSERS_WAN_CONFIG_PATCHED: | |
| return | |
| try: | |
| from diffusers.configuration_utils import FrozenDict | |
| except Exception as exc: | |
| log.warning("Could not import diffusers FrozenDict: %s", exc) | |
| return | |
| original_getattr = getattr(FrozenDict, "__getattr__", None) | |
| def seqtex_frozendict_getattr(self, name): | |
| try: | |
| return self[name] | |
| except Exception: | |
| pass | |
| if name == "scale_factor_temporal": | |
| return 4 | |
| if name == "scale_factor_spatial": | |
| return 8 | |
| if original_getattr is not None: | |
| try: | |
| return original_getattr(self, name) | |
| except Exception: | |
| pass | |
| raise AttributeError(f"'FrozenDict' object has no attribute '{name}'") | |
| try: | |
| FrozenDict.__getattr__ = seqtex_frozendict_getattr | |
| log.info("Patched diffusers FrozenDict for Wan scale_factor_temporal/scale_factor_spatial compatibility") | |
| except Exception as exc: | |
| log.warning("Could not patch FrozenDict.__getattr__: %s", exc) | |
| try: | |
| from diffusers.pipelines.wan.pipeline_wan import WanPipeline | |
| original_init = getattr(WanPipeline, "__init__", None) | |
| if original_init is not None and not getattr(original_init, "_seqtex_wan_config_patch", False): | |
| def seqtex_wan_init(self, tokenizer, text_encoder, transformer, vae, scheduler): | |
| if vae is not None: | |
| try: | |
| cfg = getattr(vae, "config", None) | |
| missing_temporal = True | |
| missing_spatial = True | |
| if cfg is not None: | |
| try: | |
| getattr(cfg, "scale_factor_temporal") | |
| missing_temporal = False | |
| except Exception: | |
| missing_temporal = True | |
| try: | |
| getattr(cfg, "scale_factor_spatial") | |
| missing_spatial = False | |
| except Exception: | |
| missing_spatial = True | |
| if hasattr(vae, "register_to_config"): | |
| patch_kwargs = {} | |
| if missing_temporal: | |
| patch_kwargs["scale_factor_temporal"] = 4 | |
| if missing_spatial: | |
| patch_kwargs["scale_factor_spatial"] = 8 | |
| if patch_kwargs: | |
| vae.register_to_config(**patch_kwargs) | |
| log.info("Registered missing WAN VAE config values: %s", patch_kwargs) | |
| except Exception as exc: | |
| log.warning("Could not register WAN VAE scale factors on VAE config: %s", exc) | |
| return original_init(self, tokenizer, text_encoder, transformer, vae, scheduler) | |
| seqtex_wan_init._seqtex_wan_config_patch = True | |
| WanPipeline.__init__ = seqtex_wan_init | |
| log.info("Patched Diffusers WanPipeline.__init__ for SeqTex VAE compatibility") | |
| except Exception as exc: | |
| log.warning("Could not patch WanPipeline.__init__: %s", exc) | |
| # Patch the SeqTex custom WanT2TexPipeline.components property. | |
| # | |
| # Diffusers 0.38's DiffusionPipeline.components is stricter than the | |
| # custom SeqTex pipeline expects. The SeqTex/WAN config contains extra | |
| # non-module config keys such as boundary_ratio and expand_timesteps, plus | |
| # optional transformer_2. During TEX_PIPE.to("cuda"), Diffusers calls | |
| # self.components and raises: | |
| # | |
| # Expected ['scheduler', 'text_encoder', 'tokenizer', 'transformer', 'vae'] | |
| # but ['boundary_ratio', 'expand_timesteps', ..., 'transformer_2', ...] | |
| # are defined. | |
| # | |
| # For .to("cuda"), only real pipeline modules are needed, so expose exactly | |
| # the core modules Diffusers expects for this custom pipeline. | |
| try: | |
| import wan.pipeline_wan_t2tex_extra as seqtex_wan_extra | |
| WanT2TexPipeline = getattr(seqtex_wan_extra, "WanT2TexPipeline", None) | |
| if WanT2TexPipeline is not None and not getattr(WanT2TexPipeline, "_seqtex_components_patch", False): | |
| def seqtex_components(self): | |
| component_names = ["scheduler", "text_encoder", "tokenizer", "transformer", "vae"] | |
| components = {} | |
| for component_name in component_names: | |
| if hasattr(self, component_name): | |
| components[component_name] = getattr(self, component_name) | |
| return components | |
| WanT2TexPipeline.components = property(seqtex_components) | |
| WanT2TexPipeline._seqtex_components_patch = True | |
| log.info("Patched SeqTex WanT2TexPipeline.components to ignore non-module config keys") | |
| except Exception as exc: | |
| log.warning("Could not patch SeqTex WanT2TexPipeline.components: %s", exc) | |
| _DIFFUSERS_WAN_CONFIG_PATCHED = True | |
| # --------------------------------------------------------------------------- | |
| # 4. Cached/local-only model loading | |
| # --------------------------------------------------------------------------- | |
| def _required_model_repos() -> list[str]: | |
| repos: list[str] = [] | |
| for repo in [SEQTEX_MODEL_REPO, DEFAULT_WAN_MODEL_REPO, *EXTRA_MODEL_REPOS]: | |
| repo = (repo or "").strip() | |
| if repo and repo not in repos: | |
| repos.append(repo) | |
| return repos | |
| def _cached_snapshot_path(repo_id: str, repo_type: str = "model") -> str: | |
| if _is_wan_repo(repo_id): | |
| if not _is_wan_local_dir_ready(): | |
| raise FileNotFoundError( | |
| f"WAN model local directory is incomplete: {WAN_LOCAL_MODEL_DIR}. " | |
| "Run Cache / Startup -> Prepare cache now." | |
| ) | |
| return WAN_LOCAL_MODEL_DIR | |
| key = (repo_id, repo_type) | |
| if key in _CACHED_REPO_PATHS and os.path.isdir(_CACHED_REPO_PATHS[key]): | |
| return _CACHED_REPO_PATHS[key] | |
| from huggingface_hub import snapshot_download | |
| path = snapshot_download( | |
| repo_id=repo_id, | |
| repo_type=repo_type, | |
| cache_dir=os.getenv("HF_HUB_CACHE"), | |
| token=_get_hf_token(), | |
| local_files_only=True, | |
| ) | |
| _CACHED_REPO_PATHS[key] = path | |
| return path | |
| def _assert_cached_models_ready() -> None: | |
| missing: list[str] = [] | |
| for repo in _required_model_repos(): | |
| try: | |
| local_path = _cached_snapshot_path(repo, "model") | |
| log.info("Cached model ready: %s -> %s", repo, local_path) | |
| except Exception as exc: | |
| log.warning("Required cached model missing/incomplete: %s (%s)", repo, exc) | |
| missing.append(repo) | |
| if missing: | |
| raise StartupFixError( | |
| "Required model cache is missing or incomplete: " | |
| + ", ".join(missing) | |
| + ". Open Cache / Startup and click Prepare cache now. " | |
| + "Do not click Generate until cache preparation says state: done." | |
| ) | |
| def _repo_to_local_path_if_cached(path_or_repo: Any) -> Any: | |
| if not isinstance(path_or_repo, str): | |
| return path_or_repo | |
| if os.path.exists(path_or_repo): | |
| return path_or_repo | |
| known_repos = [SEQTEX_MODEL_REPO, DEFAULT_WAN_MODEL_REPO, *EXTRA_MODEL_REPOS] | |
| for repo_id in known_repos: | |
| if path_or_repo == repo_id: | |
| try: | |
| local_path = _cached_snapshot_path(repo_id, "model") | |
| log.info("Using cached model path for %s: %s", repo_id, local_path) | |
| return local_path | |
| except Exception as exc: | |
| if FORCE_LOCAL_GENERATE: | |
| raise StartupFixError( | |
| f"Model repo {repo_id} is not fully cached. " | |
| "Run Cache / Startup -> Prepare cache now first." | |
| ) from exc | |
| log.warning("Could not resolve cached path for %s: %s", repo_id, exc) | |
| return path_or_repo | |
| return path_or_repo | |
| def _patch_cached_model_loading() -> None: | |
| global _CACHE_PATCHED | |
| if _CACHE_PATCHED: | |
| return | |
| cache_dir = os.getenv("HF_HUB_CACHE") | |
| def add_cached_kwargs(kwargs: dict[str, Any]) -> dict[str, Any]: | |
| kwargs.setdefault("cache_dir", cache_dir) | |
| token = _get_hf_token() | |
| if token: | |
| kwargs.setdefault("token", token) | |
| if _LOCAL_GENERATE_MODE and FORCE_LOCAL_GENERATE: | |
| kwargs["local_files_only"] = True | |
| kwargs.pop("force_download", None) | |
| return kwargs | |
| try: | |
| import huggingface_hub | |
| if not getattr(huggingface_hub.hf_hub_download, "_seqtex_cache_patch", False): | |
| original_hf_hub_download = huggingface_hub.hf_hub_download | |
| def cached_hf_hub_download(*args, **kwargs): | |
| add_cached_kwargs(kwargs) | |
| return original_hf_hub_download(*args, **kwargs) | |
| cached_hf_hub_download._seqtex_cache_patch = True | |
| huggingface_hub.hf_hub_download = cached_hf_hub_download | |
| if not getattr(huggingface_hub.snapshot_download, "_seqtex_cache_patch", False): | |
| original_snapshot_download = huggingface_hub.snapshot_download | |
| def cached_snapshot_download(*args, **kwargs): | |
| add_cached_kwargs(kwargs) | |
| return original_snapshot_download(*args, **kwargs) | |
| cached_snapshot_download._seqtex_cache_patch = True | |
| huggingface_hub.snapshot_download = cached_snapshot_download | |
| log.info("Patched huggingface_hub downloads for cached Generate mode") | |
| except Exception as exc: | |
| log.warning("Could not patch huggingface_hub cached loading: %s", exc) | |
| try: | |
| from diffusers import DiffusionPipeline | |
| from diffusers.models.modeling_utils import ModelMixin | |
| def patch_classmethod(cls: Any, attr: str, label: str) -> None: | |
| current = getattr(cls, attr) | |
| underlying = getattr(current, "__func__", current) | |
| if getattr(underlying, "_seqtex_cache_patch", False): | |
| return | |
| def cached_from_pretrained(inner_cls, pretrained_model_name_or_path=None, *model_args, **kwargs): | |
| if _LOCAL_GENERATE_MODE and FORCE_LOCAL_GENERATE: | |
| pretrained_model_name_or_path = _repo_to_local_path_if_cached(pretrained_model_name_or_path) | |
| add_cached_kwargs(kwargs) | |
| return underlying(inner_cls, pretrained_model_name_or_path, *model_args, **kwargs) | |
| cached_from_pretrained._seqtex_cache_patch = True | |
| setattr(cls, attr, classmethod(cached_from_pretrained)) | |
| log.info("Patched %s.%s for cached Generate mode", label, attr) | |
| patch_classmethod(DiffusionPipeline, "from_pretrained", "DiffusionPipeline") | |
| patch_classmethod(ModelMixin, "from_pretrained", "ModelMixin") | |
| except Exception as exc: | |
| log.warning("Could not patch Diffusers from_pretrained: %s", exc) | |
| try: | |
| from transformers import AutoTokenizer, PreTrainedModel, PreTrainedTokenizerBase | |
| def patch_transformers_classmethod(cls: Any, attr: str, label: str) -> None: | |
| current = getattr(cls, attr) | |
| underlying = getattr(current, "__func__", current) | |
| if getattr(underlying, "_seqtex_cache_patch", False): | |
| return | |
| def cached_from_pretrained(inner_cls, pretrained_model_name_or_path=None, *model_args, **kwargs): | |
| if _LOCAL_GENERATE_MODE and FORCE_LOCAL_GENERATE: | |
| pretrained_model_name_or_path = _repo_to_local_path_if_cached(pretrained_model_name_or_path) | |
| add_cached_kwargs(kwargs) | |
| return underlying(inner_cls, pretrained_model_name_or_path, *model_args, **kwargs) | |
| cached_from_pretrained._seqtex_cache_patch = True | |
| setattr(cls, attr, classmethod(cached_from_pretrained)) | |
| log.info("Patched %s.%s for cached Generate mode", label, attr) | |
| patch_transformers_classmethod(AutoTokenizer, "from_pretrained", "AutoTokenizer") | |
| patch_transformers_classmethod(PreTrainedModel, "from_pretrained", "PreTrainedModel") | |
| patch_transformers_classmethod(PreTrainedTokenizerBase, "from_pretrained", "PreTrainedTokenizerBase") | |
| except Exception as exc: | |
| log.warning("Could not patch Transformers from_pretrained: %s", exc) | |
| try: | |
| tex_mod = sys.modules.get("utils.texture_generation") | |
| if tex_mod is not None: | |
| import huggingface_hub | |
| if hasattr(tex_mod, "hf_hub_download"): | |
| tex_mod.hf_hub_download = huggingface_hub.hf_hub_download | |
| if hasattr(tex_mod, "snapshot_download"): | |
| tex_mod.snapshot_download = huggingface_hub.snapshot_download | |
| log.info("Patched SeqTex texture_generation hub helpers for cached mode") | |
| except Exception as exc: | |
| log.warning("Could not patch SeqTex module hub helpers: %s", exc) | |
| _CACHE_PATCHED = True | |
| class _LocalGenerateMode: | |
| def __enter__(self): | |
| global _LOCAL_GENERATE_MODE | |
| self.old_local = _LOCAL_GENERATE_MODE | |
| self.old_env = { | |
| "HF_HUB_OFFLINE": os.environ.get("HF_HUB_OFFLINE"), | |
| "TRANSFORMERS_OFFLINE": os.environ.get("TRANSFORMERS_OFFLINE"), | |
| "DIFFUSERS_OFFLINE": os.environ.get("DIFFUSERS_OFFLINE"), | |
| } | |
| _LOCAL_GENERATE_MODE = True | |
| if FORCE_LOCAL_GENERATE: | |
| os.environ["HF_HUB_OFFLINE"] = "1" | |
| os.environ["TRANSFORMERS_OFFLINE"] = "1" | |
| os.environ["DIFFUSERS_OFFLINE"] = "1" | |
| return self | |
| def __exit__(self, exc_type, exc, tb): | |
| global _LOCAL_GENERATE_MODE | |
| _LOCAL_GENERATE_MODE = self.old_local | |
| for key, value in self.old_env.items(): | |
| if value is None: | |
| os.environ.pop(key, None) | |
| else: | |
| os.environ[key] = value | |
| return False | |
| # --------------------------------------------------------------------------- | |
| # 5. Lazy SeqTex import/model loading | |
| # --------------------------------------------------------------------------- | |
| def _bootstrap_seqtex_utils() -> None: | |
| marker = os.path.join(SEQTEX_SPACE_DIR, "utils", "mesh_utils.py") | |
| if os.path.isfile(marker): | |
| log.info("SeqTex utilities already present at %s", SEQTEX_SPACE_DIR) | |
| else: | |
| _append_prewarm_log("SeqTex utilities missing; downloading before generation") | |
| prewarm_cache_impl(progress=None) | |
| if SEQTEX_SPACE_DIR not in sys.path: | |
| sys.path.insert(0, SEQTEX_SPACE_DIR) | |
| def _load_seqtex_modules() -> dict[str, Any]: | |
| global _seqtex_modules | |
| if _seqtex_modules is not None: | |
| return _seqtex_modules | |
| _bootstrap_seqtex_utils() | |
| log.info("Importing SeqTex modules lazily...") | |
| mesh_utils = importlib.import_module("utils.mesh_utils") | |
| render_utils = importlib.import_module("utils.render_utils") | |
| texture_generation = importlib.import_module("utils.texture_generation") | |
| torch = importlib.import_module("torch") | |
| _patch_torch_cpp_extension_load(torch) | |
| _patch_diffusers_wan_frozendict_config() | |
| _patch_cached_model_loading() | |
| np = importlib.import_module("numpy") | |
| _seqtex_modules = { | |
| "Mesh": mesh_utils.Mesh, | |
| "get_mvp_matrix": render_utils.get_mvp_matrix, | |
| "render_geo_map": render_utils.render_geo_map, | |
| "render_geo_views_tensor": render_utils.render_geo_views_tensor, | |
| "get_seqtex_pipe": texture_generation.get_seqtex_pipe, | |
| "encode_images": texture_generation.encode_images, | |
| "decode_images": texture_generation.decode_images, | |
| "convert_img_to_tensor": texture_generation.convert_img_to_tensor, | |
| "texture_generation_module": texture_generation, | |
| "torch": torch, | |
| "np": np, | |
| } | |
| log.info("SeqTex modules imported successfully") | |
| return _seqtex_modules | |
| def _get_seqtex_pipe() -> Any: | |
| global _seqtex_pipe | |
| _ensure_seqtex_token() | |
| modules = _load_seqtex_modules() | |
| if _seqtex_pipe is None: | |
| _patch_diffusers_wan_frozendict_config() | |
| _patch_cached_model_loading() | |
| if FORCE_LOCAL_GENERATE: | |
| log.info("Checking required cached model snapshots before loading pipeline...") | |
| _assert_cached_models_ready() | |
| log.info( | |
| "Loading SeqTex pipeline onto GPU using cached model files only. " | |
| "WAN local model dir: %s", | |
| WAN_LOCAL_MODEL_DIR, | |
| ) | |
| with _LocalGenerateMode(): | |
| _seqtex_pipe = modules["get_seqtex_pipe"]() | |
| log.info("SeqTex pipeline loaded") | |
| return _seqtex_pipe | |
| # --------------------------------------------------------------------------- | |
| # 6. Mesh processing | |
| # --------------------------------------------------------------------------- | |
| def step1_process_mesh( | |
| glb_path: str, | |
| upside_down: bool, | |
| uv_size: int, | |
| mv_size: int, | |
| progress: gr.Progress | None = None, | |
| ) -> tuple[dict[str, Any], Image.Image]: | |
| modules = _load_seqtex_modules() | |
| Mesh = modules["Mesh"] | |
| get_mvp_matrix = modules["get_mvp_matrix"] | |
| render_geo_views_tensor = modules["render_geo_views_tensor"] | |
| render_geo_map = modules["render_geo_map"] | |
| torch = modules["torch"] | |
| np = modules["np"] | |
| device = "cuda" | |
| log.info("Step 1: loading mesh from %s", glb_path) | |
| if progress is not None: | |
| progress(0.08, desc="Loading mesh and generating UVs if needed...") | |
| # Do not pass Gradio Progress into official Mesh helper. | |
| mesh = Mesh(glb_path, uv_tool="xAtlas", device=device) | |
| if progress is not None: | |
| progress(0.16, desc="Applying Z-UP orientation and normalizing mesh...") | |
| mesh.vertex_transform() | |
| if upside_down: | |
| mesh.vertex_transform_upsidedown() | |
| mesh.normalize() | |
| img_size = (int(mv_size), int(mv_size)) | |
| uv_sz = (int(uv_size), int(uv_size)) | |
| try: | |
| mvp_matrix, w2c = get_mvp_matrix(mesh, num_views=4, width=int(mv_size), height=int(mv_size)) | |
| except TypeError: | |
| mvp_matrix, w2c = get_mvp_matrix(mesh) | |
| mvp_matrix = mvp_matrix.to(device) | |
| w2c = w2c.to(device) | |
| if progress is not None: | |
| progress(0.24, desc="Rendering geometry views / compiling rasterizer if needed...") | |
| try: | |
| pos_imgs, norm_imgs, mask_imgs = render_geo_views_tensor(mesh, mvp_matrix, img_size) | |
| except ModuleNotFoundError as exc: | |
| if "nvdiffrast_plugin" not in str(exc): | |
| raise | |
| log.warning("nvdiffrast_plugin not importable. Cleaning extension cache and retrying once.") | |
| if progress is not None: | |
| progress(0.25, desc="Cleaning stale nvdiffrast plugin cache and retrying...") | |
| _clean_nvdiffrast_extension_cache() | |
| _reset_nvdiffrast_runtime_modules() | |
| pos_imgs, norm_imgs, mask_imgs = render_geo_views_tensor(mesh, mvp_matrix, img_size) | |
| if progress is not None: | |
| progress(0.30, desc="Rendering UV-space geometry maps...") | |
| pos_map, norm_map = render_geo_map(mesh, map_size=uv_sz) | |
| def _save_tensor(tensor: Any, prefix: str) -> str: | |
| f = tempfile.NamedTemporaryFile(delete=False, suffix=".pt", prefix=f"{prefix}_") | |
| torch.save(tensor.detach().cpu(), f.name) | |
| f.close() | |
| return f.name | |
| mesh_cpu = mesh.to("cpu") | |
| mesh_file = tempfile.NamedTemporaryFile(delete=False, suffix="_processed_mesh.pkl", prefix="seqtex_mesh_") | |
| with open(mesh_file.name, "wb") as f: | |
| pickle.dump(mesh_cpu, f) | |
| result = { | |
| "pos_imgs": _save_tensor(pos_imgs, "pos_imgs"), | |
| "norm_imgs": _save_tensor(norm_imgs, "norm_imgs"), | |
| "mask_imgs": _save_tensor(mask_imgs, "mask_imgs"), | |
| "pos_map": _save_tensor(pos_map, "pos_map"), | |
| "norm_map": _save_tensor(norm_map, "norm_map"), | |
| "w2c": _save_tensor(w2c, "w2c"), | |
| "mvp": _save_tensor(mvp_matrix, "mvp"), | |
| "mesh_pkl": mesh_file.name, | |
| "uv_size": int(uv_size), | |
| "mv_size": int(mv_size), | |
| } | |
| norm_np = norm_imgs.detach().cpu().numpy() | |
| norm_np = (norm_np * 0.5 + 0.5).clip(0, 1) | |
| tiles = [Image.fromarray((norm_np[i] * 255).astype(np.uint8)) for i in range(min(4, norm_np.shape[0]))] | |
| w, h = tiles[0].size | |
| preview = Image.new("RGB", (w * len(tiles), h)) | |
| for i, tile in enumerate(tiles): | |
| preview.paste(tile, (i * w, 0)) | |
| log.info("Step 1 complete") | |
| return result, preview | |
| # --------------------------------------------------------------------------- | |
| # 7. SeqTex generation | |
| # --------------------------------------------------------------------------- | |
| def step2_generate_texture( | |
| geo_data: dict[str, Any], | |
| condition_image: Image.Image, | |
| text_prompt: str, | |
| seed: int, | |
| steps: int, | |
| guidance_scale: float, | |
| num_views: int, | |
| progress: gr.Progress | None = None, | |
| ) -> Image.Image: | |
| modules = _load_seqtex_modules() | |
| torch = modules["torch"] | |
| np = modules["np"] | |
| encode_images = modules["encode_images"] | |
| decode_images = modules["decode_images"] | |
| convert_img_to_tensor = modules["convert_img_to_tensor"] | |
| texture_generation_module = modules.get("texture_generation_module") | |
| device = "cuda" | |
| mv_size = int(geo_data["mv_size"]) | |
| uv_size = int(geo_data["uv_size"]) | |
| if progress is not None: | |
| progress(0.36, desc="Loading SeqTex pipeline onto GPU from local cache...") | |
| pipe = _get_seqtex_pipe() | |
| def _load_tensor(path: str) -> Any: | |
| return torch.load(path, map_location=device) | |
| pos_imgs = _load_tensor(geo_data["pos_imgs"]) | |
| norm_imgs = _load_tensor(geo_data["norm_imgs"]) | |
| pos_map = _load_tensor(geo_data["pos_map"]) | |
| norm_map = _load_tensor(geo_data["norm_map"]) | |
| if progress is not None: | |
| progress(0.46, desc="Encoding geometry latents...") | |
| def _to_bfhwc(frames: Any, name: str) -> Any: | |
| """Normalize geometry tensors to SeqTex encode_images() shape [B, F, H, W, C]. | |
| render_geo_views_tensor() usually returns multi-view tensors as [F, H, W, C]. | |
| render_geo_map() may return UV maps either as [H, W, C] or [1, H, W, C] | |
| depending on the SeqTex utility version. The previous code added an extra | |
| unsqueeze for UV maps, which produced [1, 1, 1, H, W, C] and caused: | |
| einops.EinopsError: expected 5 dims, received 6 dims | |
| This helper makes all inputs exactly [B, F, H, W, C]. | |
| """ | |
| ndim = getattr(frames, "ndim", None) | |
| if ndim == 3: | |
| # [H, W, C] -> [1, 1, H, W, C] | |
| return frames.unsqueeze(0).unsqueeze(0) | |
| if ndim == 4: | |
| # [F, H, W, C] -> [1, F, H, W, C] | |
| return frames.unsqueeze(0) | |
| if ndim == 5: | |
| # Already [B, F, H, W, C] | |
| return frames | |
| shape = getattr(frames, "shape", "unknown") | |
| raise ValueError(f"{name} has unsupported shape for encode_images: {shape}") | |
| def _reset_wan_vae_encode_cache() -> None: | |
| """Clear WAN VAE's stateful encode cache before each independent encode. | |
| Diffusers' AutoencoderKLWan keeps an internal feature cache while encoding | |
| video chunks. SeqTex calls encode_images() several times for unrelated | |
| tensors: multi-view geometry, UV geometry, and the condition image. | |
| Resetting the cache between those calls prevents stale cached feature maps | |
| from leaking into the next encode. | |
| """ | |
| if texture_generation_module is None: | |
| return | |
| vae = getattr(texture_generation_module, "VAE", None) | |
| if vae is None: | |
| return | |
| for attr in ("_enc_feat_map", "_enc_conv_idx", "_dec_feat_map", "_dec_conv_idx"): | |
| if hasattr(vae, attr): | |
| try: | |
| setattr(vae, attr, None) | |
| except Exception: | |
| pass | |
| def _enc(frames: Any, name: str) -> Any: | |
| frames_5d = _to_bfhwc(frames, name) | |
| # Force float32 on CUDA. PIL/NumPy condition tensors can become float64; | |
| # Conv3D with float64 on CUDA falls back to aten::slow_conv3d_forward, | |
| # which is CPU-only in the ZeroGPU PyTorch build. | |
| frames_5d = frames_5d.to(device=device, dtype=torch.float32, non_blocking=True).contiguous() | |
| log.info("Encoding %s with shape %s dtype=%s device=%s", name, tuple(frames_5d.shape), frames_5d.dtype, frames_5d.device) | |
| _reset_wan_vae_encode_cache() | |
| return encode_images(frames_5d, encode_as_first=True) | |
| nat_pos_lat = _enc(pos_imgs, "pos_imgs") | |
| nat_norm_lat = _enc(norm_imgs, "norm_imgs") | |
| uv_pos_lat = _enc(pos_map, "pos_map") | |
| uv_norm_lat = _enc(norm_map, "norm_map") | |
| nat_geo = torch.cat([nat_pos_lat, nat_norm_lat], dim=1) | |
| uv_geo = torch.cat([uv_pos_lat, uv_norm_lat], dim=1) | |
| cond_model_latents = (nat_geo, uv_geo) | |
| del nat_pos_lat, nat_norm_lat, uv_pos_lat, uv_norm_lat | |
| torch.cuda.empty_cache() | |
| if progress is not None: | |
| progress(0.56, desc="Encoding reference image...") | |
| cond_pil = condition_image.convert("RGB").resize((mv_size, mv_size), Image.LANCZOS) | |
| cond_t = convert_img_to_tensor(cond_pil, device=device) | |
| gt_latent = _enc(cond_t, "condition_image") | |
| gt_condition = (gt_latent, None) | |
| del cond_t | |
| torch.cuda.empty_cache() | |
| text_prompt = (text_prompt or "high quality texture, clean details").strip() | |
| temporal_downsample = getattr(pipe.vae.config, "temperal_downsample", [2, 2]) | |
| frame_factor = 2 ** sum(temporal_downsample) | |
| num_frames = int(num_views) * frame_factor | |
| uv_num_frames = 1 * frame_factor | |
| if progress is not None: | |
| progress(0.66, desc=f"Running SeqTex diffusion ({int(steps)} steps)...") | |
| with torch.inference_mode(): | |
| latents = pipe( | |
| prompt=text_prompt, | |
| negative_prompt=None, | |
| num_frames=num_frames, | |
| generator=torch.Generator(device=device).manual_seed(int(seed)), | |
| num_inference_steps=int(steps), | |
| guidance_scale=float(guidance_scale), | |
| height=mv_size, | |
| width=mv_size, | |
| output_type="latent", | |
| cond_model_latents=cond_model_latents, | |
| uv_height=uv_size, | |
| uv_width=uv_size, | |
| uv_num_frames=uv_num_frames, | |
| treat_as_first=True, | |
| gt_condition=gt_condition, | |
| inference_img_cond_frame=0, | |
| use_qk_geometry=True, | |
| max_sequence_length=1024, | |
| task_type="img2tex", | |
| ).frames | |
| del cond_model_latents, gt_latent, gt_condition | |
| torch.cuda.empty_cache() | |
| mv_latents, uv_latents = latents | |
| if progress is not None: | |
| progress(0.84, desc="Decoding UV texture...") | |
| uv_frames = decode_images(uv_latents, decode_as_first=True) | |
| del uv_latents, mv_latents | |
| torch.cuda.empty_cache() | |
| uv_pred = uv_frames[:, :, -1, ...].squeeze(0).clamp(0.0, 1.0).cpu() | |
| uv_np = (uv_pred.permute(1, 2, 0).numpy() * 255).astype(np.uint8) | |
| uv_pil = Image.fromarray(uv_np).convert("RGB") | |
| log.info("Step 2 complete") | |
| return uv_pil | |
| # --------------------------------------------------------------------------- | |
| # 8. Export | |
| # --------------------------------------------------------------------------- | |
| def step3_export_glb(geo_data: dict[str, Any], uv_texture: Image.Image) -> str: | |
| modules = _load_seqtex_modules() | |
| Mesh = modules["Mesh"] | |
| with open(geo_data["mesh_pkl"], "rb") as f: | |
| mesh = pickle.load(f) | |
| out = tempfile.NamedTemporaryFile(delete=False, suffix="_textured.glb", prefix="seqtex_") | |
| out.close() | |
| Mesh.export(mesh, save_path=out.name, texture_map=uv_texture) | |
| log.info("Exported textured GLB to %s", out.name) | |
| return out.name | |
| # --------------------------------------------------------------------------- | |
| # 9. Main Generate handler — only ZeroGPU path | |
| # --------------------------------------------------------------------------- | |
| def run( | |
| glb_file: str | None, | |
| condition_image: Image.Image | None, | |
| text_prompt: str, | |
| seed: int, | |
| steps: int, | |
| guidance_scale: float, | |
| num_views: int, | |
| upside_down: bool, | |
| uv_size: int, | |
| mv_size: int, | |
| progress: gr.Progress = gr.Progress(track_tqdm=True), | |
| ): | |
| if glb_file is None: | |
| raise gr.Error("Please upload a GLB mesh.") | |
| if condition_image is None: | |
| raise gr.Error("Please upload a front-view reference image.") | |
| try: | |
| glb_path = glb_file if isinstance(glb_file, str) else glb_file.name | |
| log.info("Generate clicked: glb=%s uv=%s mv=%s steps=%s", glb_path, uv_size, mv_size, steps) | |
| progress(0.02, desc="Preparing runtime cache directories...") | |
| _prepare_runtime_dirs() | |
| if FORCE_LOCAL_GENERATE: | |
| progress(0.03, desc="Verifying cached model snapshots...") | |
| _assert_cached_models_ready() | |
| progress(0.04, desc="Preparing SeqTex utilities...") | |
| _load_seqtex_modules() | |
| geo_data, preview = step1_process_mesh( | |
| glb_path=glb_path, | |
| upside_down=bool(upside_down), | |
| uv_size=int(uv_size), | |
| mv_size=int(mv_size), | |
| progress=progress, | |
| ) | |
| uv_texture = step2_generate_texture( | |
| geo_data=geo_data, | |
| condition_image=condition_image, | |
| text_prompt=text_prompt, | |
| seed=int(seed), | |
| steps=int(steps), | |
| guidance_scale=float(guidance_scale), | |
| num_views=int(num_views), | |
| progress=progress, | |
| ) | |
| progress(0.92, desc="Baking texture into GLB...") | |
| textured_glb = step3_export_glb(geo_data, uv_texture) | |
| progress(1.0, desc="Done") | |
| return textured_glb, uv_texture, preview, textured_glb | |
| except StartupFixError as exc: | |
| log.error("Configuration error: %s", exc) | |
| raise gr.Error(str(exc)) from exc | |
| except Exception as exc: | |
| tb = traceback.format_exc() | |
| log.error("Generation failed:\n%s", tb) | |
| raise gr.Error(f"Generation failed: {exc}\n\nCheck the Container logs for the full traceback.") from exc | |
| # --------------------------------------------------------------------------- | |
| # 10. UI | |
| # --------------------------------------------------------------------------- | |
| print("[BOOT 05] building gradio UI", flush=True) | |
| with gr.Blocks(title="SeqTex Texture Generator", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown( | |
| """ | |
| # 🎨 SeqTex Texture Generator | |
| Upload an **untextured GLB mesh** and a **front-view reference image**. | |
| This version stores the WAN base model in a separate bucket mounted at | |
| `/wan-cache`, while keeping the normal Hugging Face cache at `/data`. | |
| During Generate, model loading is forced to use cached/local files only. | |
| """ | |
| ) | |
| with gr.Tab("Generate Texture"): | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| glb_input = gr.File( | |
| label="Input GLB Mesh", | |
| file_types=[".glb"], | |
| type="filepath", | |
| ) | |
| cond_image = gr.Image( | |
| label="Front-View Reference Image", | |
| type="pil", | |
| sources=["upload", "clipboard"], | |
| height=300, | |
| ) | |
| text_prompt = gr.Textbox( | |
| label="Text prompt", | |
| placeholder="e.g. anime character, colorful clothing, high quality", | |
| value="high quality texture, clean details", | |
| lines=2, | |
| ) | |
| run_btn = gr.Button("✨ Generate Texture", variant="primary", size="lg") | |
| with gr.Accordion("Advanced settings", open=False): | |
| seed = gr.Slider(label="Seed", minimum=0, maximum=2**31 - 1, value=42, step=1) | |
| steps = gr.Slider( | |
| label="Diffusion steps", | |
| minimum=5, | |
| maximum=30, | |
| value=10, | |
| step=1, | |
| info="Use 10 on ZeroGPU. Higher values need more time/VRAM.", | |
| ) | |
| guidance_scale = gr.Slider( | |
| label="Guidance scale", | |
| minimum=1.0, | |
| maximum=10.0, | |
| value=1.0, | |
| step=0.5, | |
| info="SeqTex usually works best around 1.0.", | |
| ) | |
| num_views = gr.Slider( | |
| label="Multi-view count", | |
| minimum=2, | |
| maximum=4, | |
| value=4, | |
| step=1, | |
| info="4 matches the reference SeqTex Space.", | |
| ) | |
| upside_down = gr.Checkbox( | |
| label="Flip mesh upside-down", | |
| value=False, | |
| info="Enable only if your mesh appears inverted.", | |
| ) | |
| uv_size = gr.Radio( | |
| label="UV texture resolution", | |
| choices=[512, 1024], | |
| value=1024, | |
| info="2048 is disabled for ZeroGPU stability.", | |
| ) | |
| mv_size = gr.Radio( | |
| label="Multi-view render resolution", | |
| choices=[256, 512], | |
| value=512, | |
| ) | |
| with gr.Column(scale=1): | |
| output_3d = gr.Model3D( | |
| label="Textured 3-D Model", | |
| height=450, | |
| clear_color=[0.15, 0.15, 0.15, 1.0], | |
| ) | |
| uv_preview = gr.Image( | |
| label="Generated UV Texture Map", | |
| type="pil", | |
| interactive=False, | |
| height=256, | |
| ) | |
| geo_preview = gr.Image( | |
| label="Geometry Views Preview", | |
| type="pil", | |
| interactive=False, | |
| height=150, | |
| ) | |
| download_btn = gr.File(label="Download Textured GLB") | |
| run_btn.click( | |
| fn=run, | |
| inputs=[ | |
| glb_input, | |
| cond_image, | |
| text_prompt, | |
| seed, | |
| steps, | |
| guidance_scale, | |
| num_views, | |
| upside_down, | |
| uv_size, | |
| mv_size, | |
| ], | |
| outputs=[output_3d, uv_preview, geo_preview, download_btn], | |
| ) | |
| with gr.Tab("Cache / Startup"): | |
| gr.Markdown( | |
| """ | |
| ## Cache preparation | |
| This downloads/checks: | |
| - SeqTex helper code -> `/data/seqtex_space` | |
| - SeqTex transformer -> `/data/hf_home/hub` | |
| - WAN base model -> `/wan-cache/models/Wan-AI/Wan2.1-T2V-1.3B-Diffusers` | |
| This uses CPU/network runtime only. It does **not** allocate ZeroGPU. | |
| Press **Prepare cache now** before generating. | |
| """ | |
| ) | |
| with gr.Row(): | |
| prewarm_btn = gr.Button("Prepare cache now", variant="primary") | |
| refresh_cache_btn = gr.Button("Refresh cache status") | |
| cache_log_box = gr.Textbox( | |
| label="Cache preparation logs", | |
| value=_prewarm_log_text(), | |
| lines=22, | |
| max_lines=35, | |
| interactive=False, | |
| ) | |
| prewarm_btn.click(fn=prewarm_cache_ui, inputs=[], outputs=[cache_log_box]) | |
| refresh_cache_btn.click(fn=get_cache_status_ui, inputs=[], outputs=[cache_log_box]) | |
| gr.Markdown( | |
| f""" | |
| --- | |
| **Recommended Space setup** | |
| 1. Keep or mount a bucket at `/data`. | |
| 2. Create/mount the new WAN bucket at `/wan-cache`. | |
| 3. Add a Space secret named `HF_TOKEN` or `SEQTEX_SPACE_TOKEN`. | |
| 4. Keep `AUTO_PREWARM=0` while debugging. Use **Prepare cache now** manually. | |
| 5. Keep `FORCE_LOCAL_GENERATE=1`. | |
| Current WAN local model directory: | |
| `{WAN_LOCAL_MODEL_DIR}` | |
| If Generate says the cache is incomplete, run **Prepare cache now** again. | |
| """ | |
| ) | |
| if AUTO_PREWARM: | |
| demo.load(fn=start_auto_prewarm_once, inputs=[], outputs=[cache_log_box]) | |
| print("[BOOT 06] gradio UI built", flush=True) | |
| if __name__ == "__main__": | |
| port = int(os.getenv("PORT", "7860")) | |
| print(f"[BOOT 07] launching gradio on 0.0.0.0:{port} with SSR disabled", flush=True) | |
| queued_demo = demo.queue(max_size=3) | |
| try: | |
| queued_demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=port, | |
| ssr_mode=False, | |
| show_api=False, | |
| ) | |
| except TypeError as exc: | |
| print(f"[BOOT WARN] launch kwargs fallback because: {exc}", flush=True) | |
| queued_demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=port, | |
| ) | |