| import os | |
| import torch.nn as nn | |
| import torch | |
| import torch.distributed as dist | |
| local_rank = int(os.getenv("LOCAL_RANK", "0")) | |
| world_size = torch.cuda.device_count() | |
| rank = local_rank | |
| class LayerWrapper(nn.Module): | |
| def __init__( | |
| self, | |
| layer, | |
| layer_idx, | |
| internal_projection=4, | |
| img_pattern=[151652], | |
| motion_token=0 | |
| ): | |
| super().__init__() | |
| self.layer = layer | |
| self.layer_idx = layer_idx | |
| self.internal_projection = internal_projection | |
| self.motion_token = motion_token | |
| self.img_pattern = img_pattern | |
| assert motion_token == 1 | |
| def get_removing_indices(self, hidden_states, input_ids): | |
| pat_len = len(self.img_pattern) | |
| windows = input_ids.unfold(dimension=1, size=pat_len, step=1) | |
| pattern_tensor = torch.tensor(self.img_pattern, device=hidden_states.device).view(1, 1, -1) | |
| matches = (windows == pattern_tensor).all(dim=-1) | |
| match_lists = [torch.nonzero(matches[b], as_tuple=False).squeeze(-1) for b in range(hidden_states.shape[0])] | |
| begin_idx = torch.tensor([m[0] for m in match_lists], device=hidden_states.device).unsqueeze(1) | |
| end_idx = torch.tensor([m[-1] for m in match_lists], device=hidden_states.device).unsqueeze(1) | |
| return begin_idx, end_idx | |
| def left_pad_emb_list(self, emb_list): | |
| rev = [e.flip(0) for e in emb_list] | |
| padded_rev = torch.nn.utils.rnn.pad_sequence(rev, batch_first=True, padding_value=0) | |
| return padded_rev.flip(1) | |
| def forward(self, hidden_states, input_ids, *args, **kwargs): | |
| bsz, seq_len, dim = hidden_states.shape | |
| is_incremental = ( | |
| "cache_position" in kwargs | |
| and kwargs["cache_position"] is not None | |
| and seq_len == 1 | |
| ) | |
| if self.layer_idx == self.internal_projection and not is_incremental: | |
| device = hidden_states.device | |
| token_indices = torch.arange(seq_len, device=device).view(1, -1).expand(bsz, -1) | |
| begin_idx, end_idx = self.get_removing_indices(hidden_states, input_ids) | |
| compress_mask = (end_idx > begin_idx).reshape(-1) | |
| keep_mask_front = token_indices < begin_idx | |
| keep_mask_back = token_indices >= end_idx | |
| drop_mask = ~(keep_mask_front | keep_mask_back) | |
| motion_token = ( | |
| (hidden_states * drop_mask.unsqueeze(-1)).sum(dim=1) | |
| / drop_mask.sum(dim=1, keepdim=True).clamp(min=1) | |
| ).reshape(bsz, self.motion_token, -1) | |
| hidden_states = [ | |
| torch.cat([ | |
| hidden_states[b][keep_mask_front[b]], | |
| motion_token[b] if compress_mask[b] else torch.tensor([], device=hidden_states.device, dtype=hidden_states.dtype), | |
| hidden_states[b][keep_mask_back[b]] | |
| ], dim=0) for b in range(bsz) | |
| ] | |
| hidden_states = self.left_pad_emb_list(hidden_states) | |
| if 'attention_mask' in kwargs and kwargs['attention_mask'] is not None: | |
| att_list = [ | |
| torch.cat([ | |
| kwargs["attention_mask"][b][keep_mask_front[b]], | |
| torch.ones(1, device=kwargs["attention_mask"].device, dtype=kwargs["attention_mask"].dtype) if compress_mask[b] else torch.tensor([], device=kwargs["attention_mask"].device, dtype=kwargs["attention_mask"].dtype), | |
| kwargs["attention_mask"][b][keep_mask_back[b]], | |
| ]) for b in range(bsz) | |
| ] | |
| kwargs["attention_mask"] = self.left_pad_emb_list(att_list) | |
| if 'position_ids' in kwargs.keys() and kwargs['position_ids'] is not None: | |
| pos_list = [ | |
| torch.cat([ | |
| kwargs["position_ids"][b][keep_mask_front[b]], | |
| kwargs["position_ids"][b][begin_idx[b]:begin_idx[b]+1] if compress_mask[b] else torch.tensor([], device=kwargs["position_ids"].device, dtype=kwargs["position_ids"].dtype), | |
| kwargs["position_ids"][b][keep_mask_back[b]], | |
| ]) for b in range(bsz) | |
| ] | |
| kwargs["position_ids"] = self.left_pad_emb_list(pos_list) | |
| if 'position_embeddings' in kwargs.keys() and kwargs['position_embeddings'] is not None: | |
| emb_x_list = [ | |
| torch.cat([ | |
| kwargs["position_embeddings"][0][b][keep_mask_front[b]], | |
| kwargs["position_embeddings"][0][b][begin_idx[b]:begin_idx[b]+1] if compress_mask[b] else torch.tensor([], device=kwargs["position_embeddings"][0].device, dtype=kwargs["position_embeddings"][0].dtype), | |
| kwargs["position_embeddings"][0][b][keep_mask_back[b]], | |
| ], dim=0) for b in range(bsz) | |
| ] | |
| emb_y_list = [ | |
| torch.cat([ | |
| kwargs["position_embeddings"][1][b][keep_mask_front[b]], | |
| kwargs["position_embeddings"][1][b][begin_idx[b]:begin_idx[b]+1] if compress_mask[b] else torch.tensor([], device=kwargs["position_embeddings"][0].device, dtype=kwargs["position_embeddings"][0].dtype), | |
| kwargs["position_embeddings"][1][b][keep_mask_back[b]], | |
| ], dim=0) for b in range(bsz) | |
| ] | |
| emb_x_padded = self.left_pad_emb_list(emb_x_list) | |
| emb_y_padded = self.left_pad_emb_list(emb_y_list) | |
| kwargs["position_embeddings"] = (emb_x_padded, emb_y_padded) | |
| if "cache_position" in kwargs and kwargs["cache_position"] is not None: | |
| kwargs["cache_position"] = kwargs["cache_position"][: hidden_states.shape[1]] | |
| return self.layer(hidden_states, *args, **kwargs), kwargs | |