| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torchvision.models import convnext_tiny, ConvNeXt_Tiny_Weights |
| import math |
| from typing import Optional |
|
|
|
|
| class RMSNorm(nn.Module): |
| """Root Mean Square Layer Normalization""" |
| def __init__(self, dim: int, eps: float = 1e-6): |
| super().__init__() |
| self.eps = eps |
| self.weight = nn.Parameter(torch.ones(dim)) |
| |
| def forward(self, x): |
| |
| rms = torch.sqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + self.eps) |
| return x / rms * self.weight |
|
|
|
|
| class RotaryPositionalEmbedding(nn.Module): |
| """Rotary Position Embedding (RoPE)""" |
| def __init__(self, dim: int, max_seq_len: int = 8192, base: float = 10000.0): |
| super().__init__() |
| self.dim = dim |
| self.max_seq_len = max_seq_len |
| self.base = base |
| |
| |
| inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) |
| self.register_buffer('inv_freq', inv_freq, persistent=False) |
| |
| |
| self._cached_seq_len = 0 |
| self._cached_cos = None |
| self._cached_sin = None |
| |
| def _update_cache(self, seq_len: int, device: torch.device, dtype: torch.dtype): |
| if seq_len > self._cached_seq_len: |
| self._cached_seq_len = seq_len |
| |
| |
| t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) |
| |
| |
| freqs = torch.einsum('i,j->ij', t, self.inv_freq) |
| |
| |
| emb = torch.cat([freqs, freqs], dim=-1) |
| |
| |
| self._cached_cos = emb.cos().to(dtype) |
| self._cached_sin = emb.sin().to(dtype) |
| |
| def forward(self, x: torch.Tensor, seq_len: Optional[int] = None): |
| if seq_len is None: |
| seq_len = x.shape[-2] |
| |
| self._update_cache(seq_len, x.device, x.dtype) |
| |
| return self._cached_cos[:seq_len], self._cached_sin[:seq_len] |
|
|
|
|
| def apply_rotary_pos_emb(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor): |
| """apply rotary position encoding to query and key""" |
| def rotate_half(x): |
| """rotate the second half of the input dimension""" |
| x1, x2 = x.chunk(2, dim=-1) |
| return torch.cat([-x2, x1], dim=-1) |
| |
| |
| q_embed = q * cos + rotate_half(q) * sin |
| k_embed = k * cos + rotate_half(k) * sin |
| |
| return q_embed, k_embed |
|
|
|
|
| class MultiHeadAttentionWithRoPE(nn.Module): |
| """multi-head attention with RoPE""" |
| def __init__(self, d_model: int, nhead: int, dropout: float = 0.1, rope_base: float = 10000.0): |
| super().__init__() |
| assert d_model % nhead == 0 |
| |
| self.d_model = d_model |
| self.nhead = nhead |
| self.head_dim = d_model // nhead |
| self.dropout = dropout |
| |
| |
| self.q_proj = nn.Linear(d_model, d_model, bias=False) |
| self.k_proj = nn.Linear(d_model, d_model, bias=False) |
| self.v_proj = nn.Linear(d_model, d_model, bias=False) |
| self.out_proj = nn.Linear(d_model, d_model) |
| |
| |
| self.rope = RotaryPositionalEmbedding(self.head_dim, base=rope_base) |
| |
| |
| self.attn_dropout = nn.Dropout(dropout) |
| self.resid_dropout = nn.Dropout(dropout) |
| |
| |
| self.scale = 1.0 / math.sqrt(self.head_dim) |
| |
| def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, |
| attn_mask: Optional[torch.Tensor] = None, key_padding_mask: Optional[torch.Tensor] = None): |
| B, T, C = query.shape |
| |
| |
| q = self.q_proj(query) |
| k = self.k_proj(key) |
| v = self.v_proj(value) |
| |
| |
| q = q.view(B, T, self.nhead, self.head_dim).transpose(1, 2) |
| k = k.view(B, T, self.nhead, self.head_dim).transpose(1, 2) |
| v = v.view(B, T, self.nhead, self.head_dim).transpose(1, 2) |
| |
| |
| cos, sin = self.rope(q, T) |
| |
| cos = cos.unsqueeze(0).unsqueeze(0) |
| sin = sin.unsqueeze(0).unsqueeze(0) |
| |
| q, k = apply_rotary_pos_emb(q, k, cos, sin) |
| |
| |
| attn_weights = torch.matmul(q, k.transpose(-2, -1)) * self.scale |
| |
| |
| if attn_mask is not None: |
| attn_weights += attn_mask |
| |
| if key_padding_mask is not None: |
| attn_weights = attn_weights.masked_fill( |
| key_padding_mask.unsqueeze(1).unsqueeze(2), float('-inf') |
| ) |
| |
| |
| attn_weights = F.softmax(attn_weights, dim=-1) |
| attn_weights = self.attn_dropout(attn_weights) |
| |
| |
| out = torch.matmul(attn_weights, v) |
| |
| |
| out = out.transpose(1, 2).contiguous().view(B, T, C) |
| |
| |
| out = self.out_proj(out) |
| out = self.resid_dropout(out) |
| |
| return out |
|
|
|
|
| class TransformerEncoderLayerWithRoPE(nn.Module): |
| """Transformer encoder layer with RMSNorm and RoPE""" |
| def __init__(self, d_model: int, nhead: int, dim_feedforward: int = 2048, |
| dropout: float = 0.1, activation: str = 'gelu', rope_base: float = 10000.0): |
| super().__init__() |
| |
| |
| self.self_attn = MultiHeadAttentionWithRoPE(d_model, nhead, dropout, rope_base) |
| |
| |
| self.linear1 = nn.Linear(d_model, dim_feedforward) |
| self.linear2 = nn.Linear(dim_feedforward, d_model) |
| self.dropout1 = nn.Dropout(dropout) |
| self.dropout2 = nn.Dropout(dropout) |
| |
| |
| self.norm1 = RMSNorm(d_model) |
| self.norm2 = RMSNorm(d_model) |
| |
| |
| self.activation = getattr(F, activation) |
| |
| def forward(self, src: torch.Tensor, src_mask: Optional[torch.Tensor] = None, |
| src_key_padding_mask: Optional[torch.Tensor] = None): |
| |
| src2 = self.norm1(src) |
| src2 = self.self_attn(src2, src2, src2, attn_mask=src_mask, |
| key_padding_mask=src_key_padding_mask) |
| src = src + self.dropout1(src2) |
| |
| |
| src2 = self.norm2(src) |
| src2 = self.linear2(self.dropout2(self.activation(self.linear1(src2)))) |
| src = src + self.dropout2(src2) |
| |
| return src |
|
|
|
|
| class TransformerEncoderWithRoPE(nn.Module): |
| """Transformer encoder with RMSNorm and RoPE""" |
| def __init__(self, encoder_layer: TransformerEncoderLayerWithRoPE, num_layers: int): |
| super().__init__() |
| self.layers = nn.ModuleList([ |
| encoder_layer for _ in range(num_layers) |
| ]) |
| self.num_layers = num_layers |
| |
| def forward(self, src: torch.Tensor, mask: Optional[torch.Tensor] = None, |
| src_key_padding_mask: Optional[torch.Tensor] = None): |
| output = src |
| |
| for mod in self.layers: |
| output = mod(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask) |
| |
| return output |
|
|
|
|
| class FourierFeatureEncoder(nn.Module): |
| def __init__(self, input_channels=3, num_freq_bands=10): |
| super().__init__() |
| self.num_freq_bands = num_freq_bands |
| self.output_channels = input_channels * (2 * num_freq_bands + 1) |
| self.freq_bands = nn.Parameter(2.0 ** torch.arange(num_freq_bands) * torch.pi, requires_grad=False) |
| |
| def forward(self, x): |
| B, C, H, W = x.shape |
| x_permuted = x.permute(0, 2, 3, 1) |
| scaled_x = x_permuted.unsqueeze(-1) * self.freq_bands |
| sincos_features = torch.cat([torch.sin(scaled_x), torch.cos(scaled_x)], dim=-1) |
| sincos_features = sincos_features.reshape(B, H, W, -1) |
| final_features = torch.cat([x_permuted, sincos_features], dim=-1) |
| return final_features.permute(0, 3, 1, 2) |
|
|
|
|
| class DirectionGuidedFusion(nn.Module): |
| """注意力引导的方向-视觉特征融合模块""" |
| def __init__(self, visual_dim, dir_dim, out_dim): |
| super().__init__() |
| |
| |
| self.dir_adapter = nn.Sequential( |
| nn.Conv2d(dir_dim, visual_dim, kernel_size=1), |
| nn.GroupNorm(min(8, visual_dim // 4), visual_dim), |
| nn.GELU() |
| ) |
| |
| |
| hidden_dim = max(16, visual_dim // 8) |
| self.attention = nn.Sequential( |
| nn.Conv2d(visual_dim + dir_dim, hidden_dim, kernel_size=1), |
| nn.GroupNorm(min(4, hidden_dim // 4), hidden_dim), |
| nn.GELU(), |
| nn.Conv2d(hidden_dim, visual_dim, kernel_size=1), |
| nn.Sigmoid() |
| ) |
| |
| |
| self.fusion = nn.Sequential( |
| nn.Conv2d(visual_dim * 2, out_dim, kernel_size=1), |
| nn.GroupNorm(min(8, out_dim // 4), out_dim), |
| nn.GELU() |
| ) |
| |
| |
| self.residual_proj = nn.Conv2d(visual_dim, out_dim, kernel_size=1) if visual_dim != out_dim else nn.Identity() |
| |
| def forward(self, visual_feat, dir_feat): |
| |
| dir_adapted = self.dir_adapter(dir_feat) |
| |
| |
| attention_input = torch.cat([visual_feat, dir_feat], dim=1) |
| attention_weight = self.attention(attention_input) |
| |
| |
| attended_visual = visual_feat * attention_weight |
| |
| |
| fused_input = torch.cat([attended_visual, dir_adapted], dim=1) |
| fused = self.fusion(fused_input) |
| |
| |
| residual = self.residual_proj(visual_feat) |
| output = fused + residual |
| |
| return output |
|
|
|
|
| class Hdri_Encoder(nn.Module): |
| def __init__(self, output_dim=768, num_tokens=4096, cnn_out_channels=256, |
| n_heads=8, num_transformer_layers=2, rope_base=10000.0, pretrained_path=None): |
| super().__init__() |
| self.output_dim = output_dim |
| self.num_tokens = num_tokens |
| |
| |
| self.target_h = int(math.sqrt(num_tokens)) |
| self.target_w = self.target_h |
| assert self.target_h * self.target_w == num_tokens, f"num_tokens must be a perfect square, got {num_tokens}" |
| |
| print(f"Target resolution: {self.target_h}x{self.target_w} = {num_tokens} tokens") |
|
|
| |
| |
| self.dir_encoder = FourierFeatureEncoder(input_channels=3, num_freq_bands=8) |
| dir_enc_channels = self.dir_encoder.output_channels |
|
|
| |
| convnext = convnext_tiny(weights=ConvNeXt_Tiny_Weights.IMAGENET1K_V1) |
| |
| |
| self.stage1 = nn.Sequential(convnext.features[0], convnext.features[1]) |
| self.stage2 = nn.Sequential(convnext.features[2], convnext.features[3]) |
| self.stage3 = nn.Sequential(convnext.features[4], convnext.features[5]) |
| self.stage4 = nn.Sequential(convnext.features[6], convnext.features[7]) |
| |
| |
| orig_stem_conv = self.stage1[0][0] |
| new_stem_conv = nn.Conv2d(6, orig_stem_conv.out_channels, |
| kernel_size=orig_stem_conv.kernel_size, |
| stride=orig_stem_conv.stride, |
| padding=orig_stem_conv.padding) |
| |
| |
| with torch.no_grad(): |
| new_stem_conv.weight[:, :3] = orig_stem_conv.weight |
| new_stem_conv.weight[:, 3:6] = orig_stem_conv.weight |
| new_stem_conv.bias = orig_stem_conv.bias |
| |
| self.stage1[0][0] = new_stem_conv |
|
|
| |
| |
| self.fusion_c3 = DirectionGuidedFusion(192, dir_enc_channels, cnn_out_channels) |
| self.fusion_c4 = DirectionGuidedFusion(384, dir_enc_channels, cnn_out_channels) |
| self.fusion_c5 = DirectionGuidedFusion(768, dir_enc_channels, cnn_out_channels) |
|
|
| |
| self.projection_conv = nn.Conv2d(cnn_out_channels * 3, output_dim, kernel_size=1) |
| |
| |
| encoder_layer = TransformerEncoderLayerWithRoPE( |
| d_model=output_dim, |
| nhead=n_heads, |
| dim_feedforward=output_dim * 4, |
| dropout=0.1, |
| activation='gelu', |
| rope_base=rope_base |
| ) |
| self.transformer_encoder = TransformerEncoderWithRoPE(encoder_layer, num_transformer_layers) |
| self.ln_final = RMSNorm(output_dim) |
|
|
| |
| self._initialize_weights() |
|
|
| self.load_weights(pretrained_path) |
|
|
| |
| def convert_to_fp16(self): |
| pass |
|
|
| def convert_to_fp32(self): |
| pass |
| |
| def _initialize_weights(self): |
| """weight initialization""" |
| for m in self.modules(): |
| if isinstance(m, nn.Conv2d): |
| if hasattr(m, 'weight') and m.weight is not None: |
| nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') |
| if hasattr(m, 'bias') and m.bias is not None: |
| nn.init.constant_(m.bias, 0) |
| elif isinstance(m, (nn.GroupNorm, nn.LayerNorm)): |
| if hasattr(m, 'weight') and m.weight is not None: |
| nn.init.constant_(m.weight, 1) |
| if hasattr(m, 'bias') and m.bias is not None: |
| nn.init.constant_(m.bias, 0) |
|
|
| def load_weights(self, pretrained_path): |
| if pretrained_path is not None: |
| |
| |
| |
| |
| checkpoint = torch.load( |
| pretrained_path, |
| map_location=torch.device("cpu"), |
| weights_only=False, |
| ) |
| self.load_state_dict(checkpoint) |
|
|
| def forward(self, context): |
| B, C, H, W = context.shape |
| assert C == 9, f"Expected 9 channels (3 LDR + 3 HDR + 3 directions), got {C}" |
| |
| |
| split_indices = [3, 6] |
| ldr_map, hdr_map, view_dirs_map = torch.tensor_split(context, split_indices, dim=1) |
| |
| |
| dir_encoding = self.dir_encoder(view_dirs_map) |
| |
| |
| visual_input = torch.cat([ldr_map, hdr_map], dim=1) |
|
|
| |
| c2 = self.stage1(visual_input) |
| c3 = self.stage2(c2) |
| c4 = self.stage3(c3) |
| c5 = self.stage4(c4) |
| |
| |
| target_size = (self.target_h, self.target_w) |
| dir_c3 = F.interpolate(dir_encoding, size=c3.shape[2:], mode='bilinear', align_corners=False) |
| dir_c4 = F.interpolate(dir_encoding, size=c4.shape[2:], mode='bilinear', align_corners=False) |
| dir_c5 = F.interpolate(dir_encoding, size=c5.shape[2:], mode='bilinear', align_corners=False) |
| |
| |
| p_high = self.fusion_c3(c3, dir_c3) |
| p_high = F.interpolate(p_high, size=target_size, mode='bilinear', align_corners=False) |
| |
| p_mid = self.fusion_c4(c4, dir_c4) |
| p_mid = F.interpolate(p_mid, size=target_size, mode='bilinear', align_corners=False) |
| |
| p_low = self.fusion_c5(c5, dir_c5) |
| p_low = F.interpolate(p_low, size=target_size, mode='bilinear', align_corners=False) |
| |
| |
| multi_scale_feat = torch.cat([p_high, p_mid, p_low], dim=1) |
|
|
| projected = self.projection_conv(multi_scale_feat) |
| B, C, H, W = projected.shape |
| |
| tokens = projected.view(B, C, H * W).permute(0, 2, 1) |
| |
| tokens = self.transformer_encoder(tokens) |
| tokens = self.ln_final(tokens) |
| |
| return tokens |
|
|
| def get_attention_maps(self, context): |
| """get attention maps for visualization""" |
| with torch.no_grad(): |
| B, C, H, W = context.shape |
| split_indices = [3, 6] |
| ldr_map, hdr_map, view_dirs_map = torch.tensor_split(context, split_indices, dim=1) |
| |
| dir_encoding = self.dir_encoder(view_dirs_map) |
| visual_input = torch.cat([ldr_map, hdr_map], dim=1) |
| |
| c2 = self.stage1(visual_input) |
| c3 = self.stage2(c2) |
| c4 = self.stage3(c3) |
| c5 = self.stage4(c4) |
| |
| |
| dir_c3 = F.interpolate(dir_encoding, size=c3.shape[2:], mode='bilinear', align_corners=False) |
| dir_c4 = F.interpolate(dir_encoding, size=c4.shape[2:], mode='bilinear', align_corners=False) |
| dir_c5 = F.interpolate(dir_encoding, size=c5.shape[2:], mode='bilinear', align_corners=False) |
| |
| |
| att_c3 = self.fusion_c3.attention(torch.cat([c3, dir_c3], dim=1)) |
| att_c4 = self.fusion_c4.attention(torch.cat([c4, dir_c4], dim=1)) |
| att_c5 = self.fusion_c5.attention(torch.cat([c5, dir_c5], dim=1)) |
| |
| return { |
| 'attention_c3': att_c3, |
| 'attention_c4': att_c4, |
| 'attention_c5': att_c5, |
| } |
|
|
|
|
| if __name__ == "__main__": |
| import time |
| |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| |
| |
| print("Testing individual components...") |
| |
| |
| rms_norm = RMSNorm(768).to(device) |
| x = torch.randn(2, 1024, 768).to(device) |
| out_rms = rms_norm(x) |
| print(f"RMSNorm output shape: {out_rms.shape}") |
| |
| |
| rope = RotaryPositionalEmbedding(64).to(device) |
| q = torch.randn(2, 8, 1024, 64).to(device) |
| k = torch.randn(2, 8, 1024, 64).to(device) |
| cos, sin = rope(q) |
| q_rope, k_rope = apply_rotary_pos_emb(q, k, cos.unsqueeze(0).unsqueeze(0), sin.unsqueeze(0).unsqueeze(0)) |
| print(f"RoPE output shapes - q: {q_rope.shape}, k: {k_rope.shape}") |
| |
| |
| mha = MultiHeadAttentionWithRoPE(768, 8).to(device) |
| x = torch.randn(2, 1024, 768).to(device) |
| out_mha = mha(x, x, x) |
| print(f"MultiHeadAttentionWithRoPE output shape: {out_mha.shape}") |
| |
| |
| transformer_layer = TransformerEncoderLayerWithRoPE(768, 8).to(device) |
| out_layer = transformer_layer(x) |
| print(f"TransformerEncoderLayerWithRoPE output shape: {out_layer.shape}") |
| |
| |
| print("\nPerformance comparison...") |
| |
| |
| layer_norm = nn.LayerNorm(768).to(device) |
| |
| |
| start = time.time() |
| for _ in range(100): |
| _ = layer_norm(x) |
| ln_time = time.time() - start |
| |
| |
| start = time.time() |
| for _ in range(100): |
| _ = rms_norm(x) |
| rms_time = time.time() - start |
| |
| print(f"LayerNorm time: {ln_time:.4f}s") |
| print(f"RMSNorm time: {rms_time:.4f}s") |
| print(f"RMSNorm speedup: {ln_time/rms_time:.2f}x") |
| |
| |
| ln_params = sum(p.numel() for p in layer_norm.parameters()) |
| rms_params = sum(p.numel() for p in rms_norm.parameters()) |
| print(f"LayerNorm params: {ln_params}") |
| print(f"RMSNorm params: {rms_params}") |
| print(f"Parameter reduction: {(ln_params - rms_params) / ln_params * 100:.1f}%") |
| |
| print("✓ All tests passed!") |
|
|
| |
| context = torch.randn(2, 9, 512, 512).to(device) |
| encoder = Hdri_Encoder(output_dim=768, num_tokens=4096, cnn_out_channels=128, |
| n_heads=8, num_transformer_layers=2, rope_base=10000.0).to(device) |
| output = encoder(context) |
| print(f"Hdri_Encoder output shape: {output.shape}") |
|
|
| |
| attention_maps = encoder.get_attention_maps(context) |
| print(f"Attention maps: {attention_maps}") |
|
|
| |
| print(f"Hdri_Encoder parameters: {sum(p.numel()/1e6 for p in encoder.parameters())}M") |