from __future__ import annotations import json import os import sys from dataclasses import dataclass from typing import Dict, Optional, Tuple import numpy as np import pickle import torch import torch.nn.functional as F from . import models as compat_models from .models import MaterialHybridDenoiser, ModelConfig VF_CATEGORIES = [0.1000, 0.2000, 0.3000, 0.4000, 0.5000] MATERIAL_NAMES = ["CPP", "CHDPE", "GPP", "GHDPE"] @dataclass(frozen=True) class DiscreteMaskDiffusion: T: int class GaussianDiffusion: def __init__(self, T: int, beta_start: float = 1e-4, beta_end: float = 2e-2, device: str = "cpu"): self.T = int(T) betas = torch.linspace(beta_start, beta_end, self.T, device=device) alphas = 1.0 - betas alpha_bar = torch.cumprod(alphas, dim=0) self.betas = betas self.alphas = alphas self.sqrt_alpha_bar = torch.sqrt(alpha_bar) self.sqrt_one_minus_alpha_bar = torch.sqrt(1.0 - alpha_bar) self.sqrt_recip_alphas = torch.sqrt(1.0 / alphas) @torch.no_grad() def p_sample_step(self, x_t: torch.Tensor, t: torch.Tensor, eps_pred: torch.Tensor) -> torch.Tensor: beta_t = self.betas[t].view(-1, 1, 1) sqrt_recip_alpha_t = self.sqrt_recip_alphas[t].view(-1, 1, 1) sqrt_one_minus_a_bar = self.sqrt_one_minus_alpha_bar[t].view(-1, 1, 1) mu = sqrt_recip_alpha_t * (x_t - (beta_t / sqrt_one_minus_a_bar.clamp_min(1e-8)) * eps_pred) noise = torch.randn_like(x_t) nonzero_mask = (t != 0).float().view(-1, 1, 1) return mu + nonzero_mask * torch.sqrt(beta_t) * noise @torch.no_grad() def sample( model: MaterialHybridDenoiser, disc_diff_mat: DiscreteMaskDiffusion, disc_diff_vf_category: DiscreteMaskDiffusion, disc_diff_layer: DiscreteMaskDiffusion, disc_diff_angle: Optional[DiscreteMaskDiffusion] = None, cont_diff: Optional[GaussianDiffusion] = None, cond: torch.Tensor = None, mask_ids: Dict[str, int] = None, device: str = "cpu", remask_prob: float = 0.1, use_discrete_angles: bool = True, ) -> Dict[str, torch.Tensor]: model.eval() B = cond.shape[0] L = model.cfg.n_max_layer x_material_t = torch.full((B,), mask_ids["material"], dtype=torch.long, device=device) x_vf_category_t = torch.full((B,), mask_ids["vf_category"], dtype=torch.long, device=device) x_layer_t = torch.full((B, L), mask_ids["layer"], dtype=torch.long, device=device) if use_discrete_angles: if disc_diff_angle is None: raise ValueError("disc_diff_angle is required when use_discrete_angles=True") if "angle" not in mask_ids: raise ValueError("mask_ids must include 'angle' when use_discrete_angles=True") x_angle_t = torch.full((B, L), mask_ids["angle"], dtype=torch.long, device=device) T = disc_diff_angle.T else: if cont_diff is None: raise ValueError("cont_diff is required when use_discrete_angles=False") x_angle_t = torch.randn(B, L, 1, device=device) T = cont_diff.T for t_int in reversed(range(T)): t = torch.full((B,), t_int, dtype=torch.long, device=device) outputs = model(x_material_t, x_vf_category_t, x_layer_t, x_angle_t, cond, t) probs_mat = F.softmax(outputs["material_logits"], dim=-1) remask_mat = (x_material_t == mask_ids["material"]) | (torch.rand(B, device=device) < remask_prob) if remask_mat.any(): new_material = torch.multinomial(probs_mat[remask_mat], 1).squeeze(-1) x_material_t = x_material_t.clone() x_material_t[remask_mat] = new_material probs_vf = F.softmax(outputs["vf_category_logits"], dim=-1) remask_vf = (x_vf_category_t == mask_ids["vf_category"]) | (torch.rand(B, device=device) < remask_prob) if remask_vf.any(): new_vf = torch.multinomial(probs_vf[remask_vf], 1).squeeze(-1) x_vf_category_t = x_vf_category_t.clone() x_vf_category_t[remask_vf] = new_vf if use_discrete_angles: probs_angle = F.softmax(outputs["angle_logits"], dim=-1) # (B,L,K+1) remask_angle = (torch.rand(B, L, device=device) < remask_prob) masked = (x_angle_t == mask_ids["angle"]) | remask_angle if masked.any(): flat_probs = probs_angle.view(-1, probs_angle.size(-1))[masked.view(-1)] new_angle = torch.multinomial(flat_probs, 1).squeeze(-1) x_angle_t = x_angle_t.clone() x_angle_t[masked] = new_angle dead_category = probs_angle.size(-1) - 1 x_layer_t = (x_angle_t != dead_category).long() else: probs_layer = F.softmax(outputs["layer_logits"], dim=-1) # (B,L,2) remask_layer = (torch.rand(B, L, device=device) < remask_prob) masked = (x_layer_t == mask_ids["layer"]) | remask_layer if masked.any(): flat_probs = probs_layer.view(-1, 2)[masked.view(-1)] new_layer = torch.multinomial(flat_probs, 1).squeeze(-1) x_layer_t = x_layer_t.clone() x_layer_t[masked] = new_layer angle_pred = outputs["angle"] sqrt_alpha_bar_t = cont_diff.sqrt_alpha_bar[t].view(-1, 1, 1) sqrt_one_minus_alpha_bar_t = cont_diff.sqrt_one_minus_alpha_bar[t].view(-1, 1, 1) eps_pred = (x_angle_t - sqrt_alpha_bar_t * angle_pred) / sqrt_one_minus_alpha_bar_t.clamp_min(1e-8) x_angle_t = cont_diff.p_sample_step(x_angle_t, t, eps_pred) return {"material_t": x_material_t, "vf_category_t": x_vf_category_t, "layer_t": x_layer_t, "angle_t": x_angle_t} @dataclass class ModelBundle: model: MaterialHybridDenoiser disc_diff_mat: DiscreteMaskDiffusion disc_diff_vf_category: DiscreteMaskDiffusion disc_diff_layer: DiscreteMaskDiffusion disc_diff_angle: Optional[DiscreteMaskDiffusion] cont_diff: Optional[GaussianDiffusion] mask_ids: Dict[str, int] use_discrete_angles: bool T: int beta_start: float beta_end: float angle_resolution: float def _load_state_dict_safely(obj) -> Dict[str, torch.Tensor]: if isinstance(obj, dict): for k in ("model_state_dict", "state_dict", "model"): if k in obj and isinstance(obj[k], dict): return obj[k] # If it already looks like a state_dict if all(isinstance(v, torch.Tensor) for v in obj.values()): return obj # type: ignore[return-value] raise ValueError("Unrecognized checkpoint format (expected dict with model state_dict)") def _install_pickle_compat_aliases() -> None: # Old checkpoints may reference top-level `models.ModelConfig`. sys.modules.setdefault("models", compat_models) def load_model_bundle(checkpoint_dir: str, device: Optional[str] = None, angle_resolution: float = 1.0) -> ModelBundle: cfg_path = os.path.join(checkpoint_dir, "training_config.json") with open(cfg_path, "r") as f: train_cfg = json.load(f) use_discrete_angles = bool(train_cfg.get("use_discrete_angles", True)) model_cfg = ModelConfig(**train_cfg["model_config"]) if device is None: device = "cuda" if torch.cuda.is_available() else "cpu" mask_ids = dict(train_cfg.get("mask_ids", {})) # Angle category count only matters for discrete mode. We default to 7 if not present. n_angle_categories = int(train_cfg.get("n_angle_categories", 7)) # IMPORTANT: match training model signature exactly (cfg, mask_ids, use_discrete_angles, n_angle_categories) model = MaterialHybridDenoiser( model_cfg, mask_ids=mask_ids, use_discrete_angles=use_discrete_angles, n_angle_categories=n_angle_categories, ).to(device) # Load checkpoint (prefer best_model.pt) ckpt_path = os.path.join(checkpoint_dir, "best_model.pt") if not os.path.exists(ckpt_path): ckpt_path = os.path.join(checkpoint_dir, "checkpoint_epoch_1.pt") _install_pickle_compat_aliases() # Some checkpoints may contain pickled objects referencing the original training module # layout (e.g. top-level `models`). We ship a compatibility `models.py` in the Space. # Also prefer weights-only loading when supported to avoid unpickling non-tensor objects. try: ckpt = torch.load(ckpt_path, map_location=device, weights_only=True) # type: ignore[call-arg] except TypeError: ckpt = torch.load(ckpt_path, map_location=device) except pickle.UnpicklingError: # PyTorch raised because weights-only loader encountered non-tensor objects # (e.g. models.ModelConfig). We trust this checkpoint (bundled by us), so fall back. ckpt = torch.load(ckpt_path, map_location=device, weights_only=False) # type: ignore[call-arg] state_dict = _load_state_dict_safely(ckpt) # Enforce exact match with checkpoint; otherwise generation can look arbitrarily bad. model.load_state_dict(state_dict, strict=True) T = int(train_cfg.get("T", 100)) beta_start = float(train_cfg.get("beta_start", 1e-4)) beta_end = float(train_cfg.get("beta_end", 2e-2)) disc = DiscreteMaskDiffusion(T=T) disc_angle = DiscreteMaskDiffusion(T=T) if use_discrete_angles else None cont = GaussianDiffusion(T=T, beta_start=beta_start, beta_end=beta_end, device=device) if not use_discrete_angles else None # Some training configs omit angle mask when continuous; for discrete, define a reasonable default if missing. if use_discrete_angles and "angle" not in mask_ids: # categories: n_angle_categories + dead (1) -> +1 ; then mask id is last index mask_ids["angle"] = n_angle_categories + 1 return ModelBundle( model=model, disc_diff_mat=disc, disc_diff_vf_category=disc, disc_diff_layer=disc, disc_diff_angle=disc_angle, cont_diff=cont, mask_ids=mask_ids, use_discrete_angles=use_discrete_angles, T=T, beta_start=beta_start, beta_end=beta_end, angle_resolution=float(angle_resolution), ) def vf_category_to_volume_fraction(vf_category: int) -> float: vf_category = int(vf_category) if 0 <= vf_category < len(VF_CATEGORIES): return float(VF_CATEGORIES[vf_category]) return float(VF_CATEGORIES[0]) def postprocess_sample( out: Dict[str, torch.Tensor], use_discrete_angles: bool, angle_categories_deg: Optional[np.ndarray], ) -> Tuple[str, float, list[float]]: mat_idx = int(out["material_t"].item()) vf_idx = int(out["vf_category_t"].item()) mat = MATERIAL_NAMES[mat_idx] if 0 <= mat_idx < len(MATERIAL_NAMES) else f"MAT{mat_idx}" vf = vf_category_to_volume_fraction(vf_idx) layer = out["layer_t"].detach().cpu().numpy()[0] angle = out["angle_t"].detach().cpu().numpy()[0] if use_discrete_angles: if angle_categories_deg is None: raise ValueError("angle_categories_deg is required for discrete angles") dead_category = int(len(angle_categories_deg)) alive_mask = angle != dead_category angle_cats = angle[alive_mask].astype(int) angles_deg = [float(angle_categories_deg[c]) for c in angle_cats] else: alive_mask = layer == 1 vals = angle[alive_mask, 0] if angle.ndim > 1 else angle[alive_mask] angles_deg = np.rad2deg(vals).astype(np.float32).tolist() return mat, vf, sorted([float(a) for a in angles_deg])