import torch import torch.nn as nn from torch.utils.checkpoint import checkpoint from .clip_encoder import CLIPVisionTower import torch.nn.functional as F from torch.nn.init import trunc_normal_ from copy import deepcopy import random import math import os import sys from pathlib import Path from omegaconf import OmegaConf from hydra.utils import instantiate from .custom_clip import _expand_mask def _verbose_log(message): if os.environ.get("PIXDLM_VERBOSE", "0") == "1": print(message) class MultiPathAlignModule(nn.Module): def __init__(self, fast_vision_dim, slow_vision_dim,pretrained_weights=None, prefix=""): super().__init__() self.fast_proj = nn.Linear(fast_vision_dim, fast_vision_dim) self.slow_proj = nn.Linear(slow_vision_dim, fast_vision_dim) self.load_pretrained_weights(pretrained_weights, prefix) def load_pretrained_weights(self, weights_dict, prefix=""): fast_proj_weight_key = f"{prefix}fast_proj.weight" fast_proj_bias_key = f"{prefix}fast_proj.bias" slow_proj_weight_key = f"{prefix}slow_proj.weight" slow_proj_bias_key = f"{prefix}slow_proj.bias" if fast_proj_weight_key in weights_dict: self.fast_proj.weight.data.copy_(weights_dict[fast_proj_weight_key].to(self.fast_proj.weight.dtype)) _verbose_log(f"Loaded {fast_proj_weight_key}") if fast_proj_bias_key in weights_dict: self.fast_proj.bias.data.copy_(weights_dict[fast_proj_bias_key].to(self.fast_proj.bias.dtype)) _verbose_log(f"Loaded {fast_proj_bias_key}") if slow_proj_weight_key in weights_dict: self.slow_proj.weight.data.copy_(weights_dict[slow_proj_weight_key].to(self.slow_proj.weight.dtype)) _verbose_log(f"Loaded {slow_proj_weight_key}") if slow_proj_bias_key in weights_dict: self.slow_proj.bias.data.copy_(weights_dict[slow_proj_bias_key].to(self.slow_proj.bias.dtype)) _verbose_log(f"Loaded {slow_proj_bias_key}") def forward(self, fast_feat, slow_feat): #修改,这里也写死了 target_dtype = torch.bfloat16 if fast_feat.dtype != target_dtype: fast_feat = fast_feat.to(target_dtype) if slow_feat.dtype != target_dtype: slow_feat = slow_feat.to(target_dtype) if slow_feat.ndim == 4: b, c, h, w = slow_feat.shape slow_feat = slow_feat.view(b, c, -1).transpose(1, 2) assert slow_feat.shape[1] % fast_feat.shape[1] == 0 or fast_feat.shape[1] % slow_feat.shape[1] == 0 if slow_feat.shape[1] < fast_feat.shape[1]: # upsample b, l, c = slow_feat.shape src_size = int(math.sqrt(l)) dst_size = int(math.sqrt(fast_feat.shape[1])) slow_feat = slow_feat.transpose(1, 2).view(b, c, src_size, src_size) slow_feat = F.interpolate(slow_feat.float(), size=(dst_size, dst_size), mode='bilinear', align_corners=True).to(dtype=slow_feat.dtype) slow_feat = slow_feat.view(b, c, -1).transpose(1, 2) elif slow_feat.shape[1] > fast_feat.shape[1]: # pooling b, l, c = slow_feat.shape src_size = int(math.sqrt(l)) dst_size = int(math.sqrt(fast_feat.shape[1])) slow_feat = slow_feat.transpose(1, 2).view(b, c, src_size, src_size) slow_feat = F.avg_pool2d(slow_feat, src_size // dst_size, src_size // dst_size) slow_feat = slow_feat.view(b, c, -1).transpose(1, 2) patch_feat = self.fast_proj(fast_feat) + self.slow_proj(slow_feat) # print("patch_feat :",patch_feat.shape) return patch_feat class S2FStitchAlignModuleV2(nn.Module): def __init__(self, fast_vision_dim, slow_vision_dim, zero_init=True): super().__init__() self.slow_conv = nn.Conv2d(slow_vision_dim, slow_vision_dim, 1) self.slow_proj = nn.Conv2d(slow_vision_dim, fast_vision_dim, 1) self.fast_conv = nn.Conv2d(fast_vision_dim, fast_vision_dim, 7, padding=3, groups=fast_vision_dim) self.fast_proj = nn.Conv2d(fast_vision_dim, fast_vision_dim, 1) self.gate = nn.Sequential( nn.Linear(fast_vision_dim*2, fast_vision_dim//2), nn.GELU(), nn.Linear(fast_vision_dim//2, 1) ) nn.init.xavier_uniform_(self.slow_conv.weight) nn.init.xavier_uniform_(self.fast_conv.weight) nn.init.zeros_(self.slow_conv.bias) nn.init.zeros_(self.fast_conv.bias) if zero_init: nn.init.zeros_(self.slow_proj.weight) nn.init.zeros_(self.fast_proj.weight) else: nn.init.xavier_uniform_(self.slow_proj.weight) nn.init.xavier_uniform_(self.fast_proj.weight) nn.init.zeros_(self.slow_proj.bias) nn.init.zeros_(self.fast_proj.bias) def load_pretrained_weights(self, weights_dict, prefix=""): for name, param in self.named_parameters(): full_key = prefix + name if full_key in weights_dict: param.data.copy_(weights_dict[full_key].to(param.dtype)) _verbose_log(f"Loaded {full_key}") def src2dst_align(self, src_feat, dst_feat): dst_size = int(math.sqrt(dst_feat.shape[1])) assert src_feat.shape[1] % dst_feat.shape[1] == 0 or dst_feat.shape[1] % src_feat.shape[1] == 0 if src_feat.shape[1] < dst_feat.shape[1]: # upsample b, l, c = src_feat.shape src_size = int(math.sqrt(l)) dst_size = int(math.sqrt(dst_feat.shape[1])) src_feat = src_feat.transpose(1, 2).view(b, c, src_size, src_size) src_feat = F.interpolate(src_feat.float(), size=(dst_size, dst_size), mode='bilinear', align_corners=True).to(dtype=src_feat.dtype) src_feat = src_feat.view(b, c, -1).transpose(1, 2) elif src_feat.shape[1] > dst_feat.shape[1]: # pooling b, l, c = src_feat.shape src_size = int(math.sqrt(l)) dst_size = int(math.sqrt(dst_feat.shape[1])) src_feat = src_feat.transpose(1, 2).view(b, c, src_size, src_size) src_feat = F.avg_pool2d(src_feat, src_size // dst_size, src_size // dst_size) src_feat = src_feat.view(b, c, -1).transpose(1, 2) return src_feat, dst_size def forward(self, fast_feat, slow_feat): b, c, h, w = slow_feat.shape _, _, d = fast_feat.shape slow_feat = self.slow_proj(F.gelu(self.slow_conv(slow_feat))) slow_feat = slow_feat.view(b, d, -1).transpose(1, 2) slow_feat_align, dst_size = self.src2dst_align(slow_feat, fast_feat) fast_feat = fast_feat.transpose(1, 2).view(b, d, dst_size, dst_size) fast_feat = fast_feat + self.fast_proj(F.gelu(self.fast_conv(fast_feat))) fast_feat = fast_feat.view(b, d, dst_size * dst_size).transpose(1, 2) gate=self.gate(torch.cat([fast_feat,slow_feat_align],-1).mean(1)).unsqueeze(1) fast_feat = fast_feat + slow_feat_align *gate.tanh() return fast_feat class MultiPathCLIPVisionTower(nn.Module): def __init__(self, vision_tower, args, delay_load=False): super().__init__() self.is_loaded = False # Use PIXDLM_ROOT when set; otherwise infer the deployment root from this file. base_dir = Path(os.environ.get("PIXDLM_ROOT", Path(__file__).resolve().parents[4])) sam2_source_dir = base_dir / "models" / "sam2" if str(sam2_source_dir) not in sys.path: sys.path.append(str(sam2_source_dir)) cfg = OmegaConf.load(str(base_dir / "models" / "sam2" / "sam2" / "configs" / "sam2.1" / "sam2.1_hiera_l.yaml")) model = instantiate(cfg.model) ckpt = torch.load(str(base_dir / "models" / "sam2_checkpoints" / "sam2.1_hiera_large.pt"), map_location="cpu") state_dict = ckpt["model"] model.load_state_dict(state_dict, strict=False) #, assign=True self.slow_vision_tower = model.image_encoder _verbose_log("Initialized SAM2 slow vision tower") # 快速分支保持CLIP不变 args_ = deepcopy(args) # 原来是336 args_.input_image_size = 448 self.fast_vision_tower = CLIPVisionTower(vision_tower, args_, delay_load=delay_load) _verbose_log("Initialized CLIP fast vision tower") self.load_model() self.vision_tower_name = vision_tower self.select_layer = args.mm_vision_select_layer self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch') self.splits = self.select_layer // 100 if self.select_layer > 100 else 1 self.enable_adapter = not args.freeze_vision _verbose_log(f"enable_adapter={self.enable_adapter}") #暂时没传 self.image_size = 800 # SAM2的hidden_size是256(来自neck的d_model) sam2_hidden_size = 256 if self.enable_adapter: self.align_stages_latent = nn.ModuleList([S2FStitchAlignModuleV2(self.fast_vision_tower.hidden_size, sam2_hidden_size, True) for i in range(3)]) align_weights_path = base_dir / "models" / "pixdlm_align_stages.pth" if align_weights_path.exists(): weights_dict = torch.load(str(align_weights_path), map_location="cpu") else: _verbose_log(f"align weights not found at {align_weights_path}; using initialized align modules") weights_dict = {} self.align_stages = nn.ModuleList([MultiPathAlignModule(self.fast_vision_tower.hidden_size, sam2_hidden_size,pretrained_weights=weights_dict,prefix="base_model.model.model.vision_tower.align_stages.0." ) ]) for i in range(3): self.align_stages_latent[i].load_pretrained_weights( weights_dict, prefix=f"base_model.model.model.vision_tower.align_stages_latent.{i}." ) def load_model(self): # SAM2 encoder已经在初始化时加载 self.fast_vision_tower.load_model() self.image_processor = self.fast_vision_tower.image_processor # 使用CLIP的预处理器 self.is_loaded = True def forward(self, x,attention_mask=None,output_attentions=False,output_keys=False): #尺寸相同 # 快速分支预处理 fast_image_size = 448 y = F.interpolate(x.float(), size=(fast_image_size, fast_image_size), mode='bilinear', align_corners=True).to(dtype=x.dtype) y = self.fast_vision_tower.vision_tower.vision_model.embeddings(y) # print("y1:",y.shape) y = self.fast_vision_tower.vision_tower.vision_model.pre_layrnorm(y[:, 1:]) # print("y2:",y.shape) # SAM2慢速分支处理 slow_image_size = 1024 # 或者你想要的其他尺寸 x_resized = F.interpolate(x.float(), size=(slow_image_size, slow_image_size), mode='bilinear', align_corners=True).to(dtype=x.dtype) # print("x_resized:",x_resized.shape) with torch.no_grad(): sam_backbone_out = self.slow_vision_tower(x_resized) sam_features = sam_backbone_out["vision_features"] # [B, C, H, W] # print("sam_features:",sam_features.shape) #修改 # 你有 latent sam_features = sam_features.to(torch.bfloat16) if attention_mask.shape[-1] == 1025: attention_mask = attention_mask[:, 1:] # 变为 [1, 1024] # 使用 _expand_mask 函数进行维度扩展 expanded_mask = _expand_mask(attention_mask, attention_mask.dtype, tgt_len=1024) # 快速分支的分阶段处理 fast_blk = self.fast_vision_tower.vision_tower.vision_model.encoder.layers n_blks = len(fast_blk) // 4 assert len(fast_blk) == n_blks * 4 # 第一阶段 for blk in fast_blk[:n_blks]: if self.training: y = checkpoint(blk.__call__, y,expanded_mask, None)[0] else: y = blk(y, expanded_mask, None)[0] if self.enable_adapter: y = self.align_stages_latent[0](y, sam_features) # 第二阶段 for blk in fast_blk[n_blks:2 * n_blks]: if self.training: y = checkpoint(blk.__call__, y, expanded_mask, None)[0] else: y = blk(y, expanded_mask, None)[0] if self.enable_adapter: # print("没有走哦") y = self.align_stages_latent[1](y, sam_features) # 第三阶段 for blk in fast_blk[2 * n_blks:3 * n_blks]: if self.training: y = checkpoint(blk.__call__, y, expanded_mask, None)[0] else: y = blk(y, expanded_mask, None)[0] if self.enable_adapter: y = self.align_stages_latent[2](y, sam_features) last_blk_idx = len(fast_blk[3 * n_blks:]) - 1 last_attention = None last_keys = None # 第四阶段 for i, blk in enumerate(fast_blk[3 * n_blks:]): if self.training: if i == last_blk_idx: # 最后一个 block,获取 attention # 部分 CLIP layer 不支持 output_keys;仅在可用时请求 try: outputs = blk( y, expanded_mask, None, output_attentions=False, output_keys=output_keys ) except TypeError: outputs = blk(y, expanded_mask, None, output_attentions=False) y = outputs[0] last_attention = outputs[1] if len(outputs) > 1 else None last_keys = outputs[-1] if output_keys and len(outputs) > 1 else None else: y = checkpoint(blk.__call__, y, expanded_mask, None)[0] else: if i == last_blk_idx: # 最后一个 block,获取 attention try: outputs = blk( y, expanded_mask, None, output_attentions=False, output_keys=output_keys ) except TypeError: outputs = blk(y, expanded_mask, None, output_attentions=False) y = outputs[0] last_attention = outputs[1] if len(outputs) > 1 else None last_keys = outputs[-1] if output_keys and len(outputs) > 1 else None else: y = blk(y, expanded_mask, None)[0] # 最终特征融合 y = self.align_stages[0](y, sam_features) if last_keys is not None: # 对所有 heads 求平均: [B, num_heads, N, head_dim] -> [B, N, head_dim] last_keys = last_keys.mean(dim=1) #修改 # return y if last_attention is not None: last_attention = last_attention.mean(dim=1) #修改少返回一点 if output_keys: return y, [y],last_keys else: return y, [y] def forward_sam_multilayer_features(self, x): """ 专门用于提取 SAM2 encoder 的多层特征,作为 fimg 特征送入下游 decoder。 最多返回 4 层特征(约对应 256x256, 128x128, 64x64, 32x32),每层通道统一为 256。 """ slow_image_size = 1024 x_resized = F.interpolate( x.float(), size=(slow_image_size, slow_image_size), mode="bilinear", align_corners=True, ).to(dtype=x.dtype) with torch.no_grad(): sam_backbone_out = self.slow_vision_tower(x_resized) backbone_fpn = sam_backbone_out.get("backbone_fpn", None) if backbone_fpn is not None and len(backbone_fpn) >= 1: # backbone_fpn[0]: (B, 144, 256, 256) - 最高分辨率 # backbone_fpn[1]: (B, 288, 128, 128) # backbone_fpn[2]: (B, 576, 64, 64) # backbone_fpn[3]: (B, 1152, 32, 32) - 最低分辨率 # neck.convs 按通道从低分辨率到高分辨率构建: # convs[0] ← 1152 → backbone_fpn[3] # convs[1] ← 576 → backbone_fpn[2] # convs[2] ← 288 → backbone_fpn[1] # convs[3] ← 144 → backbone_fpn[0] neck = self.slow_vision_tower.neck max_layers = min(len(backbone_fpn), len(neck.convs)) selected_backbone_indices = list(range(max_layers)) processed_features = [] for backbone_idx in selected_backbone_indices: conv_idx = len(neck.convs) - 1 - backbone_idx if 0 <= conv_idx < len(neck.convs) and backbone_idx < len(backbone_fpn): backbone_feat = backbone_fpn[backbone_idx] conv_layer = neck.convs[conv_idx] # 如果 conv_layer 内部还有 conv 子模块,检查通道是否匹配 if hasattr(conv_layer, "conv"): conv = conv_layer.conv expected_in_channels = conv.in_channels actual_channels = backbone_feat.shape[1] if actual_channels != expected_in_channels: if hasattr(self, "local_rank") and getattr(self, "local_rank", 0) == 0: print( f"Error: backbone_fpn[{backbone_idx}] has {actual_channels} channels, " f"but convs[{conv_idx}] expects {expected_in_channels} channels. " f"backbone_fpn shape: {backbone_feat.shape}" ) continue processed_feat = conv_layer(backbone_feat) processed_features.append(processed_feat) if len(processed_features) > 0: return processed_features # 如果无法从 backbone_fpn 中提取到有效多层特征,退回到单层 vision_features, # 并复制成若干层,保持与 neck.convs 或 4 层中的较小值一致。 sam_features = sam_backbone_out["vision_features"] fallback_layers = min(len(self.slow_vision_tower.neck.convs), 4) return [sam_features for _ in range(fallback_layers)] def forward_features(self, x): raise NotImplementedError @property def dummy_feature(self): return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) @property def dtype(self): return next(self.fast_vision_tower.parameters()).dtype @property def device(self): return next(self.fast_vision_tower.parameters()).device @property def config(self): raise NotImplementedError @property def hidden_size(self): return self.fast_vision_tower.hidden_size @property def num_patches(self): return self.fast_vision_tower.num_patches