open3dforge / src /workspace.py
Reverb's picture
Fix state persistence across ZeroGPU subprocess boundary
299ba9b
Raw
History Blame Contribute Delete
10.2 kB
"""
Workspace management for Open3DForge.
Single-user pattern: one persistent workspace folder.
No session IDs, no multi-tenancy.
Layout:
workspace/
current/ -- active work, overwritten per generation
exports/ -- finished asset zips, kept indefinitely
presets/ -- saved parameter configurations (JSON)
history/ -- thumbnails + metadata of past assets
"""
from __future__ import annotations
import json
import os
import shutil
import time
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
# ---------------------------------------------------------------------------
# Paths
# ---------------------------------------------------------------------------
ROOT = Path(__file__).resolve().parent.parent
WORKSPACE = ROOT / "workspace"
CURRENT = WORKSPACE / "current"
EXPORTS = WORKSPACE / "exports"
PRESETS = WORKSPACE / "presets"
HISTORY = WORKSPACE / "history"
# Subdirectories of `current/` created on each generation
CURRENT_TEXTURES = CURRENT / "textures"
CURRENT_LODS = CURRENT / "lods"
def ensure_dirs() -> None:
"""Create all workspace directories if missing. Safe to call repeatedly."""
for p in (WORKSPACE, CURRENT, EXPORTS, PRESETS, HISTORY,
CURRENT_TEXTURES, CURRENT_LODS):
p.mkdir(parents=True, exist_ok=True)
def reset_current() -> None:
"""Clear `current/` for a fresh asset. Called at the start of generation."""
if CURRENT.exists():
shutil.rmtree(CURRENT)
CURRENT.mkdir(parents=True, exist_ok=True)
CURRENT_TEXTURES.mkdir(parents=True, exist_ok=True)
CURRENT_LODS.mkdir(parents=True, exist_ok=True)
# ---------------------------------------------------------------------------
# Asset state (in-memory representation of what's in `current/`)
# ---------------------------------------------------------------------------
@dataclass
class AssetState:
"""Tracks what files exist in `current/` and the pipeline progress.
Updated by each stage as it completes. Used by the UI to enable/disable
buttons and show status indicators.
"""
# Input
input_images: list[Path] = field(default_factory=list)
# Stage 1 outputs
high_poly_glb: Path | None = None # raw TRELLIS.2 output, kept for baking
raw_gen_glb: Path | None = None # decimated base from generation
# Stage 2 outputs
repaired_glb: Path | None = None
cleaned_glb: Path | None = None
low_poly_glb: Path | None = None # post-decimation working mesh
unwrapped_glb: Path | None = None # has UV coordinates
final_glb: Path | None = None # all post-processing done
# Textures (Stage 2 baking)
albedo_png: Path | None = None
normal_gl_png: Path | None = None
normal_dx_png: Path | None = None
roughness_png: Path | None = None
metallic_png: Path | None = None
ao_png: Path | None = None
orm_png: Path | None = None # UE5-packed AO/Rough/Metal
metallic_smoothness_png: Path | None = None # Unity-packed
# LODs
lod_glbs: list[Path] = field(default_factory=list)
# Collision
collision_glb: Path | None = None
# Rigging (Stage 3)
rigged_glb: Path | None = None
rigged_fbx: Path | None = None
# Metadata
asset_name: str = "untitled"
generated_at: float = field(default_factory=time.time)
model_used: str = "" # "TRELLIS.2" or "Hunyuan3D-2"
face_count: int = 0
vertex_count: int = 0
def to_dict(self) -> dict[str, Any]:
"""Serialise for status display / debug."""
out: dict[str, Any] = {}
for k, v in self.__dict__.items():
if isinstance(v, Path):
out[k] = str(v) if v else None
elif isinstance(v, list):
out[k] = [str(p) for p in v]
else:
out[k] = v
return out
# ---------------------------------------------------------------------------
# Filesystem-based state persistence
#
# ZeroGPU runs @spaces.GPU functions in a forked subprocess. Any writes to
# module-level variables (like _state) happen in the subprocess and are
# invisible to the parent Gradio process. To survive the process boundary we
# write state to a JSON file inside CURRENT/ and always read back from there.
# ---------------------------------------------------------------------------
_META_FILE = CURRENT / ".meta.json"
# File names that map to AssetState path attributes (order = preference)
_PATH_ATTRS: list[tuple[str, Path]] = [
("rigged_fbx", CURRENT / "rigged.fbx"),
("rigged_glb", CURRENT / "rigged.glb"),
("final_glb", CURRENT / "scaled.glb"),
("final_glb", CURRENT / "pivoted.glb"),
("unwrapped_glb", CURRENT / "unwrapped.glb"),
("low_poly_glb", CURRENT / "low_poly.glb"),
("cleaned_glb", CURRENT / "cleaned.glb"),
("repaired_glb", CURRENT / "repaired.glb"),
("raw_gen_glb", CURRENT / "raw_gen.glb"),
("high_poly_glb", CURRENT / "high_poly.glb"),
("normal_dx_png", CURRENT / "textures" / "normal_dx.png"),
("normal_gl_png", CURRENT / "textures" / "normal_gl.png"),
("albedo_png", CURRENT / "textures" / "albedo.png"),
("roughness_png", CURRENT / "textures" / "roughness.png"),
("metallic_png", CURRENT / "textures" / "metallic.png"),
("ao_png", CURRENT / "textures" / "ao.png"),
("orm_png", CURRENT / "textures" / "orm.png"),
("collision_glb", CURRENT / "collision.glb"),
]
def _build_state_from_disk() -> AssetState:
"""Reconstruct AssetState by scanning CURRENT/ and reading .meta.json."""
state = AssetState()
# Read persisted metadata (face count, model name, etc.)
if _META_FILE.exists():
try:
meta = json.loads(_META_FILE.read_text())
state.asset_name = meta.get("asset_name", "untitled")
state.model_used = meta.get("model_used", "")
state.face_count = meta.get("face_count", 0)
state.vertex_count = meta.get("vertex_count", 0)
except Exception:
pass
# Populate path attributes from filesystem
seen_attrs: set[str] = set()
for attr, path in _PATH_ATTRS:
if path.exists() and attr not in seen_attrs:
setattr(state, attr, path)
seen_attrs.add(attr)
# LODs
lod_dir = CURRENT / "lods"
if lod_dir.exists():
state.lod_glbs = sorted(lod_dir.glob("LOD*.glb"))
return state
def flush_meta(state: AssetState) -> None:
"""Write lightweight metadata to disk so the parent process can read it."""
try:
_META_FILE.write_text(json.dumps({
"asset_name": state.asset_name,
"model_used": state.model_used,
"face_count": state.face_count,
"vertex_count": state.vertex_count,
}))
except Exception:
pass
# Module-level singleton — kept for in-process use (e.g. stage2 steps that
# run in the same process as Gradio). Always prefer get_state() which syncs
# from disk first, making it safe across the ZeroGPU process boundary.
_state: AssetState = AssetState()
def get_state() -> AssetState:
"""Return current state, rebuilding from disk to handle ZeroGPU isolation."""
global _state
_state = _build_state_from_disk()
return _state
def reset_state() -> AssetState:
"""Replace the global state with a fresh one. Returns the new state."""
global _state
_state = AssetState()
return _state
# ---------------------------------------------------------------------------
# Presets (JSON-on-disk, loaded as dicts)
# ---------------------------------------------------------------------------
def list_presets() -> list[str]:
"""Return preset names (filenames without .json), sorted."""
if not PRESETS.exists():
return []
return sorted(p.stem for p in PRESETS.glob("*.json"))
def load_preset(name: str) -> dict[str, Any]:
"""Load a preset by name. Raises FileNotFoundError if missing."""
path = PRESETS / f"{name}.json"
if not path.exists():
raise FileNotFoundError(f"Preset not found: {name}")
with path.open("r", encoding="utf-8") as f:
return json.load(f)
def save_preset(name: str, config: dict[str, Any]) -> Path:
"""Save a preset as JSON. Overwrites if exists."""
safe_name = _sanitize_filename(name)
path = PRESETS / f"{safe_name}.json"
with path.open("w", encoding="utf-8") as f:
json.dump(config, f, indent=2, sort_keys=True)
return path
def delete_preset(name: str) -> bool:
"""Delete a preset. Returns True if deleted, False if it didn't exist."""
path = PRESETS / f"{name}.json"
if path.exists():
path.unlink()
return True
return False
def _sanitize_filename(name: str) -> str:
"""Strip path separators and unsafe chars from a filename stem."""
return "".join(c for c in name if c.isalnum() or c in "_-").strip("_-") or "preset"
# ---------------------------------------------------------------------------
# Workspace stats (for UI status bar)
# ---------------------------------------------------------------------------
def workspace_size_mb() -> float:
"""Total size of the workspace in MB."""
total = 0
if WORKSPACE.exists():
for path in WORKSPACE.rglob("*"):
if path.is_file():
total += path.stat().st_size
return total / (1024 * 1024)
def current_size_mb() -> float:
"""Size of `current/` only (active work)."""
total = 0
if CURRENT.exists():
for path in CURRENT.rglob("*"):
if path.is_file():
total += path.stat().st_size
return total / (1024 * 1024)
def export_count() -> int:
"""How many exported zips exist."""
if not EXPORTS.exists():
return 0
return len(list(EXPORTS.glob("*.zip")))
# ---------------------------------------------------------------------------
# Module init
# ---------------------------------------------------------------------------
# Auto-create folders on import so the app never crashes on a fresh checkout
ensure_dirs()