| from __future__ import annotations |
|
|
| import json |
| import os |
| import sys |
| from dataclasses import dataclass |
| from typing import Dict, Optional, Tuple |
|
|
| import numpy as np |
| import pickle |
| import torch |
| import torch.nn.functional as F |
|
|
| from . import models as compat_models |
| from .models import MaterialHybridDenoiser, ModelConfig |
|
|
|
|
| VF_CATEGORIES = [0.1000, 0.2000, 0.3000, 0.4000, 0.5000] |
| MATERIAL_NAMES = ["CPP", "CHDPE", "GPP", "GHDPE"] |
|
|
|
|
| @dataclass(frozen=True) |
| class DiscreteMaskDiffusion: |
| T: int |
|
|
|
|
| class GaussianDiffusion: |
| def __init__(self, T: int, beta_start: float = 1e-4, beta_end: float = 2e-2, device: str = "cpu"): |
| self.T = int(T) |
| betas = torch.linspace(beta_start, beta_end, self.T, device=device) |
| alphas = 1.0 - betas |
| alpha_bar = torch.cumprod(alphas, dim=0) |
| self.betas = betas |
| self.alphas = alphas |
| self.sqrt_alpha_bar = torch.sqrt(alpha_bar) |
| self.sqrt_one_minus_alpha_bar = torch.sqrt(1.0 - alpha_bar) |
| self.sqrt_recip_alphas = torch.sqrt(1.0 / alphas) |
|
|
| @torch.no_grad() |
| def p_sample_step(self, x_t: torch.Tensor, t: torch.Tensor, eps_pred: torch.Tensor) -> torch.Tensor: |
| beta_t = self.betas[t].view(-1, 1, 1) |
| sqrt_recip_alpha_t = self.sqrt_recip_alphas[t].view(-1, 1, 1) |
| sqrt_one_minus_a_bar = self.sqrt_one_minus_alpha_bar[t].view(-1, 1, 1) |
| mu = sqrt_recip_alpha_t * (x_t - (beta_t / sqrt_one_minus_a_bar.clamp_min(1e-8)) * eps_pred) |
| noise = torch.randn_like(x_t) |
| nonzero_mask = (t != 0).float().view(-1, 1, 1) |
| return mu + nonzero_mask * torch.sqrt(beta_t) * noise |
|
|
|
|
| @torch.no_grad() |
| def sample( |
| model: MaterialHybridDenoiser, |
| disc_diff_mat: DiscreteMaskDiffusion, |
| disc_diff_vf_category: DiscreteMaskDiffusion, |
| disc_diff_layer: DiscreteMaskDiffusion, |
| disc_diff_angle: Optional[DiscreteMaskDiffusion] = None, |
| cont_diff: Optional[GaussianDiffusion] = None, |
| cond: torch.Tensor = None, |
| mask_ids: Dict[str, int] = None, |
| device: str = "cpu", |
| remask_prob: float = 0.1, |
| use_discrete_angles: bool = True, |
| ) -> Dict[str, torch.Tensor]: |
| model.eval() |
| B = cond.shape[0] |
| L = model.cfg.n_max_layer |
|
|
| x_material_t = torch.full((B,), mask_ids["material"], dtype=torch.long, device=device) |
| x_vf_category_t = torch.full((B,), mask_ids["vf_category"], dtype=torch.long, device=device) |
| x_layer_t = torch.full((B, L), mask_ids["layer"], dtype=torch.long, device=device) |
|
|
| if use_discrete_angles: |
| if disc_diff_angle is None: |
| raise ValueError("disc_diff_angle is required when use_discrete_angles=True") |
| if "angle" not in mask_ids: |
| raise ValueError("mask_ids must include 'angle' when use_discrete_angles=True") |
| x_angle_t = torch.full((B, L), mask_ids["angle"], dtype=torch.long, device=device) |
| T = disc_diff_angle.T |
| else: |
| if cont_diff is None: |
| raise ValueError("cont_diff is required when use_discrete_angles=False") |
| x_angle_t = torch.randn(B, L, 1, device=device) |
| T = cont_diff.T |
|
|
| for t_int in reversed(range(T)): |
| t = torch.full((B,), t_int, dtype=torch.long, device=device) |
| outputs = model(x_material_t, x_vf_category_t, x_layer_t, x_angle_t, cond, t) |
|
|
| probs_mat = F.softmax(outputs["material_logits"], dim=-1) |
| remask_mat = (x_material_t == mask_ids["material"]) | (torch.rand(B, device=device) < remask_prob) |
| if remask_mat.any(): |
| new_material = torch.multinomial(probs_mat[remask_mat], 1).squeeze(-1) |
| x_material_t = x_material_t.clone() |
| x_material_t[remask_mat] = new_material |
|
|
| probs_vf = F.softmax(outputs["vf_category_logits"], dim=-1) |
| remask_vf = (x_vf_category_t == mask_ids["vf_category"]) | (torch.rand(B, device=device) < remask_prob) |
| if remask_vf.any(): |
| new_vf = torch.multinomial(probs_vf[remask_vf], 1).squeeze(-1) |
| x_vf_category_t = x_vf_category_t.clone() |
| x_vf_category_t[remask_vf] = new_vf |
|
|
| if use_discrete_angles: |
| probs_angle = F.softmax(outputs["angle_logits"], dim=-1) |
| remask_angle = (torch.rand(B, L, device=device) < remask_prob) |
| masked = (x_angle_t == mask_ids["angle"]) | remask_angle |
| if masked.any(): |
| flat_probs = probs_angle.view(-1, probs_angle.size(-1))[masked.view(-1)] |
| new_angle = torch.multinomial(flat_probs, 1).squeeze(-1) |
| x_angle_t = x_angle_t.clone() |
| x_angle_t[masked] = new_angle |
|
|
| dead_category = probs_angle.size(-1) - 1 |
| x_layer_t = (x_angle_t != dead_category).long() |
| else: |
| probs_layer = F.softmax(outputs["layer_logits"], dim=-1) |
| remask_layer = (torch.rand(B, L, device=device) < remask_prob) |
| masked = (x_layer_t == mask_ids["layer"]) | remask_layer |
| if masked.any(): |
| flat_probs = probs_layer.view(-1, 2)[masked.view(-1)] |
| new_layer = torch.multinomial(flat_probs, 1).squeeze(-1) |
| x_layer_t = x_layer_t.clone() |
| x_layer_t[masked] = new_layer |
|
|
| angle_pred = outputs["angle"] |
| sqrt_alpha_bar_t = cont_diff.sqrt_alpha_bar[t].view(-1, 1, 1) |
| sqrt_one_minus_alpha_bar_t = cont_diff.sqrt_one_minus_alpha_bar[t].view(-1, 1, 1) |
| eps_pred = (x_angle_t - sqrt_alpha_bar_t * angle_pred) / sqrt_one_minus_alpha_bar_t.clamp_min(1e-8) |
| x_angle_t = cont_diff.p_sample_step(x_angle_t, t, eps_pred) |
|
|
| return {"material_t": x_material_t, "vf_category_t": x_vf_category_t, "layer_t": x_layer_t, "angle_t": x_angle_t} |
|
|
|
|
| @dataclass |
| class ModelBundle: |
| model: MaterialHybridDenoiser |
| disc_diff_mat: DiscreteMaskDiffusion |
| disc_diff_vf_category: DiscreteMaskDiffusion |
| disc_diff_layer: DiscreteMaskDiffusion |
| disc_diff_angle: Optional[DiscreteMaskDiffusion] |
| cont_diff: Optional[GaussianDiffusion] |
| mask_ids: Dict[str, int] |
| use_discrete_angles: bool |
| T: int |
| beta_start: float |
| beta_end: float |
| angle_resolution: float |
|
|
|
|
| def _load_state_dict_safely(obj) -> Dict[str, torch.Tensor]: |
| if isinstance(obj, dict): |
| for k in ("model_state_dict", "state_dict", "model"): |
| if k in obj and isinstance(obj[k], dict): |
| return obj[k] |
| |
| if all(isinstance(v, torch.Tensor) for v in obj.values()): |
| return obj |
| raise ValueError("Unrecognized checkpoint format (expected dict with model state_dict)") |
|
|
|
|
| def _install_pickle_compat_aliases() -> None: |
| |
| sys.modules.setdefault("models", compat_models) |
|
|
|
|
| def load_model_bundle(checkpoint_dir: str, device: Optional[str] = None, angle_resolution: float = 1.0) -> ModelBundle: |
| cfg_path = os.path.join(checkpoint_dir, "training_config.json") |
| with open(cfg_path, "r") as f: |
| train_cfg = json.load(f) |
|
|
| use_discrete_angles = bool(train_cfg.get("use_discrete_angles", True)) |
| model_cfg = ModelConfig(**train_cfg["model_config"]) |
|
|
| if device is None: |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
| mask_ids = dict(train_cfg.get("mask_ids", {})) |
| |
| n_angle_categories = int(train_cfg.get("n_angle_categories", 7)) |
| |
| model = MaterialHybridDenoiser( |
| model_cfg, |
| mask_ids=mask_ids, |
| use_discrete_angles=use_discrete_angles, |
| n_angle_categories=n_angle_categories, |
| ).to(device) |
|
|
| |
| ckpt_path = os.path.join(checkpoint_dir, "best_model.pt") |
| if not os.path.exists(ckpt_path): |
| ckpt_path = os.path.join(checkpoint_dir, "checkpoint_epoch_1.pt") |
| _install_pickle_compat_aliases() |
| |
| |
| |
| try: |
| ckpt = torch.load(ckpt_path, map_location=device, weights_only=True) |
| except TypeError: |
| ckpt = torch.load(ckpt_path, map_location=device) |
| except pickle.UnpicklingError: |
| |
| |
| ckpt = torch.load(ckpt_path, map_location=device, weights_only=False) |
| state_dict = _load_state_dict_safely(ckpt) |
| |
| model.load_state_dict(state_dict, strict=True) |
|
|
| T = int(train_cfg.get("T", 100)) |
| beta_start = float(train_cfg.get("beta_start", 1e-4)) |
| beta_end = float(train_cfg.get("beta_end", 2e-2)) |
|
|
| disc = DiscreteMaskDiffusion(T=T) |
| disc_angle = DiscreteMaskDiffusion(T=T) if use_discrete_angles else None |
| cont = GaussianDiffusion(T=T, beta_start=beta_start, beta_end=beta_end, device=device) if not use_discrete_angles else None |
|
|
| |
| if use_discrete_angles and "angle" not in mask_ids: |
| |
| mask_ids["angle"] = n_angle_categories + 1 |
|
|
| return ModelBundle( |
| model=model, |
| disc_diff_mat=disc, |
| disc_diff_vf_category=disc, |
| disc_diff_layer=disc, |
| disc_diff_angle=disc_angle, |
| cont_diff=cont, |
| mask_ids=mask_ids, |
| use_discrete_angles=use_discrete_angles, |
| T=T, |
| beta_start=beta_start, |
| beta_end=beta_end, |
| angle_resolution=float(angle_resolution), |
| ) |
|
|
|
|
| def vf_category_to_volume_fraction(vf_category: int) -> float: |
| vf_category = int(vf_category) |
| if 0 <= vf_category < len(VF_CATEGORIES): |
| return float(VF_CATEGORIES[vf_category]) |
| return float(VF_CATEGORIES[0]) |
|
|
|
|
| def postprocess_sample( |
| out: Dict[str, torch.Tensor], |
| use_discrete_angles: bool, |
| angle_categories_deg: Optional[np.ndarray], |
| ) -> Tuple[str, float, list[float]]: |
| mat_idx = int(out["material_t"].item()) |
| vf_idx = int(out["vf_category_t"].item()) |
| mat = MATERIAL_NAMES[mat_idx] if 0 <= mat_idx < len(MATERIAL_NAMES) else f"MAT{mat_idx}" |
| vf = vf_category_to_volume_fraction(vf_idx) |
|
|
| layer = out["layer_t"].detach().cpu().numpy()[0] |
| angle = out["angle_t"].detach().cpu().numpy()[0] |
|
|
| if use_discrete_angles: |
| if angle_categories_deg is None: |
| raise ValueError("angle_categories_deg is required for discrete angles") |
| dead_category = int(len(angle_categories_deg)) |
| alive_mask = angle != dead_category |
| angle_cats = angle[alive_mask].astype(int) |
| angles_deg = [float(angle_categories_deg[c]) for c in angle_cats] |
| else: |
| alive_mask = layer == 1 |
| vals = angle[alive_mask, 0] if angle.ndim > 1 else angle[alive_mask] |
| angles_deg = np.rad2deg(vals).astype(np.float32).tolist() |
|
|
| return mat, vf, sorted([float(a) for a in angles_deg]) |
|
|