File size: 5,660 Bytes
03de09d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 | import torch
import torch.nn as nn
import yaml
from typing import Optional, Dict, Any
from .noise_encoders import get_noise_encoder, ResidualConv
class ScorePredictor(nn.Module):
def __init__(
self,
noise_encoder: nn.Module,
text_encoder: nn.Module,
dropout: float = 0.1,
num_heads: int = 1,
):
super().__init__()
self.noise_encoder = noise_encoder
self.text_encoder = text_encoder
self.num_heads = num_heads
self.dropout = dropout
noise_dim = noise_encoder.output_dim
text_dim = text_encoder.output_dim
fusion_dim = noise_dim + text_dim
self.fusion_backbone = nn.Sequential(
nn.Linear(fusion_dim, 512),
nn.SiLU(),
nn.Dropout(dropout),
nn.Linear(512, 256),
nn.SiLU(),
nn.Dropout(dropout),
nn.Linear(256, 64),
nn.SiLU(),
)
if num_heads == 1:
self.head = nn.Linear(64, 1)
else:
self.heads = nn.ModuleList([nn.Linear(64, 1) for _ in range(num_heads)])
def forward(
self,
noise: torch.Tensor,
prompt_embeds: torch.Tensor,
prompt_mask: torch.Tensor,
mask_noise: bool = False,
) -> torch.Tensor:
text_feat = self.text_encoder(prompt_embeds, prompt_mask)
noise_out = self.noise_encoder(noise)
if mask_noise:
noise_out = torch.zeros_like(noise_out)
combined = torch.cat([noise_out, text_feat], dim=1)
backbone_out = self.fusion_backbone(combined)
if self.num_heads == 1:
return self.head(backbone_out)
else:
outputs = [head(backbone_out) for head in self.heads]
return torch.cat(outputs, dim=1)
@torch.no_grad()
def predict(self, prompt_embeds, noise, prompt_mask=None, mask_noise=False):
was_training = self.training
self.eval()
if prompt_mask is None:
prompt_mask = torch.ones(
prompt_embeds.shape[:2],
device=prompt_embeds.device
)
scores = self.forward(noise, prompt_embeds, prompt_mask, mask_noise=mask_noise)
if was_training:
self.train()
return scores
def save(
self,
path: str,
normalization: Optional[Dict[str, Any]] = None,
model_type: str = None,
spatial_size: int = None,
in_channels: int = 4,
embed_dim: int = None,
seq_len: int = None,
) -> None:
checkpoint = {
'model_state_dict': self.state_dict(),
'model_config': {
'text_enc': self.text_encoder.__class__.__name__.replace('TextEncoder', '').lower(),
'noise_enc': self._get_noise_enc_name(),
'dropout': self.dropout,
'num_heads': self.num_heads,
'model_type': model_type,
'spatial_size': spatial_size,
'in_channels': in_channels,
'embed_dim': embed_dim,
'seq_len': seq_len,
'pos_encoding': getattr(self.text_encoder, 'pos_encoding_type', 'none'),
},
'normalization': normalization or {},
}
torch.save(checkpoint, path)
def _get_noise_enc_name(self) -> str:
return 'residualconv'
@classmethod
def from_config(cls, config_path: str, device: str = None):
with open(config_path, 'r') as f:
config = yaml.safe_load(f)
pred_config = config['predictor']
arch = pred_config['architecture']
weights = pred_config['weights']
device = device or weights.get('device', 'cuda')
checkpoint_path = weights['checkpoint_path']
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
model_config = checkpoint.get('model_config', {})
from .text_encoders import get_text_encoder
from .noise_encoders import get_noise_encoder
text_enc_name = model_config.get('text_enc', arch['text_encoder'])
noise_enc_name = model_config.get('noise_enc', arch['noise_encoder'])
dropout = model_config.get('dropout', arch.get('dropout', 0.1))
num_heads = model_config.get('num_heads', arch.get('num_heads', 1))
# Get dimension info from checkpoint (with backward-compatible defaults)
spatial_size = model_config.get('spatial_size', 64)
in_channels = model_config.get('in_channels', 4)
embed_dim = model_config.get('embed_dim', 4096)
seq_len = model_config.get('seq_len', 120)
pos_encoding = model_config.get('pos_encoding', 'none')
text_encoder = get_text_encoder(text_enc_name, embed_dim=embed_dim, seq_len=seq_len, pos_encoding=pos_encoding)
noise_encoder = get_noise_encoder(spatial_size=spatial_size, in_channels=in_channels)
model = cls(
noise_encoder=noise_encoder,
text_encoder=text_encoder,
dropout=dropout,
num_heads=num_heads,
)
# Handle float16 checkpoints: cast to float32 to match model dtype
state_dict = {k: v.float() for k, v in checkpoint['model_state_dict'].items()}
model.load_state_dict(state_dict)
model.to(device)
model.eval()
normalization_info = checkpoint.get('normalization', {})
return model, normalization_info
|