huiwon's picture
Upload folder using huggingface_hub
e6ab8f6 verified
import os
import torch.nn as nn
import torch
import torch.distributed as dist
local_rank = int(os.getenv("LOCAL_RANK", "0"))
world_size = torch.cuda.device_count()
rank = local_rank
class LayerWrapper(nn.Module):
def __init__(
self,
layer,
layer_idx,
internal_projection=4,
img_pattern=[151652],
motion_token=0
):
super().__init__()
self.layer = layer
self.layer_idx = layer_idx
self.internal_projection = internal_projection
self.motion_token = motion_token
self.img_pattern = img_pattern
assert motion_token == 1
def get_removing_indices(self, hidden_states, input_ids):
pat_len = len(self.img_pattern)
windows = input_ids.unfold(dimension=1, size=pat_len, step=1)
pattern_tensor = torch.tensor(self.img_pattern, device=hidden_states.device).view(1, 1, -1)
matches = (windows == pattern_tensor).all(dim=-1)
match_lists = [torch.nonzero(matches[b], as_tuple=False).squeeze(-1) for b in range(hidden_states.shape[0])]
begin_idx = torch.tensor([m[0] for m in match_lists], device=hidden_states.device).unsqueeze(1)
end_idx = torch.tensor([m[-1] for m in match_lists], device=hidden_states.device).unsqueeze(1)
return begin_idx, end_idx
def left_pad_emb_list(self, emb_list):
rev = [e.flip(0) for e in emb_list]
padded_rev = torch.nn.utils.rnn.pad_sequence(rev, batch_first=True, padding_value=0)
return padded_rev.flip(1)
def forward(self, hidden_states, input_ids, *args, **kwargs):
bsz, seq_len, dim = hidden_states.shape
is_incremental = (
"cache_position" in kwargs
and kwargs["cache_position"] is not None
and seq_len == 1
)
if self.layer_idx == self.internal_projection and not is_incremental:
device = hidden_states.device
token_indices = torch.arange(seq_len, device=device).view(1, -1).expand(bsz, -1)
begin_idx, end_idx = self.get_removing_indices(hidden_states, input_ids)
compress_mask = (end_idx > begin_idx).reshape(-1)
keep_mask_front = token_indices < begin_idx
keep_mask_back = token_indices >= end_idx
drop_mask = ~(keep_mask_front | keep_mask_back)
motion_token = (
(hidden_states * drop_mask.unsqueeze(-1)).sum(dim=1)
/ drop_mask.sum(dim=1, keepdim=True).clamp(min=1)
).reshape(bsz, self.motion_token, -1)
hidden_states = [
torch.cat([
hidden_states[b][keep_mask_front[b]],
motion_token[b] if compress_mask[b] else torch.tensor([], device=hidden_states.device, dtype=hidden_states.dtype),
hidden_states[b][keep_mask_back[b]]
], dim=0) for b in range(bsz)
]
hidden_states = self.left_pad_emb_list(hidden_states)
if 'attention_mask' in kwargs and kwargs['attention_mask'] is not None:
att_list = [
torch.cat([
kwargs["attention_mask"][b][keep_mask_front[b]],
torch.ones(1, device=kwargs["attention_mask"].device, dtype=kwargs["attention_mask"].dtype) if compress_mask[b] else torch.tensor([], device=kwargs["attention_mask"].device, dtype=kwargs["attention_mask"].dtype),
kwargs["attention_mask"][b][keep_mask_back[b]],
]) for b in range(bsz)
]
kwargs["attention_mask"] = self.left_pad_emb_list(att_list)
if 'position_ids' in kwargs.keys() and kwargs['position_ids'] is not None:
pos_list = [
torch.cat([
kwargs["position_ids"][b][keep_mask_front[b]],
kwargs["position_ids"][b][begin_idx[b]:begin_idx[b]+1] if compress_mask[b] else torch.tensor([], device=kwargs["position_ids"].device, dtype=kwargs["position_ids"].dtype),
kwargs["position_ids"][b][keep_mask_back[b]],
]) for b in range(bsz)
]
kwargs["position_ids"] = self.left_pad_emb_list(pos_list)
if 'position_embeddings' in kwargs.keys() and kwargs['position_embeddings'] is not None:
emb_x_list = [
torch.cat([
kwargs["position_embeddings"][0][b][keep_mask_front[b]],
kwargs["position_embeddings"][0][b][begin_idx[b]:begin_idx[b]+1] if compress_mask[b] else torch.tensor([], device=kwargs["position_embeddings"][0].device, dtype=kwargs["position_embeddings"][0].dtype),
kwargs["position_embeddings"][0][b][keep_mask_back[b]],
], dim=0) for b in range(bsz)
]
emb_y_list = [
torch.cat([
kwargs["position_embeddings"][1][b][keep_mask_front[b]],
kwargs["position_embeddings"][1][b][begin_idx[b]:begin_idx[b]+1] if compress_mask[b] else torch.tensor([], device=kwargs["position_embeddings"][0].device, dtype=kwargs["position_embeddings"][0].dtype),
kwargs["position_embeddings"][1][b][keep_mask_back[b]],
], dim=0) for b in range(bsz)
]
emb_x_padded = self.left_pad_emb_list(emb_x_list)
emb_y_padded = self.left_pad_emb_list(emb_y_list)
kwargs["position_embeddings"] = (emb_x_padded, emb_y_padded)
if "cache_position" in kwargs and kwargs["cache_position"] is not None:
kwargs["cache_position"] = kwargs["cache_position"][: hidden_states.shape[1]]
return self.layer(hidden_states, *args, **kwargs), kwargs