GMTexture / app.py
vivekchakraverty's picture
Upload app.py
9c4371f verified
Raw
History Blame Contribute Delete
59.9 kB
"""
SeqTex Texture Generator — Hugging Face Space
=============================================
Startup-safe + Storage Bucket + separate WAN bucket + cached-model-only Generate.
Plain Python file. Do not paste Markdown fences such as ```python into app.py.
Recommended buckets:
- Existing/general bucket mounted at: /data
- New WAN bucket mounted at: /wan-cache
What this version does:
- Launches Gradio first.
- Keeps SeqTex Space utilities and SeqTex-Transformer cache under /data.
- Stores the large WAN base model in the separate bucket /wan-cache.
- Uses CPU/network prewarm from the Cache / Startup tab.
- Uses ZeroGPU only for Generate Texture.
- During Generate, forces all known model repo IDs to local cached paths.
- If the cache is incomplete, Generate fails fast instead of downloading during ZeroGPU.
- Includes nvdiffrast_plugin and Diffusers WAN compatibility patches.
"""
from __future__ import annotations
# ---------------------------------------------------------------------------
# 0. Minimal startup section
# ---------------------------------------------------------------------------
import fnmatch
import importlib
import logging
import os
import pickle
import shutil
import sys
import tempfile
import threading
import time
import traceback
from typing import Any
print("[BOOT 00] app.py started", flush=True)
# General persistent bucket. Your existing bucket can stay here.
PERSISTENT_ROOT = os.getenv("PERSISTENT_ROOT", "/data")
CACHE_ROOT = os.getenv("CACHE_ROOT", os.path.join(PERSISTENT_ROOT, "hf_home"))
# New dedicated WAN bucket. Mount the new bucket here in HF Settings -> Storage Buckets.
WAN_BUCKET_ROOT = os.getenv("WAN_BUCKET_ROOT", "/wan-cache")
# Hugging Face cache for SeqTex-Transformer and normal Hub cache.
os.environ.setdefault("HF_HOME", CACHE_ROOT)
os.environ.setdefault("HF_HUB_CACHE", os.path.join(CACHE_ROOT, "hub"))
os.environ.setdefault("TRANSFORMERS_CACHE", os.path.join(CACHE_ROOT, "hub"))
os.environ.setdefault("DIFFUSERS_CACHE", os.path.join(CACHE_ROOT, "hub"))
os.environ.setdefault("TORCH_HOME", os.path.join(PERSISTENT_ROOT, "torch"))
# nvdiffrast compiles a small runtime extension. Keep this on local tmp.
# Do not use /data or /wan-cache for Torch extensions.
os.environ.setdefault("TORCH_EXTENSIONS_DIR", os.getenv("TORCH_EXTENSIONS_DIR", "/tmp/torch_extensions"))
os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "0")
os.environ.setdefault("GRADIO_SSR_MODE", "False")
os.environ.setdefault("GRADIO_ANALYTICS_ENABLED", "False")
# During Generate, use only cached/local files. Prewarm can still download.
FORCE_LOCAL_GENERATE = os.getenv("FORCE_LOCAL_GENERATE", "1").strip().lower() in {"1", "true", "yes", "on"}
SEQTEX_SPACE_REPO = os.getenv("SEQTEX_SPACE_REPO", "VAST-AI/SeqTex")
SEQTEX_MODEL_REPO = os.getenv("SEQTEX_MODEL_REPO", "VAST-AI/SeqTex-Transformer")
SEQTEX_SPACE_DIR = os.getenv("SEQTEX_SPACE_DIR", os.path.join(PERSISTENT_ROOT, "seqtex_space"))
# SeqTex's official loader pulls this WAN base model at runtime.
DEFAULT_WAN_MODEL_REPO = os.getenv("DEFAULT_WAN_MODEL_REPO", "Wan-AI/Wan2.1-T2V-1.3B-Diffusers")
WAN_LOCAL_MODEL_DIR = os.getenv(
"WAN_LOCAL_MODEL_DIR",
os.path.join(WAN_BUCKET_ROOT, "models", "Wan-AI", "Wan2.1-T2V-1.3B-Diffusers"),
)
# Optional comma-separated extra model repos. By default, include WAN because SeqTex needs it.
_extra_repos_env = os.getenv("EXTRA_MODEL_REPOS")
if _extra_repos_env is None:
_extra_repos_env = DEFAULT_WAN_MODEL_REPO
EXTRA_MODEL_REPOS = [repo.strip() for repo in _extra_repos_env.split(",") if repo.strip()]
AUTO_PREWARM = os.getenv("AUTO_PREWARM", "0").strip().lower() in {"1", "true", "yes", "on"}
AUTO_PREWARM_DELAY_SECONDS = int(os.getenv("AUTO_PREWARM_DELAY_SECONDS", "10"))
print("[BOOT 01] environment variables set", flush=True)
# ---------------------------------------------------------------------------
# 1. Light imports only
# ---------------------------------------------------------------------------
try:
import gradio as gr
print("[BOOT 02] gradio imported", flush=True)
except Exception:
print("[BOOT ERROR] gradio import failed", flush=True)
raise
try:
import spaces
print("[BOOT 03] spaces imported", flush=True)
except Exception as exc:
print(f"[BOOT WARN] spaces import failed: {exc}", flush=True)
spaces = None
try:
from PIL import Image
print("[BOOT 04] PIL imported", flush=True)
except Exception:
print("[BOOT ERROR] PIL import failed", flush=True)
raise
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
handlers=[logging.StreamHandler(sys.stdout)],
force=True,
)
log = logging.getLogger("seqtex-app")
log.info("PERSISTENT_ROOT=%s", PERSISTENT_ROOT)
log.info("WAN_BUCKET_ROOT=%s", WAN_BUCKET_ROOT)
log.info("HF_HOME=%s", os.getenv("HF_HOME"))
log.info("HF_HUB_CACHE=%s", os.getenv("HF_HUB_CACHE"))
log.info("TORCH_HOME=%s", os.getenv("TORCH_HOME"))
log.info("TORCH_EXTENSIONS_DIR=%s", os.getenv("TORCH_EXTENSIONS_DIR"))
log.info("SEQTEX_SPACE_REPO=%s", SEQTEX_SPACE_REPO)
log.info("SEQTEX_MODEL_REPO=%s", SEQTEX_MODEL_REPO)
log.info("DEFAULT_WAN_MODEL_REPO=%s", DEFAULT_WAN_MODEL_REPO)
log.info("WAN_LOCAL_MODEL_DIR=%s", WAN_LOCAL_MODEL_DIR)
log.info("SEQTEX_SPACE_DIR=%s", SEQTEX_SPACE_DIR)
log.info("AUTO_PREWARM=%s", AUTO_PREWARM)
log.info("EXTRA_MODEL_REPOS=%s", EXTRA_MODEL_REPOS)
log.info("FORCE_LOCAL_GENERATE=%s", FORCE_LOCAL_GENERATE)
def _patch_zero_startup_report_timeout() -> None:
"""
Avoid rare startup crash when ZeroGPU local startup-report API times out.
This does not disable @spaces.GPU for Generate.
"""
if spaces is None:
return
try:
import httpx
import spaces.zero.client as zero_client
except Exception as exc:
print(f"[BOOT WARN] could not import ZeroGPU client for timeout patch: {exc}", flush=True)
return
original_startup_report = getattr(zero_client, "startup_report", None)
if original_startup_report is None or getattr(original_startup_report, "_seqtex_timeout_safe", False):
return
def timeout_safe_startup_report(*args, **kwargs):
last_exc: Exception | None = None
for attempt in range(1, 4):
try:
print(f"[BOOT ZG] ZeroGPU startup_report attempt {attempt}/3", flush=True)
return original_startup_report(*args, **kwargs)
except httpx.TimeoutException as exc:
last_exc = exc
print(f"[BOOT WARN] ZeroGPU startup_report timed out on attempt {attempt}/3: {exc}", flush=True)
time.sleep(1.0)
except Exception as exc:
name = exc.__class__.__name__.lower()
msg = str(exc).lower()
if "timeout" in name or "timed out" in msg:
last_exc = exc
print(f"[BOOT WARN] ZeroGPU startup_report timeout-like error on attempt {attempt}/3: {exc}", flush=True)
time.sleep(1.0)
continue
raise
print(
f"[BOOT WARN] ZeroGPU startup_report failed after retries; continuing launch anyway. Last error: {last_exc}",
flush=True,
)
return None
timeout_safe_startup_report._seqtex_timeout_safe = True
zero_client.startup_report = timeout_safe_startup_report
print("[BOOT 04B] ZeroGPU startup_report timeout patch installed", flush=True)
_patch_zero_startup_report_timeout()
# ---------------------------------------------------------------------------
# 2. Global state and helpers
# ---------------------------------------------------------------------------
_seqtex_modules: dict[str, Any] | None = None
_seqtex_pipe: Any | None = None
_prewarm_lock = threading.Lock()
_prewarm_thread: threading.Thread | None = None
_prewarm_status: dict[str, Any] = {
"state": "idle",
"started_at": None,
"finished_at": None,
"last_error": None,
"log": [],
}
MAX_PREWARM_LOG_LINES = 400
_CPP_EXTENSION_LOAD_PATCHED = False
_DIFFUSERS_WAN_CONFIG_PATCHED = False
_CACHE_PATCHED = False
_LOCAL_GENERATE_MODE = False
_CACHED_REPO_PATHS: dict[tuple[str, str], str] = {}
class StartupFixError(RuntimeError):
"""Friendly configuration/runtime error shown in the UI."""
def _gpu_decorator(duration: int = 120):
def _decorator(fn):
if spaces is None:
log.warning("spaces module unavailable; running without ZeroGPU decorator")
return fn
return spaces.GPU(duration=duration)(fn)
return _decorator
def _append_prewarm_log(message: str) -> None:
line = f"{time.strftime('%H:%M:%S')} | {message}"
log.info("PREWARM: %s", message)
_prewarm_status["log"].append(line)
if len(_prewarm_status["log"]) > MAX_PREWARM_LOG_LINES:
_prewarm_status["log"] = _prewarm_status["log"][-MAX_PREWARM_LOG_LINES:]
def _prewarm_log_text() -> str:
header = [
f"state: {_prewarm_status.get('state', 'unknown')}",
f"started_at: {_prewarm_status.get('started_at') or '-'}",
f"finished_at: {_prewarm_status.get('finished_at') or '-'}",
f"last_error: {_prewarm_status.get('last_error') or '-'}",
"",
"logs:",
]
return "\n".join(header + list(_prewarm_status.get("log", [])))
def _get_hf_token() -> str | None:
token = (
os.getenv("SEQTEX_SPACE_TOKEN")
or os.getenv("HF_TOKEN")
or os.getenv("HUGGINGFACE_HUB_TOKEN")
or os.getenv("HUGGING_FACE_HUB_TOKEN")
)
if token:
os.environ.setdefault("SEQTEX_SPACE_TOKEN", token)
os.environ.setdefault("HF_TOKEN", token)
os.environ.setdefault("HUGGINGFACE_HUB_TOKEN", token)
os.environ.setdefault("HUGGING_FACE_HUB_TOKEN", token)
return token
def _ensure_seqtex_token() -> None:
if not _get_hf_token():
raise StartupFixError(
"Missing Hugging Face token secret. Add SEQTEX_SPACE_TOKEN or HF_TOKEN "
"in Settings -> Variables and secrets."
)
def _prepare_runtime_dirs() -> None:
for path in [
os.getenv("HF_HOME"),
os.getenv("HF_HUB_CACHE"),
os.getenv("TORCH_HOME"),
os.getenv("TORCH_EXTENSIONS_DIR"),
SEQTEX_SPACE_DIR,
WAN_BUCKET_ROOT,
WAN_LOCAL_MODEL_DIR,
]:
if not path:
continue
os.makedirs(path, exist_ok=True)
_append_prewarm_log(f"directory ready: {path}")
def _skip_file(filename: str, ignore_patterns: list[str]) -> bool:
return any(fnmatch.fnmatch(filename, pattern) for pattern in ignore_patterns)
def _repo_name_tail(repo_id: str) -> str:
return repo_id.rstrip("/").split("/")[-1]
def _is_wan_repo(repo_id: str) -> bool:
return repo_id.strip() == DEFAULT_WAN_MODEL_REPO
def _is_wan_local_dir_ready() -> bool:
required = [
"model_index.json",
"scheduler/scheduler_config.json",
"text_encoder/config.json",
"text_encoder/model.safetensors.index.json",
"tokenizer/tokenizer_config.json",
"transformer/config.json",
"transformer/diffusion_pytorch_model.safetensors.index.json",
"vae/config.json",
"vae/diffusion_pytorch_model.safetensors",
]
return all(os.path.exists(os.path.join(WAN_LOCAL_MODEL_DIR, item)) for item in required)
def _download_repo_files_with_logs(
*,
repo_id: str,
repo_type: str,
local_dir: str | None = None,
ignore_patterns: list[str] | None = None,
progress: gr.Progress | None = None,
progress_start: float = 0.0,
progress_end: float = 1.0,
) -> None:
"""
Download/check a Hub repo file-by-file.
If local_dir is None, files go to HF_HUB_CACHE.
If local_dir is set, files go to that folder as a normal local snapshot.
"""
from huggingface_hub import HfApi, hf_hub_download
token = _get_hf_token()
ignore_patterns = ignore_patterns or []
_append_prewarm_log(f"listing {repo_type} repo: {repo_id}")
api = HfApi(token=token)
all_files = api.list_repo_files(repo_id=repo_id, repo_type=repo_type)
files = [f for f in all_files if not _skip_file(f, ignore_patterns)]
total = len(files)
_append_prewarm_log(f"{repo_id}: {total} files to check/download")
if total == 0:
return
if local_dir:
os.makedirs(local_dir, exist_ok=True)
_append_prewarm_log(f"{repo_id}: local_dir={local_dir}")
for index, filename in enumerate(files, start=1):
frac = progress_start + ((index - 1) / total) * (progress_end - progress_start)
if progress is not None:
progress(frac, desc=f"Caching {repo_id}: {index}/{total} {filename}")
_append_prewarm_log(f"{repo_id}: [{index}/{total}] {filename}")
kwargs: dict[str, Any] = {
"repo_id": repo_id,
"repo_type": repo_type,
"filename": filename,
"token": token,
"force_download": False,
"resume_download": True,
}
if local_dir:
kwargs["local_dir"] = local_dir
else:
kwargs["cache_dir"] = os.getenv("HF_HUB_CACHE")
try:
hf_hub_download(**kwargs)
except TypeError:
# Compatibility for older huggingface_hub versions.
kwargs.pop("resume_download", None)
kwargs.pop("force_download", None)
hf_hub_download(**kwargs)
if progress is not None:
progress(progress_end, desc=f"Cached {repo_id}")
_append_prewarm_log(f"finished caching {repo_id}")
def prewarm_cache_impl(progress: gr.Progress | None = None) -> str:
"""
CPU/network-only cache preparation.
This does not allocate ZeroGPU and does not load CUDA.
"""
with _prewarm_lock:
_prewarm_status["state"] = "running"
_prewarm_status["started_at"] = time.strftime("%Y-%m-%d %H:%M:%S")
_prewarm_status["finished_at"] = None
_prewarm_status["last_error"] = None
try:
if progress is not None:
progress(0.01, desc="Preparing persistent cache directories...")
_append_prewarm_log("cache preparation started")
_prepare_runtime_dirs()
_ensure_seqtex_token()
# 1) SeqTex Space helper code.
_download_repo_files_with_logs(
repo_id=SEQTEX_SPACE_REPO,
repo_type="space",
local_dir=SEQTEX_SPACE_DIR,
ignore_patterns=[
".git/*",
"__pycache__/*",
"*.png",
"*.jpg",
"*.jpeg",
"*.gif",
"*.mp4",
"*.webm",
"examples/*",
"outputs/*",
],
progress=progress,
progress_start=0.05,
progress_end=0.25,
)
# 2) SeqTex transformer into /data HF cache.
_download_repo_files_with_logs(
repo_id=SEQTEX_MODEL_REPO,
repo_type="model",
local_dir=None,
ignore_patterns=[],
progress=progress,
progress_start=0.25,
progress_end=0.45,
)
# 3) WAN model into the separate /wan-cache local snapshot.
# Ignore non-model assets; Diffusers only needs model_index + component folders.
if DEFAULT_WAN_MODEL_REPO:
_download_repo_files_with_logs(
repo_id=DEFAULT_WAN_MODEL_REPO,
repo_type="model",
local_dir=WAN_LOCAL_MODEL_DIR,
ignore_patterns=[
".git/*",
"assets/*",
"examples/*",
"*.png",
"*.jpg",
"*.jpeg",
"*.gif",
"*.mp4",
"*.webm",
],
progress=progress,
progress_start=0.45,
progress_end=0.95,
)
# 4) Extra repos, if any and not already the WAN repo.
extra_unique = [r for r in EXTRA_MODEL_REPOS if r and r != DEFAULT_WAN_MODEL_REPO]
if extra_unique:
span = 0.04 / max(len(extra_unique), 1)
start = 0.95
for repo in extra_unique:
end = min(0.99, start + span)
_download_repo_files_with_logs(
repo_id=repo,
repo_type="model",
local_dir=None,
ignore_patterns=[],
progress=progress,
progress_start=start,
progress_end=end,
)
start = end
if progress is not None:
progress(1.0, desc="Cache preparation complete")
_prewarm_status["state"] = "done"
_prewarm_status["finished_at"] = time.strftime("%Y-%m-%d %H:%M:%S")
_append_prewarm_log("cache preparation complete")
return _prewarm_log_text()
except Exception as exc:
tb = traceback.format_exc()
_prewarm_status["state"] = "error"
_prewarm_status["finished_at"] = time.strftime("%Y-%m-%d %H:%M:%S")
_prewarm_status["last_error"] = str(exc)
_append_prewarm_log(f"ERROR: {exc}")
log.error("Prewarm failed:\n%s", tb)
return _prewarm_log_text()
def prewarm_cache_ui(progress: gr.Progress = gr.Progress(track_tqdm=True)) -> str:
return prewarm_cache_impl(progress=progress)
def get_cache_status_ui() -> str:
return _prewarm_log_text()
def _auto_prewarm_worker() -> None:
try:
_append_prewarm_log(f"auto-prewarm will start after {AUTO_PREWARM_DELAY_SECONDS}s")
time.sleep(AUTO_PREWARM_DELAY_SECONDS)
prewarm_cache_impl(progress=None)
except Exception:
log.error("Auto-prewarm worker crashed:\n%s", traceback.format_exc())
def start_auto_prewarm_once() -> str:
global _prewarm_thread
if not AUTO_PREWARM:
_append_prewarm_log("auto-prewarm disabled by AUTO_PREWARM=0")
return _prewarm_log_text()
if _prewarm_thread is not None and _prewarm_thread.is_alive():
return _prewarm_log_text()
if _prewarm_status.get("state") in {"running", "done"}:
return _prewarm_log_text()
_prewarm_thread = threading.Thread(target=_auto_prewarm_worker, daemon=True, name="seqtex-auto-prewarm")
_prewarm_thread.start()
_append_prewarm_log("auto-prewarm thread started")
return _prewarm_log_text()
# ---------------------------------------------------------------------------
# 3. Build/runtime compatibility patches
# ---------------------------------------------------------------------------
def _clean_nvdiffrast_extension_cache() -> None:
ext_dir = os.getenv("TORCH_EXTENSIONS_DIR") or "/tmp/torch_extensions"
log.warning("Cleaning nvdiffrast extension cache under %s", ext_dir)
for mod_name in list(sys.modules.keys()):
if mod_name.startswith("nvdiffrast_plugin"):
sys.modules.pop(mod_name, None)
if not os.path.isdir(ext_dir):
return
for root, dirs, files in os.walk(ext_dir):
for dirname in list(dirs):
if "nvdiffrast" in dirname or dirname == "nvdiffrast_plugin":
path = os.path.join(root, dirname)
try:
shutil.rmtree(path, ignore_errors=True)
log.warning("Removed stale nvdiffrast extension dir: %s", path)
except Exception as exc:
log.warning("Could not remove %s: %s", path, exc)
for filename in files:
if "nvdiffrast" in filename or filename.startswith("nvdiffrast_plugin"):
path = os.path.join(root, filename)
try:
os.remove(path)
log.warning("Removed stale nvdiffrast extension file: %s", path)
except Exception as exc:
log.warning("Could not remove %s: %s", path, exc)
def _patch_torch_cpp_extension_load(torch_module: Any) -> None:
global _CPP_EXTENSION_LOAD_PATCHED
if _CPP_EXTENSION_LOAD_PATCHED:
return
try:
cpp_ext = torch_module.utils.cpp_extension
except Exception as exc:
log.warning("Could not access torch.utils.cpp_extension: %s", exc)
return
original_load = cpp_ext.load
def load_and_register(*args, **kwargs):
plugin_name = kwargs.get("name")
if plugin_name is None and args:
plugin_name = args[0]
module = original_load(*args, **kwargs)
if plugin_name and module is not None:
try:
sys.modules[str(plugin_name)] = module
module_file = getattr(module, "__file__", None)
if module_file:
module_dir = os.path.dirname(os.path.abspath(module_file))
parent_dir = os.path.dirname(module_dir)
for path in (module_dir, parent_dir):
if path and path not in sys.path:
sys.path.insert(0, path)
log.info("Registered Torch extension module %s from %s", plugin_name, module_file)
else:
log.info("Registered Torch extension module %s", plugin_name)
except Exception as exc:
log.warning("Torch extension registration failed for %s: %s", plugin_name, exc)
return module
cpp_ext.load = load_and_register
_CPP_EXTENSION_LOAD_PATCHED = True
log.info("Patched torch.utils.cpp_extension.load for nvdiffrast plugin registration")
def _reset_nvdiffrast_runtime_modules() -> None:
for mod_name in list(sys.modules.keys()):
if mod_name.startswith("nvdiffrast_plugin"):
sys.modules.pop(mod_name, None)
try:
ops_mod = sys.modules.get("nvdiffrast.torch.ops")
if ops_mod is not None and hasattr(ops_mod, "_cached_plugin"):
ops_mod._cached_plugin = {False: None, True: None}
log.warning("Reset nvdiffrast.torch.ops._cached_plugin")
except Exception as exc:
log.warning("Could not reset nvdiffrast cached plugin state: %s", exc)
def _patch_diffusers_wan_frozendict_config() -> None:
global _DIFFUSERS_WAN_CONFIG_PATCHED
if _DIFFUSERS_WAN_CONFIG_PATCHED:
return
try:
from diffusers.configuration_utils import FrozenDict
except Exception as exc:
log.warning("Could not import diffusers FrozenDict: %s", exc)
return
original_getattr = getattr(FrozenDict, "__getattr__", None)
def seqtex_frozendict_getattr(self, name):
try:
return self[name]
except Exception:
pass
if name == "scale_factor_temporal":
return 4
if name == "scale_factor_spatial":
return 8
if original_getattr is not None:
try:
return original_getattr(self, name)
except Exception:
pass
raise AttributeError(f"'FrozenDict' object has no attribute '{name}'")
try:
FrozenDict.__getattr__ = seqtex_frozendict_getattr
log.info("Patched diffusers FrozenDict for Wan scale_factor_temporal/scale_factor_spatial compatibility")
except Exception as exc:
log.warning("Could not patch FrozenDict.__getattr__: %s", exc)
try:
from diffusers.pipelines.wan.pipeline_wan import WanPipeline
original_init = getattr(WanPipeline, "__init__", None)
if original_init is not None and not getattr(original_init, "_seqtex_wan_config_patch", False):
def seqtex_wan_init(self, tokenizer, text_encoder, transformer, vae, scheduler):
if vae is not None:
try:
cfg = getattr(vae, "config", None)
missing_temporal = True
missing_spatial = True
if cfg is not None:
try:
getattr(cfg, "scale_factor_temporal")
missing_temporal = False
except Exception:
missing_temporal = True
try:
getattr(cfg, "scale_factor_spatial")
missing_spatial = False
except Exception:
missing_spatial = True
if hasattr(vae, "register_to_config"):
patch_kwargs = {}
if missing_temporal:
patch_kwargs["scale_factor_temporal"] = 4
if missing_spatial:
patch_kwargs["scale_factor_spatial"] = 8
if patch_kwargs:
vae.register_to_config(**patch_kwargs)
log.info("Registered missing WAN VAE config values: %s", patch_kwargs)
except Exception as exc:
log.warning("Could not register WAN VAE scale factors on VAE config: %s", exc)
return original_init(self, tokenizer, text_encoder, transformer, vae, scheduler)
seqtex_wan_init._seqtex_wan_config_patch = True
WanPipeline.__init__ = seqtex_wan_init
log.info("Patched Diffusers WanPipeline.__init__ for SeqTex VAE compatibility")
except Exception as exc:
log.warning("Could not patch WanPipeline.__init__: %s", exc)
# Patch the SeqTex custom WanT2TexPipeline.components property.
#
# Diffusers 0.38's DiffusionPipeline.components is stricter than the
# custom SeqTex pipeline expects. The SeqTex/WAN config contains extra
# non-module config keys such as boundary_ratio and expand_timesteps, plus
# optional transformer_2. During TEX_PIPE.to("cuda"), Diffusers calls
# self.components and raises:
#
# Expected ['scheduler', 'text_encoder', 'tokenizer', 'transformer', 'vae']
# but ['boundary_ratio', 'expand_timesteps', ..., 'transformer_2', ...]
# are defined.
#
# For .to("cuda"), only real pipeline modules are needed, so expose exactly
# the core modules Diffusers expects for this custom pipeline.
try:
import wan.pipeline_wan_t2tex_extra as seqtex_wan_extra
WanT2TexPipeline = getattr(seqtex_wan_extra, "WanT2TexPipeline", None)
if WanT2TexPipeline is not None and not getattr(WanT2TexPipeline, "_seqtex_components_patch", False):
def seqtex_components(self):
component_names = ["scheduler", "text_encoder", "tokenizer", "transformer", "vae"]
components = {}
for component_name in component_names:
if hasattr(self, component_name):
components[component_name] = getattr(self, component_name)
return components
WanT2TexPipeline.components = property(seqtex_components)
WanT2TexPipeline._seqtex_components_patch = True
log.info("Patched SeqTex WanT2TexPipeline.components to ignore non-module config keys")
except Exception as exc:
log.warning("Could not patch SeqTex WanT2TexPipeline.components: %s", exc)
_DIFFUSERS_WAN_CONFIG_PATCHED = True
# ---------------------------------------------------------------------------
# 4. Cached/local-only model loading
# ---------------------------------------------------------------------------
def _required_model_repos() -> list[str]:
repos: list[str] = []
for repo in [SEQTEX_MODEL_REPO, DEFAULT_WAN_MODEL_REPO, *EXTRA_MODEL_REPOS]:
repo = (repo or "").strip()
if repo and repo not in repos:
repos.append(repo)
return repos
def _cached_snapshot_path(repo_id: str, repo_type: str = "model") -> str:
if _is_wan_repo(repo_id):
if not _is_wan_local_dir_ready():
raise FileNotFoundError(
f"WAN model local directory is incomplete: {WAN_LOCAL_MODEL_DIR}. "
"Run Cache / Startup -> Prepare cache now."
)
return WAN_LOCAL_MODEL_DIR
key = (repo_id, repo_type)
if key in _CACHED_REPO_PATHS and os.path.isdir(_CACHED_REPO_PATHS[key]):
return _CACHED_REPO_PATHS[key]
from huggingface_hub import snapshot_download
path = snapshot_download(
repo_id=repo_id,
repo_type=repo_type,
cache_dir=os.getenv("HF_HUB_CACHE"),
token=_get_hf_token(),
local_files_only=True,
)
_CACHED_REPO_PATHS[key] = path
return path
def _assert_cached_models_ready() -> None:
missing: list[str] = []
for repo in _required_model_repos():
try:
local_path = _cached_snapshot_path(repo, "model")
log.info("Cached model ready: %s -> %s", repo, local_path)
except Exception as exc:
log.warning("Required cached model missing/incomplete: %s (%s)", repo, exc)
missing.append(repo)
if missing:
raise StartupFixError(
"Required model cache is missing or incomplete: "
+ ", ".join(missing)
+ ". Open Cache / Startup and click Prepare cache now. "
+ "Do not click Generate until cache preparation says state: done."
)
def _repo_to_local_path_if_cached(path_or_repo: Any) -> Any:
if not isinstance(path_or_repo, str):
return path_or_repo
if os.path.exists(path_or_repo):
return path_or_repo
known_repos = [SEQTEX_MODEL_REPO, DEFAULT_WAN_MODEL_REPO, *EXTRA_MODEL_REPOS]
for repo_id in known_repos:
if path_or_repo == repo_id:
try:
local_path = _cached_snapshot_path(repo_id, "model")
log.info("Using cached model path for %s: %s", repo_id, local_path)
return local_path
except Exception as exc:
if FORCE_LOCAL_GENERATE:
raise StartupFixError(
f"Model repo {repo_id} is not fully cached. "
"Run Cache / Startup -> Prepare cache now first."
) from exc
log.warning("Could not resolve cached path for %s: %s", repo_id, exc)
return path_or_repo
return path_or_repo
def _patch_cached_model_loading() -> None:
global _CACHE_PATCHED
if _CACHE_PATCHED:
return
cache_dir = os.getenv("HF_HUB_CACHE")
def add_cached_kwargs(kwargs: dict[str, Any]) -> dict[str, Any]:
kwargs.setdefault("cache_dir", cache_dir)
token = _get_hf_token()
if token:
kwargs.setdefault("token", token)
if _LOCAL_GENERATE_MODE and FORCE_LOCAL_GENERATE:
kwargs["local_files_only"] = True
kwargs.pop("force_download", None)
return kwargs
try:
import huggingface_hub
if not getattr(huggingface_hub.hf_hub_download, "_seqtex_cache_patch", False):
original_hf_hub_download = huggingface_hub.hf_hub_download
def cached_hf_hub_download(*args, **kwargs):
add_cached_kwargs(kwargs)
return original_hf_hub_download(*args, **kwargs)
cached_hf_hub_download._seqtex_cache_patch = True
huggingface_hub.hf_hub_download = cached_hf_hub_download
if not getattr(huggingface_hub.snapshot_download, "_seqtex_cache_patch", False):
original_snapshot_download = huggingface_hub.snapshot_download
def cached_snapshot_download(*args, **kwargs):
add_cached_kwargs(kwargs)
return original_snapshot_download(*args, **kwargs)
cached_snapshot_download._seqtex_cache_patch = True
huggingface_hub.snapshot_download = cached_snapshot_download
log.info("Patched huggingface_hub downloads for cached Generate mode")
except Exception as exc:
log.warning("Could not patch huggingface_hub cached loading: %s", exc)
try:
from diffusers import DiffusionPipeline
from diffusers.models.modeling_utils import ModelMixin
def patch_classmethod(cls: Any, attr: str, label: str) -> None:
current = getattr(cls, attr)
underlying = getattr(current, "__func__", current)
if getattr(underlying, "_seqtex_cache_patch", False):
return
def cached_from_pretrained(inner_cls, pretrained_model_name_or_path=None, *model_args, **kwargs):
if _LOCAL_GENERATE_MODE and FORCE_LOCAL_GENERATE:
pretrained_model_name_or_path = _repo_to_local_path_if_cached(pretrained_model_name_or_path)
add_cached_kwargs(kwargs)
return underlying(inner_cls, pretrained_model_name_or_path, *model_args, **kwargs)
cached_from_pretrained._seqtex_cache_patch = True
setattr(cls, attr, classmethod(cached_from_pretrained))
log.info("Patched %s.%s for cached Generate mode", label, attr)
patch_classmethod(DiffusionPipeline, "from_pretrained", "DiffusionPipeline")
patch_classmethod(ModelMixin, "from_pretrained", "ModelMixin")
except Exception as exc:
log.warning("Could not patch Diffusers from_pretrained: %s", exc)
try:
from transformers import AutoTokenizer, PreTrainedModel, PreTrainedTokenizerBase
def patch_transformers_classmethod(cls: Any, attr: str, label: str) -> None:
current = getattr(cls, attr)
underlying = getattr(current, "__func__", current)
if getattr(underlying, "_seqtex_cache_patch", False):
return
def cached_from_pretrained(inner_cls, pretrained_model_name_or_path=None, *model_args, **kwargs):
if _LOCAL_GENERATE_MODE and FORCE_LOCAL_GENERATE:
pretrained_model_name_or_path = _repo_to_local_path_if_cached(pretrained_model_name_or_path)
add_cached_kwargs(kwargs)
return underlying(inner_cls, pretrained_model_name_or_path, *model_args, **kwargs)
cached_from_pretrained._seqtex_cache_patch = True
setattr(cls, attr, classmethod(cached_from_pretrained))
log.info("Patched %s.%s for cached Generate mode", label, attr)
patch_transformers_classmethod(AutoTokenizer, "from_pretrained", "AutoTokenizer")
patch_transformers_classmethod(PreTrainedModel, "from_pretrained", "PreTrainedModel")
patch_transformers_classmethod(PreTrainedTokenizerBase, "from_pretrained", "PreTrainedTokenizerBase")
except Exception as exc:
log.warning("Could not patch Transformers from_pretrained: %s", exc)
try:
tex_mod = sys.modules.get("utils.texture_generation")
if tex_mod is not None:
import huggingface_hub
if hasattr(tex_mod, "hf_hub_download"):
tex_mod.hf_hub_download = huggingface_hub.hf_hub_download
if hasattr(tex_mod, "snapshot_download"):
tex_mod.snapshot_download = huggingface_hub.snapshot_download
log.info("Patched SeqTex texture_generation hub helpers for cached mode")
except Exception as exc:
log.warning("Could not patch SeqTex module hub helpers: %s", exc)
_CACHE_PATCHED = True
class _LocalGenerateMode:
def __enter__(self):
global _LOCAL_GENERATE_MODE
self.old_local = _LOCAL_GENERATE_MODE
self.old_env = {
"HF_HUB_OFFLINE": os.environ.get("HF_HUB_OFFLINE"),
"TRANSFORMERS_OFFLINE": os.environ.get("TRANSFORMERS_OFFLINE"),
"DIFFUSERS_OFFLINE": os.environ.get("DIFFUSERS_OFFLINE"),
}
_LOCAL_GENERATE_MODE = True
if FORCE_LOCAL_GENERATE:
os.environ["HF_HUB_OFFLINE"] = "1"
os.environ["TRANSFORMERS_OFFLINE"] = "1"
os.environ["DIFFUSERS_OFFLINE"] = "1"
return self
def __exit__(self, exc_type, exc, tb):
global _LOCAL_GENERATE_MODE
_LOCAL_GENERATE_MODE = self.old_local
for key, value in self.old_env.items():
if value is None:
os.environ.pop(key, None)
else:
os.environ[key] = value
return False
# ---------------------------------------------------------------------------
# 5. Lazy SeqTex import/model loading
# ---------------------------------------------------------------------------
def _bootstrap_seqtex_utils() -> None:
marker = os.path.join(SEQTEX_SPACE_DIR, "utils", "mesh_utils.py")
if os.path.isfile(marker):
log.info("SeqTex utilities already present at %s", SEQTEX_SPACE_DIR)
else:
_append_prewarm_log("SeqTex utilities missing; downloading before generation")
prewarm_cache_impl(progress=None)
if SEQTEX_SPACE_DIR not in sys.path:
sys.path.insert(0, SEQTEX_SPACE_DIR)
def _load_seqtex_modules() -> dict[str, Any]:
global _seqtex_modules
if _seqtex_modules is not None:
return _seqtex_modules
_bootstrap_seqtex_utils()
log.info("Importing SeqTex modules lazily...")
mesh_utils = importlib.import_module("utils.mesh_utils")
render_utils = importlib.import_module("utils.render_utils")
texture_generation = importlib.import_module("utils.texture_generation")
torch = importlib.import_module("torch")
_patch_torch_cpp_extension_load(torch)
_patch_diffusers_wan_frozendict_config()
_patch_cached_model_loading()
np = importlib.import_module("numpy")
_seqtex_modules = {
"Mesh": mesh_utils.Mesh,
"get_mvp_matrix": render_utils.get_mvp_matrix,
"render_geo_map": render_utils.render_geo_map,
"render_geo_views_tensor": render_utils.render_geo_views_tensor,
"get_seqtex_pipe": texture_generation.get_seqtex_pipe,
"encode_images": texture_generation.encode_images,
"decode_images": texture_generation.decode_images,
"convert_img_to_tensor": texture_generation.convert_img_to_tensor,
"texture_generation_module": texture_generation,
"torch": torch,
"np": np,
}
log.info("SeqTex modules imported successfully")
return _seqtex_modules
def _get_seqtex_pipe() -> Any:
global _seqtex_pipe
_ensure_seqtex_token()
modules = _load_seqtex_modules()
if _seqtex_pipe is None:
_patch_diffusers_wan_frozendict_config()
_patch_cached_model_loading()
if FORCE_LOCAL_GENERATE:
log.info("Checking required cached model snapshots before loading pipeline...")
_assert_cached_models_ready()
log.info(
"Loading SeqTex pipeline onto GPU using cached model files only. "
"WAN local model dir: %s",
WAN_LOCAL_MODEL_DIR,
)
with _LocalGenerateMode():
_seqtex_pipe = modules["get_seqtex_pipe"]()
log.info("SeqTex pipeline loaded")
return _seqtex_pipe
# ---------------------------------------------------------------------------
# 6. Mesh processing
# ---------------------------------------------------------------------------
def step1_process_mesh(
glb_path: str,
upside_down: bool,
uv_size: int,
mv_size: int,
progress: gr.Progress | None = None,
) -> tuple[dict[str, Any], Image.Image]:
modules = _load_seqtex_modules()
Mesh = modules["Mesh"]
get_mvp_matrix = modules["get_mvp_matrix"]
render_geo_views_tensor = modules["render_geo_views_tensor"]
render_geo_map = modules["render_geo_map"]
torch = modules["torch"]
np = modules["np"]
device = "cuda"
log.info("Step 1: loading mesh from %s", glb_path)
if progress is not None:
progress(0.08, desc="Loading mesh and generating UVs if needed...")
# Do not pass Gradio Progress into official Mesh helper.
mesh = Mesh(glb_path, uv_tool="xAtlas", device=device)
if progress is not None:
progress(0.16, desc="Applying Z-UP orientation and normalizing mesh...")
mesh.vertex_transform()
if upside_down:
mesh.vertex_transform_upsidedown()
mesh.normalize()
img_size = (int(mv_size), int(mv_size))
uv_sz = (int(uv_size), int(uv_size))
try:
mvp_matrix, w2c = get_mvp_matrix(mesh, num_views=4, width=int(mv_size), height=int(mv_size))
except TypeError:
mvp_matrix, w2c = get_mvp_matrix(mesh)
mvp_matrix = mvp_matrix.to(device)
w2c = w2c.to(device)
if progress is not None:
progress(0.24, desc="Rendering geometry views / compiling rasterizer if needed...")
try:
pos_imgs, norm_imgs, mask_imgs = render_geo_views_tensor(mesh, mvp_matrix, img_size)
except ModuleNotFoundError as exc:
if "nvdiffrast_plugin" not in str(exc):
raise
log.warning("nvdiffrast_plugin not importable. Cleaning extension cache and retrying once.")
if progress is not None:
progress(0.25, desc="Cleaning stale nvdiffrast plugin cache and retrying...")
_clean_nvdiffrast_extension_cache()
_reset_nvdiffrast_runtime_modules()
pos_imgs, norm_imgs, mask_imgs = render_geo_views_tensor(mesh, mvp_matrix, img_size)
if progress is not None:
progress(0.30, desc="Rendering UV-space geometry maps...")
pos_map, norm_map = render_geo_map(mesh, map_size=uv_sz)
def _save_tensor(tensor: Any, prefix: str) -> str:
f = tempfile.NamedTemporaryFile(delete=False, suffix=".pt", prefix=f"{prefix}_")
torch.save(tensor.detach().cpu(), f.name)
f.close()
return f.name
mesh_cpu = mesh.to("cpu")
mesh_file = tempfile.NamedTemporaryFile(delete=False, suffix="_processed_mesh.pkl", prefix="seqtex_mesh_")
with open(mesh_file.name, "wb") as f:
pickle.dump(mesh_cpu, f)
result = {
"pos_imgs": _save_tensor(pos_imgs, "pos_imgs"),
"norm_imgs": _save_tensor(norm_imgs, "norm_imgs"),
"mask_imgs": _save_tensor(mask_imgs, "mask_imgs"),
"pos_map": _save_tensor(pos_map, "pos_map"),
"norm_map": _save_tensor(norm_map, "norm_map"),
"w2c": _save_tensor(w2c, "w2c"),
"mvp": _save_tensor(mvp_matrix, "mvp"),
"mesh_pkl": mesh_file.name,
"uv_size": int(uv_size),
"mv_size": int(mv_size),
}
norm_np = norm_imgs.detach().cpu().numpy()
norm_np = (norm_np * 0.5 + 0.5).clip(0, 1)
tiles = [Image.fromarray((norm_np[i] * 255).astype(np.uint8)) for i in range(min(4, norm_np.shape[0]))]
w, h = tiles[0].size
preview = Image.new("RGB", (w * len(tiles), h))
for i, tile in enumerate(tiles):
preview.paste(tile, (i * w, 0))
log.info("Step 1 complete")
return result, preview
# ---------------------------------------------------------------------------
# 7. SeqTex generation
# ---------------------------------------------------------------------------
def step2_generate_texture(
geo_data: dict[str, Any],
condition_image: Image.Image,
text_prompt: str,
seed: int,
steps: int,
guidance_scale: float,
num_views: int,
progress: gr.Progress | None = None,
) -> Image.Image:
modules = _load_seqtex_modules()
torch = modules["torch"]
np = modules["np"]
encode_images = modules["encode_images"]
decode_images = modules["decode_images"]
convert_img_to_tensor = modules["convert_img_to_tensor"]
texture_generation_module = modules.get("texture_generation_module")
device = "cuda"
mv_size = int(geo_data["mv_size"])
uv_size = int(geo_data["uv_size"])
if progress is not None:
progress(0.36, desc="Loading SeqTex pipeline onto GPU from local cache...")
pipe = _get_seqtex_pipe()
def _load_tensor(path: str) -> Any:
return torch.load(path, map_location=device)
pos_imgs = _load_tensor(geo_data["pos_imgs"])
norm_imgs = _load_tensor(geo_data["norm_imgs"])
pos_map = _load_tensor(geo_data["pos_map"])
norm_map = _load_tensor(geo_data["norm_map"])
if progress is not None:
progress(0.46, desc="Encoding geometry latents...")
def _to_bfhwc(frames: Any, name: str) -> Any:
"""Normalize geometry tensors to SeqTex encode_images() shape [B, F, H, W, C].
render_geo_views_tensor() usually returns multi-view tensors as [F, H, W, C].
render_geo_map() may return UV maps either as [H, W, C] or [1, H, W, C]
depending on the SeqTex utility version. The previous code added an extra
unsqueeze for UV maps, which produced [1, 1, 1, H, W, C] and caused:
einops.EinopsError: expected 5 dims, received 6 dims
This helper makes all inputs exactly [B, F, H, W, C].
"""
ndim = getattr(frames, "ndim", None)
if ndim == 3:
# [H, W, C] -> [1, 1, H, W, C]
return frames.unsqueeze(0).unsqueeze(0)
if ndim == 4:
# [F, H, W, C] -> [1, F, H, W, C]
return frames.unsqueeze(0)
if ndim == 5:
# Already [B, F, H, W, C]
return frames
shape = getattr(frames, "shape", "unknown")
raise ValueError(f"{name} has unsupported shape for encode_images: {shape}")
def _reset_wan_vae_encode_cache() -> None:
"""Clear WAN VAE's stateful encode cache before each independent encode.
Diffusers' AutoencoderKLWan keeps an internal feature cache while encoding
video chunks. SeqTex calls encode_images() several times for unrelated
tensors: multi-view geometry, UV geometry, and the condition image.
Resetting the cache between those calls prevents stale cached feature maps
from leaking into the next encode.
"""
if texture_generation_module is None:
return
vae = getattr(texture_generation_module, "VAE", None)
if vae is None:
return
for attr in ("_enc_feat_map", "_enc_conv_idx", "_dec_feat_map", "_dec_conv_idx"):
if hasattr(vae, attr):
try:
setattr(vae, attr, None)
except Exception:
pass
def _enc(frames: Any, name: str) -> Any:
frames_5d = _to_bfhwc(frames, name)
# Force float32 on CUDA. PIL/NumPy condition tensors can become float64;
# Conv3D with float64 on CUDA falls back to aten::slow_conv3d_forward,
# which is CPU-only in the ZeroGPU PyTorch build.
frames_5d = frames_5d.to(device=device, dtype=torch.float32, non_blocking=True).contiguous()
log.info("Encoding %s with shape %s dtype=%s device=%s", name, tuple(frames_5d.shape), frames_5d.dtype, frames_5d.device)
_reset_wan_vae_encode_cache()
return encode_images(frames_5d, encode_as_first=True)
nat_pos_lat = _enc(pos_imgs, "pos_imgs")
nat_norm_lat = _enc(norm_imgs, "norm_imgs")
uv_pos_lat = _enc(pos_map, "pos_map")
uv_norm_lat = _enc(norm_map, "norm_map")
nat_geo = torch.cat([nat_pos_lat, nat_norm_lat], dim=1)
uv_geo = torch.cat([uv_pos_lat, uv_norm_lat], dim=1)
cond_model_latents = (nat_geo, uv_geo)
del nat_pos_lat, nat_norm_lat, uv_pos_lat, uv_norm_lat
torch.cuda.empty_cache()
if progress is not None:
progress(0.56, desc="Encoding reference image...")
cond_pil = condition_image.convert("RGB").resize((mv_size, mv_size), Image.LANCZOS)
cond_t = convert_img_to_tensor(cond_pil, device=device)
gt_latent = _enc(cond_t, "condition_image")
gt_condition = (gt_latent, None)
del cond_t
torch.cuda.empty_cache()
text_prompt = (text_prompt or "high quality texture, clean details").strip()
temporal_downsample = getattr(pipe.vae.config, "temperal_downsample", [2, 2])
frame_factor = 2 ** sum(temporal_downsample)
num_frames = int(num_views) * frame_factor
uv_num_frames = 1 * frame_factor
if progress is not None:
progress(0.66, desc=f"Running SeqTex diffusion ({int(steps)} steps)...")
with torch.inference_mode():
latents = pipe(
prompt=text_prompt,
negative_prompt=None,
num_frames=num_frames,
generator=torch.Generator(device=device).manual_seed(int(seed)),
num_inference_steps=int(steps),
guidance_scale=float(guidance_scale),
height=mv_size,
width=mv_size,
output_type="latent",
cond_model_latents=cond_model_latents,
uv_height=uv_size,
uv_width=uv_size,
uv_num_frames=uv_num_frames,
treat_as_first=True,
gt_condition=gt_condition,
inference_img_cond_frame=0,
use_qk_geometry=True,
max_sequence_length=1024,
task_type="img2tex",
).frames
del cond_model_latents, gt_latent, gt_condition
torch.cuda.empty_cache()
mv_latents, uv_latents = latents
if progress is not None:
progress(0.84, desc="Decoding UV texture...")
uv_frames = decode_images(uv_latents, decode_as_first=True)
del uv_latents, mv_latents
torch.cuda.empty_cache()
uv_pred = uv_frames[:, :, -1, ...].squeeze(0).clamp(0.0, 1.0).cpu()
uv_np = (uv_pred.permute(1, 2, 0).numpy() * 255).astype(np.uint8)
uv_pil = Image.fromarray(uv_np).convert("RGB")
log.info("Step 2 complete")
return uv_pil
# ---------------------------------------------------------------------------
# 8. Export
# ---------------------------------------------------------------------------
def step3_export_glb(geo_data: dict[str, Any], uv_texture: Image.Image) -> str:
modules = _load_seqtex_modules()
Mesh = modules["Mesh"]
with open(geo_data["mesh_pkl"], "rb") as f:
mesh = pickle.load(f)
out = tempfile.NamedTemporaryFile(delete=False, suffix="_textured.glb", prefix="seqtex_")
out.close()
Mesh.export(mesh, save_path=out.name, texture_map=uv_texture)
log.info("Exported textured GLB to %s", out.name)
return out.name
# ---------------------------------------------------------------------------
# 9. Main Generate handler — only ZeroGPU path
# ---------------------------------------------------------------------------
@_gpu_decorator(duration=120)
def run(
glb_file: str | None,
condition_image: Image.Image | None,
text_prompt: str,
seed: int,
steps: int,
guidance_scale: float,
num_views: int,
upside_down: bool,
uv_size: int,
mv_size: int,
progress: gr.Progress = gr.Progress(track_tqdm=True),
):
if glb_file is None:
raise gr.Error("Please upload a GLB mesh.")
if condition_image is None:
raise gr.Error("Please upload a front-view reference image.")
try:
glb_path = glb_file if isinstance(glb_file, str) else glb_file.name
log.info("Generate clicked: glb=%s uv=%s mv=%s steps=%s", glb_path, uv_size, mv_size, steps)
progress(0.02, desc="Preparing runtime cache directories...")
_prepare_runtime_dirs()
if FORCE_LOCAL_GENERATE:
progress(0.03, desc="Verifying cached model snapshots...")
_assert_cached_models_ready()
progress(0.04, desc="Preparing SeqTex utilities...")
_load_seqtex_modules()
geo_data, preview = step1_process_mesh(
glb_path=glb_path,
upside_down=bool(upside_down),
uv_size=int(uv_size),
mv_size=int(mv_size),
progress=progress,
)
uv_texture = step2_generate_texture(
geo_data=geo_data,
condition_image=condition_image,
text_prompt=text_prompt,
seed=int(seed),
steps=int(steps),
guidance_scale=float(guidance_scale),
num_views=int(num_views),
progress=progress,
)
progress(0.92, desc="Baking texture into GLB...")
textured_glb = step3_export_glb(geo_data, uv_texture)
progress(1.0, desc="Done")
return textured_glb, uv_texture, preview, textured_glb
except StartupFixError as exc:
log.error("Configuration error: %s", exc)
raise gr.Error(str(exc)) from exc
except Exception as exc:
tb = traceback.format_exc()
log.error("Generation failed:\n%s", tb)
raise gr.Error(f"Generation failed: {exc}\n\nCheck the Container logs for the full traceback.") from exc
# ---------------------------------------------------------------------------
# 10. UI
# ---------------------------------------------------------------------------
print("[BOOT 05] building gradio UI", flush=True)
with gr.Blocks(title="SeqTex Texture Generator", theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
# 🎨 SeqTex Texture Generator
Upload an **untextured GLB mesh** and a **front-view reference image**.
This version stores the WAN base model in a separate bucket mounted at
`/wan-cache`, while keeping the normal Hugging Face cache at `/data`.
During Generate, model loading is forced to use cached/local files only.
"""
)
with gr.Tab("Generate Texture"):
with gr.Row():
with gr.Column(scale=1):
glb_input = gr.File(
label="Input GLB Mesh",
file_types=[".glb"],
type="filepath",
)
cond_image = gr.Image(
label="Front-View Reference Image",
type="pil",
sources=["upload", "clipboard"],
height=300,
)
text_prompt = gr.Textbox(
label="Text prompt",
placeholder="e.g. anime character, colorful clothing, high quality",
value="high quality texture, clean details",
lines=2,
)
run_btn = gr.Button("✨ Generate Texture", variant="primary", size="lg")
with gr.Accordion("Advanced settings", open=False):
seed = gr.Slider(label="Seed", minimum=0, maximum=2**31 - 1, value=42, step=1)
steps = gr.Slider(
label="Diffusion steps",
minimum=5,
maximum=30,
value=10,
step=1,
info="Use 10 on ZeroGPU. Higher values need more time/VRAM.",
)
guidance_scale = gr.Slider(
label="Guidance scale",
minimum=1.0,
maximum=10.0,
value=1.0,
step=0.5,
info="SeqTex usually works best around 1.0.",
)
num_views = gr.Slider(
label="Multi-view count",
minimum=2,
maximum=4,
value=4,
step=1,
info="4 matches the reference SeqTex Space.",
)
upside_down = gr.Checkbox(
label="Flip mesh upside-down",
value=False,
info="Enable only if your mesh appears inverted.",
)
uv_size = gr.Radio(
label="UV texture resolution",
choices=[512, 1024],
value=1024,
info="2048 is disabled for ZeroGPU stability.",
)
mv_size = gr.Radio(
label="Multi-view render resolution",
choices=[256, 512],
value=512,
)
with gr.Column(scale=1):
output_3d = gr.Model3D(
label="Textured 3-D Model",
height=450,
clear_color=[0.15, 0.15, 0.15, 1.0],
)
uv_preview = gr.Image(
label="Generated UV Texture Map",
type="pil",
interactive=False,
height=256,
)
geo_preview = gr.Image(
label="Geometry Views Preview",
type="pil",
interactive=False,
height=150,
)
download_btn = gr.File(label="Download Textured GLB")
run_btn.click(
fn=run,
inputs=[
glb_input,
cond_image,
text_prompt,
seed,
steps,
guidance_scale,
num_views,
upside_down,
uv_size,
mv_size,
],
outputs=[output_3d, uv_preview, geo_preview, download_btn],
)
with gr.Tab("Cache / Startup"):
gr.Markdown(
"""
## Cache preparation
This downloads/checks:
- SeqTex helper code -> `/data/seqtex_space`
- SeqTex transformer -> `/data/hf_home/hub`
- WAN base model -> `/wan-cache/models/Wan-AI/Wan2.1-T2V-1.3B-Diffusers`
This uses CPU/network runtime only. It does **not** allocate ZeroGPU.
Press **Prepare cache now** before generating.
"""
)
with gr.Row():
prewarm_btn = gr.Button("Prepare cache now", variant="primary")
refresh_cache_btn = gr.Button("Refresh cache status")
cache_log_box = gr.Textbox(
label="Cache preparation logs",
value=_prewarm_log_text(),
lines=22,
max_lines=35,
interactive=False,
)
prewarm_btn.click(fn=prewarm_cache_ui, inputs=[], outputs=[cache_log_box])
refresh_cache_btn.click(fn=get_cache_status_ui, inputs=[], outputs=[cache_log_box])
gr.Markdown(
f"""
---
**Recommended Space setup**
1. Keep or mount a bucket at `/data`.
2. Create/mount the new WAN bucket at `/wan-cache`.
3. Add a Space secret named `HF_TOKEN` or `SEQTEX_SPACE_TOKEN`.
4. Keep `AUTO_PREWARM=0` while debugging. Use **Prepare cache now** manually.
5. Keep `FORCE_LOCAL_GENERATE=1`.
Current WAN local model directory:
`{WAN_LOCAL_MODEL_DIR}`
If Generate says the cache is incomplete, run **Prepare cache now** again.
"""
)
if AUTO_PREWARM:
demo.load(fn=start_auto_prewarm_once, inputs=[], outputs=[cache_log_box])
print("[BOOT 06] gradio UI built", flush=True)
if __name__ == "__main__":
port = int(os.getenv("PORT", "7860"))
print(f"[BOOT 07] launching gradio on 0.0.0.0:{port} with SSR disabled", flush=True)
queued_demo = demo.queue(max_size=3)
try:
queued_demo.launch(
server_name="0.0.0.0",
server_port=port,
ssr_mode=False,
show_api=False,
)
except TypeError as exc:
print(f"[BOOT WARN] launch kwargs fallback because: {exc}", flush=True)
queued_demo.launch(
server_name="0.0.0.0",
server_port=port,
)