feiyang-cai's picture
Initial ZeroGPU Gradio Space
01a8278
Raw
History Blame Contribute Delete
11.5 kB
from __future__ import annotations
import json
import os
import sys
from dataclasses import dataclass
from typing import Dict, Optional, Tuple
import numpy as np
import pickle
import torch
import torch.nn.functional as F
from . import models as compat_models
from .models import MaterialHybridDenoiser, ModelConfig
VF_CATEGORIES = [0.1000, 0.2000, 0.3000, 0.4000, 0.5000]
MATERIAL_NAMES = ["CPP", "CHDPE", "GPP", "GHDPE"]
@dataclass(frozen=True)
class DiscreteMaskDiffusion:
T: int
class GaussianDiffusion:
def __init__(self, T: int, beta_start: float = 1e-4, beta_end: float = 2e-2, device: str = "cpu"):
self.T = int(T)
betas = torch.linspace(beta_start, beta_end, self.T, device=device)
alphas = 1.0 - betas
alpha_bar = torch.cumprod(alphas, dim=0)
self.betas = betas
self.alphas = alphas
self.sqrt_alpha_bar = torch.sqrt(alpha_bar)
self.sqrt_one_minus_alpha_bar = torch.sqrt(1.0 - alpha_bar)
self.sqrt_recip_alphas = torch.sqrt(1.0 / alphas)
@torch.no_grad()
def p_sample_step(self, x_t: torch.Tensor, t: torch.Tensor, eps_pred: torch.Tensor) -> torch.Tensor:
beta_t = self.betas[t].view(-1, 1, 1)
sqrt_recip_alpha_t = self.sqrt_recip_alphas[t].view(-1, 1, 1)
sqrt_one_minus_a_bar = self.sqrt_one_minus_alpha_bar[t].view(-1, 1, 1)
mu = sqrt_recip_alpha_t * (x_t - (beta_t / sqrt_one_minus_a_bar.clamp_min(1e-8)) * eps_pred)
noise = torch.randn_like(x_t)
nonzero_mask = (t != 0).float().view(-1, 1, 1)
return mu + nonzero_mask * torch.sqrt(beta_t) * noise
@torch.no_grad()
def sample(
model: MaterialHybridDenoiser,
disc_diff_mat: DiscreteMaskDiffusion,
disc_diff_vf_category: DiscreteMaskDiffusion,
disc_diff_layer: DiscreteMaskDiffusion,
disc_diff_angle: Optional[DiscreteMaskDiffusion] = None,
cont_diff: Optional[GaussianDiffusion] = None,
cond: torch.Tensor = None,
mask_ids: Dict[str, int] = None,
device: str = "cpu",
remask_prob: float = 0.1,
use_discrete_angles: bool = True,
) -> Dict[str, torch.Tensor]:
model.eval()
B = cond.shape[0]
L = model.cfg.n_max_layer
x_material_t = torch.full((B,), mask_ids["material"], dtype=torch.long, device=device)
x_vf_category_t = torch.full((B,), mask_ids["vf_category"], dtype=torch.long, device=device)
x_layer_t = torch.full((B, L), mask_ids["layer"], dtype=torch.long, device=device)
if use_discrete_angles:
if disc_diff_angle is None:
raise ValueError("disc_diff_angle is required when use_discrete_angles=True")
if "angle" not in mask_ids:
raise ValueError("mask_ids must include 'angle' when use_discrete_angles=True")
x_angle_t = torch.full((B, L), mask_ids["angle"], dtype=torch.long, device=device)
T = disc_diff_angle.T
else:
if cont_diff is None:
raise ValueError("cont_diff is required when use_discrete_angles=False")
x_angle_t = torch.randn(B, L, 1, device=device)
T = cont_diff.T
for t_int in reversed(range(T)):
t = torch.full((B,), t_int, dtype=torch.long, device=device)
outputs = model(x_material_t, x_vf_category_t, x_layer_t, x_angle_t, cond, t)
probs_mat = F.softmax(outputs["material_logits"], dim=-1)
remask_mat = (x_material_t == mask_ids["material"]) | (torch.rand(B, device=device) < remask_prob)
if remask_mat.any():
new_material = torch.multinomial(probs_mat[remask_mat], 1).squeeze(-1)
x_material_t = x_material_t.clone()
x_material_t[remask_mat] = new_material
probs_vf = F.softmax(outputs["vf_category_logits"], dim=-1)
remask_vf = (x_vf_category_t == mask_ids["vf_category"]) | (torch.rand(B, device=device) < remask_prob)
if remask_vf.any():
new_vf = torch.multinomial(probs_vf[remask_vf], 1).squeeze(-1)
x_vf_category_t = x_vf_category_t.clone()
x_vf_category_t[remask_vf] = new_vf
if use_discrete_angles:
probs_angle = F.softmax(outputs["angle_logits"], dim=-1) # (B,L,K+1)
remask_angle = (torch.rand(B, L, device=device) < remask_prob)
masked = (x_angle_t == mask_ids["angle"]) | remask_angle
if masked.any():
flat_probs = probs_angle.view(-1, probs_angle.size(-1))[masked.view(-1)]
new_angle = torch.multinomial(flat_probs, 1).squeeze(-1)
x_angle_t = x_angle_t.clone()
x_angle_t[masked] = new_angle
dead_category = probs_angle.size(-1) - 1
x_layer_t = (x_angle_t != dead_category).long()
else:
probs_layer = F.softmax(outputs["layer_logits"], dim=-1) # (B,L,2)
remask_layer = (torch.rand(B, L, device=device) < remask_prob)
masked = (x_layer_t == mask_ids["layer"]) | remask_layer
if masked.any():
flat_probs = probs_layer.view(-1, 2)[masked.view(-1)]
new_layer = torch.multinomial(flat_probs, 1).squeeze(-1)
x_layer_t = x_layer_t.clone()
x_layer_t[masked] = new_layer
angle_pred = outputs["angle"]
sqrt_alpha_bar_t = cont_diff.sqrt_alpha_bar[t].view(-1, 1, 1)
sqrt_one_minus_alpha_bar_t = cont_diff.sqrt_one_minus_alpha_bar[t].view(-1, 1, 1)
eps_pred = (x_angle_t - sqrt_alpha_bar_t * angle_pred) / sqrt_one_minus_alpha_bar_t.clamp_min(1e-8)
x_angle_t = cont_diff.p_sample_step(x_angle_t, t, eps_pred)
return {"material_t": x_material_t, "vf_category_t": x_vf_category_t, "layer_t": x_layer_t, "angle_t": x_angle_t}
@dataclass
class ModelBundle:
model: MaterialHybridDenoiser
disc_diff_mat: DiscreteMaskDiffusion
disc_diff_vf_category: DiscreteMaskDiffusion
disc_diff_layer: DiscreteMaskDiffusion
disc_diff_angle: Optional[DiscreteMaskDiffusion]
cont_diff: Optional[GaussianDiffusion]
mask_ids: Dict[str, int]
use_discrete_angles: bool
T: int
beta_start: float
beta_end: float
angle_resolution: float
def _load_state_dict_safely(obj) -> Dict[str, torch.Tensor]:
if isinstance(obj, dict):
for k in ("model_state_dict", "state_dict", "model"):
if k in obj and isinstance(obj[k], dict):
return obj[k]
# If it already looks like a state_dict
if all(isinstance(v, torch.Tensor) for v in obj.values()):
return obj # type: ignore[return-value]
raise ValueError("Unrecognized checkpoint format (expected dict with model state_dict)")
def _install_pickle_compat_aliases() -> None:
# Old checkpoints may reference top-level `models.ModelConfig`.
sys.modules.setdefault("models", compat_models)
def load_model_bundle(checkpoint_dir: str, device: Optional[str] = None, angle_resolution: float = 1.0) -> ModelBundle:
cfg_path = os.path.join(checkpoint_dir, "training_config.json")
with open(cfg_path, "r") as f:
train_cfg = json.load(f)
use_discrete_angles = bool(train_cfg.get("use_discrete_angles", True))
model_cfg = ModelConfig(**train_cfg["model_config"])
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
mask_ids = dict(train_cfg.get("mask_ids", {}))
# Angle category count only matters for discrete mode. We default to 7 if not present.
n_angle_categories = int(train_cfg.get("n_angle_categories", 7))
# IMPORTANT: match training model signature exactly (cfg, mask_ids, use_discrete_angles, n_angle_categories)
model = MaterialHybridDenoiser(
model_cfg,
mask_ids=mask_ids,
use_discrete_angles=use_discrete_angles,
n_angle_categories=n_angle_categories,
).to(device)
# Load checkpoint (prefer best_model.pt)
ckpt_path = os.path.join(checkpoint_dir, "best_model.pt")
if not os.path.exists(ckpt_path):
ckpt_path = os.path.join(checkpoint_dir, "checkpoint_epoch_1.pt")
_install_pickle_compat_aliases()
# Some checkpoints may contain pickled objects referencing the original training module
# layout (e.g. top-level `models`). We ship a compatibility `models.py` in the Space.
# Also prefer weights-only loading when supported to avoid unpickling non-tensor objects.
try:
ckpt = torch.load(ckpt_path, map_location=device, weights_only=True) # type: ignore[call-arg]
except TypeError:
ckpt = torch.load(ckpt_path, map_location=device)
except pickle.UnpicklingError:
# PyTorch raised because weights-only loader encountered non-tensor objects
# (e.g. models.ModelConfig). We trust this checkpoint (bundled by us), so fall back.
ckpt = torch.load(ckpt_path, map_location=device, weights_only=False) # type: ignore[call-arg]
state_dict = _load_state_dict_safely(ckpt)
# Enforce exact match with checkpoint; otherwise generation can look arbitrarily bad.
model.load_state_dict(state_dict, strict=True)
T = int(train_cfg.get("T", 100))
beta_start = float(train_cfg.get("beta_start", 1e-4))
beta_end = float(train_cfg.get("beta_end", 2e-2))
disc = DiscreteMaskDiffusion(T=T)
disc_angle = DiscreteMaskDiffusion(T=T) if use_discrete_angles else None
cont = GaussianDiffusion(T=T, beta_start=beta_start, beta_end=beta_end, device=device) if not use_discrete_angles else None
# Some training configs omit angle mask when continuous; for discrete, define a reasonable default if missing.
if use_discrete_angles and "angle" not in mask_ids:
# categories: n_angle_categories + dead (1) -> +1 ; then mask id is last index
mask_ids["angle"] = n_angle_categories + 1
return ModelBundle(
model=model,
disc_diff_mat=disc,
disc_diff_vf_category=disc,
disc_diff_layer=disc,
disc_diff_angle=disc_angle,
cont_diff=cont,
mask_ids=mask_ids,
use_discrete_angles=use_discrete_angles,
T=T,
beta_start=beta_start,
beta_end=beta_end,
angle_resolution=float(angle_resolution),
)
def vf_category_to_volume_fraction(vf_category: int) -> float:
vf_category = int(vf_category)
if 0 <= vf_category < len(VF_CATEGORIES):
return float(VF_CATEGORIES[vf_category])
return float(VF_CATEGORIES[0])
def postprocess_sample(
out: Dict[str, torch.Tensor],
use_discrete_angles: bool,
angle_categories_deg: Optional[np.ndarray],
) -> Tuple[str, float, list[float]]:
mat_idx = int(out["material_t"].item())
vf_idx = int(out["vf_category_t"].item())
mat = MATERIAL_NAMES[mat_idx] if 0 <= mat_idx < len(MATERIAL_NAMES) else f"MAT{mat_idx}"
vf = vf_category_to_volume_fraction(vf_idx)
layer = out["layer_t"].detach().cpu().numpy()[0]
angle = out["angle_t"].detach().cpu().numpy()[0]
if use_discrete_angles:
if angle_categories_deg is None:
raise ValueError("angle_categories_deg is required for discrete angles")
dead_category = int(len(angle_categories_deg))
alive_mask = angle != dead_category
angle_cats = angle[alive_mask].astype(int)
angles_deg = [float(angle_categories_deg[c]) for c in angle_cats]
else:
alive_mask = layer == 1
vals = angle[alive_mask, 0] if angle.ndim > 1 else angle[alive_mask]
angles_deg = np.rad2deg(vals).astype(np.float32).tolist()
return mat, vf, sorted([float(a) for a in angles_deg])