| import torch
|
| import torch.nn as nn
|
| import yaml
|
| from typing import Optional, Dict, Any
|
|
|
| from .noise_encoders import get_noise_encoder, ResidualConv
|
|
|
|
|
| class ScorePredictor(nn.Module):
|
| def __init__(
|
| self,
|
| noise_encoder: nn.Module,
|
| text_encoder: nn.Module,
|
| dropout: float = 0.1,
|
| num_heads: int = 1,
|
| ):
|
| super().__init__()
|
|
|
| self.noise_encoder = noise_encoder
|
| self.text_encoder = text_encoder
|
| self.num_heads = num_heads
|
| self.dropout = dropout
|
|
|
| noise_dim = noise_encoder.output_dim
|
| text_dim = text_encoder.output_dim
|
|
|
| fusion_dim = noise_dim + text_dim
|
|
|
| self.fusion_backbone = nn.Sequential(
|
| nn.Linear(fusion_dim, 512),
|
| nn.SiLU(),
|
| nn.Dropout(dropout),
|
|
|
| nn.Linear(512, 256),
|
| nn.SiLU(),
|
| nn.Dropout(dropout),
|
|
|
| nn.Linear(256, 64),
|
| nn.SiLU(),
|
| )
|
|
|
| if num_heads == 1:
|
| self.head = nn.Linear(64, 1)
|
| else:
|
| self.heads = nn.ModuleList([nn.Linear(64, 1) for _ in range(num_heads)])
|
|
|
| def forward(
|
| self,
|
| noise: torch.Tensor,
|
| prompt_embeds: torch.Tensor,
|
| prompt_mask: torch.Tensor,
|
| mask_noise: bool = False,
|
| ) -> torch.Tensor:
|
| text_feat = self.text_encoder(prompt_embeds, prompt_mask)
|
|
|
| noise_out = self.noise_encoder(noise)
|
| if mask_noise:
|
| noise_out = torch.zeros_like(noise_out)
|
|
|
| combined = torch.cat([noise_out, text_feat], dim=1)
|
|
|
| backbone_out = self.fusion_backbone(combined)
|
|
|
| if self.num_heads == 1:
|
| return self.head(backbone_out)
|
| else:
|
| outputs = [head(backbone_out) for head in self.heads]
|
| return torch.cat(outputs, dim=1)
|
|
|
| @torch.no_grad()
|
| def predict(self, prompt_embeds, noise, prompt_mask=None, mask_noise=False):
|
| was_training = self.training
|
| self.eval()
|
|
|
| if prompt_mask is None:
|
| prompt_mask = torch.ones(
|
| prompt_embeds.shape[:2],
|
| device=prompt_embeds.device
|
| )
|
|
|
| scores = self.forward(noise, prompt_embeds, prompt_mask, mask_noise=mask_noise)
|
|
|
| if was_training:
|
| self.train()
|
|
|
| return scores
|
|
|
| def save(
|
| self,
|
| path: str,
|
| normalization: Optional[Dict[str, Any]] = None,
|
| model_type: str = None,
|
| spatial_size: int = None,
|
| in_channels: int = 4,
|
| embed_dim: int = None,
|
| seq_len: int = None,
|
| ) -> None:
|
| checkpoint = {
|
| 'model_state_dict': self.state_dict(),
|
| 'model_config': {
|
| 'text_enc': self.text_encoder.__class__.__name__.replace('TextEncoder', '').lower(),
|
| 'noise_enc': self._get_noise_enc_name(),
|
| 'dropout': self.dropout,
|
| 'num_heads': self.num_heads,
|
| 'model_type': model_type,
|
| 'spatial_size': spatial_size,
|
| 'in_channels': in_channels,
|
| 'embed_dim': embed_dim,
|
| 'seq_len': seq_len,
|
| 'pos_encoding': getattr(self.text_encoder, 'pos_encoding_type', 'none'),
|
| },
|
| 'normalization': normalization or {},
|
| }
|
| torch.save(checkpoint, path)
|
|
|
| def _get_noise_enc_name(self) -> str:
|
| return 'residualconv'
|
|
|
| @classmethod
|
| def from_config(cls, config_path: str, device: str = None):
|
| with open(config_path, 'r') as f:
|
| config = yaml.safe_load(f)
|
|
|
| pred_config = config['predictor']
|
| arch = pred_config['architecture']
|
| weights = pred_config['weights']
|
|
|
| device = device or weights.get('device', 'cuda')
|
|
|
| checkpoint_path = weights['checkpoint_path']
|
| checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
|
|
|
| model_config = checkpoint.get('model_config', {})
|
|
|
| from .text_encoders import get_text_encoder
|
| from .noise_encoders import get_noise_encoder
|
|
|
| text_enc_name = model_config.get('text_enc', arch['text_encoder'])
|
| noise_enc_name = model_config.get('noise_enc', arch['noise_encoder'])
|
| dropout = model_config.get('dropout', arch.get('dropout', 0.1))
|
| num_heads = model_config.get('num_heads', arch.get('num_heads', 1))
|
|
|
|
|
| spatial_size = model_config.get('spatial_size', 64)
|
| in_channels = model_config.get('in_channels', 4)
|
| embed_dim = model_config.get('embed_dim', 4096)
|
| seq_len = model_config.get('seq_len', 120)
|
|
|
| pos_encoding = model_config.get('pos_encoding', 'none')
|
| text_encoder = get_text_encoder(text_enc_name, embed_dim=embed_dim, seq_len=seq_len, pos_encoding=pos_encoding)
|
| noise_encoder = get_noise_encoder(spatial_size=spatial_size, in_channels=in_channels)
|
|
|
| model = cls(
|
| noise_encoder=noise_encoder,
|
| text_encoder=text_encoder,
|
| dropout=dropout,
|
| num_heads=num_heads,
|
| )
|
|
|
|
|
| state_dict = {k: v.float() for k, v in checkpoint['model_state_dict'].items()}
|
| model.load_state_dict(state_dict)
|
| model.to(device)
|
| model.eval()
|
|
|
| normalization_info = checkpoint.get('normalization', {})
|
| return model, normalization_info
|
|
|