English
PAINE / predictor /models /model.py
joonghk's picture
first commit
03de09d
import torch
import torch.nn as nn
import yaml
from typing import Optional, Dict, Any
from .noise_encoders import get_noise_encoder, ResidualConv
class ScorePredictor(nn.Module):
def __init__(
self,
noise_encoder: nn.Module,
text_encoder: nn.Module,
dropout: float = 0.1,
num_heads: int = 1,
):
super().__init__()
self.noise_encoder = noise_encoder
self.text_encoder = text_encoder
self.num_heads = num_heads
self.dropout = dropout
noise_dim = noise_encoder.output_dim
text_dim = text_encoder.output_dim
fusion_dim = noise_dim + text_dim
self.fusion_backbone = nn.Sequential(
nn.Linear(fusion_dim, 512),
nn.SiLU(),
nn.Dropout(dropout),
nn.Linear(512, 256),
nn.SiLU(),
nn.Dropout(dropout),
nn.Linear(256, 64),
nn.SiLU(),
)
if num_heads == 1:
self.head = nn.Linear(64, 1)
else:
self.heads = nn.ModuleList([nn.Linear(64, 1) for _ in range(num_heads)])
def forward(
self,
noise: torch.Tensor,
prompt_embeds: torch.Tensor,
prompt_mask: torch.Tensor,
mask_noise: bool = False,
) -> torch.Tensor:
text_feat = self.text_encoder(prompt_embeds, prompt_mask)
noise_out = self.noise_encoder(noise)
if mask_noise:
noise_out = torch.zeros_like(noise_out)
combined = torch.cat([noise_out, text_feat], dim=1)
backbone_out = self.fusion_backbone(combined)
if self.num_heads == 1:
return self.head(backbone_out)
else:
outputs = [head(backbone_out) for head in self.heads]
return torch.cat(outputs, dim=1)
@torch.no_grad()
def predict(self, prompt_embeds, noise, prompt_mask=None, mask_noise=False):
was_training = self.training
self.eval()
if prompt_mask is None:
prompt_mask = torch.ones(
prompt_embeds.shape[:2],
device=prompt_embeds.device
)
scores = self.forward(noise, prompt_embeds, prompt_mask, mask_noise=mask_noise)
if was_training:
self.train()
return scores
def save(
self,
path: str,
normalization: Optional[Dict[str, Any]] = None,
model_type: str = None,
spatial_size: int = None,
in_channels: int = 4,
embed_dim: int = None,
seq_len: int = None,
) -> None:
checkpoint = {
'model_state_dict': self.state_dict(),
'model_config': {
'text_enc': self.text_encoder.__class__.__name__.replace('TextEncoder', '').lower(),
'noise_enc': self._get_noise_enc_name(),
'dropout': self.dropout,
'num_heads': self.num_heads,
'model_type': model_type,
'spatial_size': spatial_size,
'in_channels': in_channels,
'embed_dim': embed_dim,
'seq_len': seq_len,
'pos_encoding': getattr(self.text_encoder, 'pos_encoding_type', 'none'),
},
'normalization': normalization or {},
}
torch.save(checkpoint, path)
def _get_noise_enc_name(self) -> str:
return 'residualconv'
@classmethod
def from_config(cls, config_path: str, device: str = None):
with open(config_path, 'r') as f:
config = yaml.safe_load(f)
pred_config = config['predictor']
arch = pred_config['architecture']
weights = pred_config['weights']
device = device or weights.get('device', 'cuda')
checkpoint_path = weights['checkpoint_path']
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
model_config = checkpoint.get('model_config', {})
from .text_encoders import get_text_encoder
from .noise_encoders import get_noise_encoder
text_enc_name = model_config.get('text_enc', arch['text_encoder'])
noise_enc_name = model_config.get('noise_enc', arch['noise_encoder'])
dropout = model_config.get('dropout', arch.get('dropout', 0.1))
num_heads = model_config.get('num_heads', arch.get('num_heads', 1))
# Get dimension info from checkpoint (with backward-compatible defaults)
spatial_size = model_config.get('spatial_size', 64)
in_channels = model_config.get('in_channels', 4)
embed_dim = model_config.get('embed_dim', 4096)
seq_len = model_config.get('seq_len', 120)
pos_encoding = model_config.get('pos_encoding', 'none')
text_encoder = get_text_encoder(text_enc_name, embed_dim=embed_dim, seq_len=seq_len, pos_encoding=pos_encoding)
noise_encoder = get_noise_encoder(spatial_size=spatial_size, in_channels=in_channels)
model = cls(
noise_encoder=noise_encoder,
text_encoder=text_encoder,
dropout=dropout,
num_heads=num_heads,
)
# Handle float16 checkpoints: cast to float32 to match model dtype
state_dict = {k: v.float() for k, v in checkpoint['model_state_dict'].items()}
model.load_state_dict(state_dict)
model.to(device)
model.eval()
normalization_info = checkpoint.get('normalization', {})
return model, normalization_info