File size: 2,652 Bytes
03de09d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 | """Load trained multi-model predictor from checkpoint.
Reads dimension info from checkpoint for automatic model reconstruction.
Falls back to PixArt-Alpha defaults for legacy checkpoints.
"""
import torch
import torch.nn as nn
from typing import Tuple, Dict, Any
from predictor.models import get_model
from predictor.configs.model_dims import MODEL_DIMS
def load_predictor(
checkpoint_path: str,
device: str = "cuda",
) -> Tuple[nn.Module, Dict[str, Any]]:
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
cfg = checkpoint["model_config"]
# Get dimensions from checkpoint (backward-compatible defaults)
model_type = cfg.get("model_type", None)
if model_type and model_type in MODEL_DIMS:
dims = MODEL_DIMS[model_type]
spatial_size = cfg.get("spatial_size", dims["spatial_size"])
in_channels = cfg.get("in_channels", dims["latent_shape"][0])
embed_dim = cfg.get("embed_dim", dims["embed_dim"])
seq_len = cfg.get("seq_len", dims["seq_len"])
else:
# Legacy checkpoint: use values from config or PixArt-Alpha defaults
spatial_size = cfg.get("spatial_size", 64)
in_channels = cfg.get("in_channels", 4)
embed_dim = cfg.get("embed_dim", 4096)
seq_len = cfg.get("seq_len", 120)
model = get_model(
noise_enc=cfg["noise_enc"],
text_enc=cfg["text_enc"],
dropout=cfg["dropout"],
num_heads=cfg.get("num_heads", 1),
spatial_size=spatial_size,
in_channels=in_channels,
embed_dim=embed_dim,
seq_len=seq_len,
pos_encoding=cfg.get("pos_encoding", "none"),
)
# Handle float16 checkpoints: cast to float32 to match model dtype
state_dict = {k: v.float() for k, v in checkpoint["model_state_dict"].items()}
model.load_state_dict(state_dict)
model.to(device)
model.eval()
return model, checkpoint["normalization"]
def denormalize_prediction(
pred_normalized: torch.Tensor,
normalization: Dict[str, Any],
) -> torch.Tensor:
mean = normalization["y_mean"]
std = normalization["y_std"]
return pred_normalized * std + mean
def get_checkpoint_info(checkpoint_path: str) -> Dict[str, Any]:
checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
param_count = sum(p.numel() for p in checkpoint["model_state_dict"].values())
return {
"model_config": checkpoint["model_config"],
"normalization": checkpoint["normalization"],
"param_count": param_count,
} |