import os import torch.nn as nn import torch import torch.distributed as dist local_rank = int(os.getenv("LOCAL_RANK", "0")) world_size = torch.cuda.device_count() rank = local_rank class LayerWrapper(nn.Module): def __init__( self, layer, layer_idx, internal_projection=4, img_pattern=[151652], motion_token=0 ): super().__init__() self.layer = layer self.layer_idx = layer_idx self.internal_projection = internal_projection self.motion_token = motion_token self.img_pattern = img_pattern assert motion_token == 1 def get_removing_indices(self, hidden_states, input_ids): pat_len = len(self.img_pattern) windows = input_ids.unfold(dimension=1, size=pat_len, step=1) pattern_tensor = torch.tensor(self.img_pattern, device=hidden_states.device).view(1, 1, -1) matches = (windows == pattern_tensor).all(dim=-1) match_lists = [torch.nonzero(matches[b], as_tuple=False).squeeze(-1) for b in range(hidden_states.shape[0])] begin_idx = torch.tensor([m[0] for m in match_lists], device=hidden_states.device).unsqueeze(1) end_idx = torch.tensor([m[-1] for m in match_lists], device=hidden_states.device).unsqueeze(1) return begin_idx, end_idx def left_pad_emb_list(self, emb_list): rev = [e.flip(0) for e in emb_list] padded_rev = torch.nn.utils.rnn.pad_sequence(rev, batch_first=True, padding_value=0) return padded_rev.flip(1) def forward(self, hidden_states, input_ids, *args, **kwargs): bsz, seq_len, dim = hidden_states.shape is_incremental = ( "cache_position" in kwargs and kwargs["cache_position"] is not None and seq_len == 1 ) if self.layer_idx == self.internal_projection and not is_incremental: device = hidden_states.device token_indices = torch.arange(seq_len, device=device).view(1, -1).expand(bsz, -1) begin_idx, end_idx = self.get_removing_indices(hidden_states, input_ids) compress_mask = (end_idx > begin_idx).reshape(-1) keep_mask_front = token_indices < begin_idx keep_mask_back = token_indices >= end_idx drop_mask = ~(keep_mask_front | keep_mask_back) motion_token = ( (hidden_states * drop_mask.unsqueeze(-1)).sum(dim=1) / drop_mask.sum(dim=1, keepdim=True).clamp(min=1) ).reshape(bsz, self.motion_token, -1) hidden_states = [ torch.cat([ hidden_states[b][keep_mask_front[b]], motion_token[b] if compress_mask[b] else torch.tensor([], device=hidden_states.device, dtype=hidden_states.dtype), hidden_states[b][keep_mask_back[b]] ], dim=0) for b in range(bsz) ] hidden_states = self.left_pad_emb_list(hidden_states) if 'attention_mask' in kwargs and kwargs['attention_mask'] is not None: att_list = [ torch.cat([ kwargs["attention_mask"][b][keep_mask_front[b]], torch.ones(1, device=kwargs["attention_mask"].device, dtype=kwargs["attention_mask"].dtype) if compress_mask[b] else torch.tensor([], device=kwargs["attention_mask"].device, dtype=kwargs["attention_mask"].dtype), kwargs["attention_mask"][b][keep_mask_back[b]], ]) for b in range(bsz) ] kwargs["attention_mask"] = self.left_pad_emb_list(att_list) if 'position_ids' in kwargs.keys() and kwargs['position_ids'] is not None: pos_list = [ torch.cat([ kwargs["position_ids"][b][keep_mask_front[b]], kwargs["position_ids"][b][begin_idx[b]:begin_idx[b]+1] if compress_mask[b] else torch.tensor([], device=kwargs["position_ids"].device, dtype=kwargs["position_ids"].dtype), kwargs["position_ids"][b][keep_mask_back[b]], ]) for b in range(bsz) ] kwargs["position_ids"] = self.left_pad_emb_list(pos_list) if 'position_embeddings' in kwargs.keys() and kwargs['position_embeddings'] is not None: emb_x_list = [ torch.cat([ kwargs["position_embeddings"][0][b][keep_mask_front[b]], kwargs["position_embeddings"][0][b][begin_idx[b]:begin_idx[b]+1] if compress_mask[b] else torch.tensor([], device=kwargs["position_embeddings"][0].device, dtype=kwargs["position_embeddings"][0].dtype), kwargs["position_embeddings"][0][b][keep_mask_back[b]], ], dim=0) for b in range(bsz) ] emb_y_list = [ torch.cat([ kwargs["position_embeddings"][1][b][keep_mask_front[b]], kwargs["position_embeddings"][1][b][begin_idx[b]:begin_idx[b]+1] if compress_mask[b] else torch.tensor([], device=kwargs["position_embeddings"][0].device, dtype=kwargs["position_embeddings"][0].dtype), kwargs["position_embeddings"][1][b][keep_mask_back[b]], ], dim=0) for b in range(bsz) ] emb_x_padded = self.left_pad_emb_list(emb_x_list) emb_y_padded = self.left_pad_emb_list(emb_y_list) kwargs["position_embeddings"] = (emb_x_padded, emb_y_padded) if "cache_position" in kwargs and kwargs["cache_position"] is not None: kwargs["cache_position"] = kwargs["cache_position"][: hidden_states.shape[1]] return self.layer(hidden_states, *args, **kwargs), kwargs