Spaces:
Paused
Paused
| import torch | |
| import torch.nn.functional as F | |
| from torch import nn | |
| from torch.nn.attention.flex_attention import or_masks, and_masks | |
| from transformers.activations import ACT2FN | |
| class MLPconnector(nn.Module): | |
| def __init__(self, in_dim: int, out_dim: int, hidden_act: str): | |
| super().__init__() | |
| self.activation_fn = ACT2FN[hidden_act] | |
| self.fc1 = nn.Linear(in_dim, out_dim) | |
| self.fc2 = nn.Linear(out_dim, out_dim) | |
| def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: | |
| hidden_states = self.fc1(hidden_states) | |
| hidden_states = self.activation_fn(hidden_states) | |
| hidden_states = self.fc2(hidden_states) | |
| return hidden_states | |
| def create_sparse_mask(document_lens, split_lens, attn_modes, parallel_num, device): | |
| parallel_causal_num = 2 | |
| parallel_block_causal_num = parallel_num | |
| def causal_mask(b, h, q_idx, kv_idx): | |
| return q_idx >= kv_idx | |
| def parallel_block_mask(b, h, q_idx, kv_idx): | |
| same_seg = segment_ids[q_idx] == segment_ids[kv_idx] | |
| is_par = is_parallel[q_idx] | |
| lq = local_ids[q_idx] | |
| lk = local_ids[kv_idx] | |
| in_block_region = (lq >= parallel_causal_num) & (lk >= parallel_causal_num) | |
| same_block = ((lq - parallel_causal_num) // parallel_block_causal_num) == ((lk - parallel_causal_num) // parallel_block_causal_num) | |
| return same_seg & is_par & in_block_region & same_block | |
| def sample_mask(b, h, q_idx, kv_idx): | |
| return document_id[q_idx] == document_id[kv_idx] | |
| segment_ids_list = [] | |
| local_ids_list = [] | |
| is_parallel_list = [] | |
| current_seg_id = 0 | |
| for length, mode in zip(split_lens, attn_modes): | |
| segment_ids_list.extend([current_seg_id] * length) | |
| local_ids_list.extend(list(range(length))) | |
| is_parallel_list.extend([True if mode == 'parallel' else False] * length) | |
| current_seg_id += 1 | |
| segment_ids = torch.tensor(segment_ids_list, device=device, dtype=torch.long) | |
| local_ids = torch.tensor(local_ids_list, device=device, dtype=torch.long) | |
| is_parallel = torch.tensor(is_parallel_list, device=device, dtype=torch.bool) | |
| document_id = torch.cat([torch.full((l,), i, device=device) for i, l in enumerate(document_lens, start=1)]) | |
| return and_masks(or_masks(causal_mask, parallel_block_mask), sample_mask) | |
| def top_k_top_p_filtering( | |
| logits, | |
| top_k: int = 0, | |
| top_p: float = 1.0, | |
| filter_value: float = -float("Inf"), | |
| min_tokens_to_keep: int = 1, | |
| ): | |
| """Filter a distribution of logits using top-k and/or top-p (nucleus) filtering.""" | |
| if top_k > 0: | |
| top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1)) | |
| indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] | |
| logits[indices_to_remove] = filter_value | |
| if top_p < 1.0: | |
| sorted_logits, sorted_indices = torch.sort(logits, descending=True) | |
| cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) | |
| sorted_indices_to_remove = cumulative_probs > top_p | |
| if min_tokens_to_keep > 1: | |
| sorted_indices_to_remove[..., :min_tokens_to_keep] = 0 | |
| sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() | |
| sorted_indices_to_remove[..., 0] = 0 | |
| indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) | |
| logits[indices_to_remove] = filter_value | |
| return logits | |
| def sample_codebook( | |
| pred_logits, | |
| cur_item_type, | |
| codebook, | |
| do_sample: bool = True, | |
| temperature: float = 1.0, | |
| top_k: int = 0, | |
| top_p: float = 1.0, | |
| ): | |
| """ | |
| pred_logits: (B, vocab_size) | |
| cur_item_type: 'text' or 'vision' | |
| """ | |
| # 1. Apply temperature | |
| logits = pred_logits / max(temperature, 1e-5) | |
| # 2. Apply top-k / top-p filtering | |
| if top_k > 0 or top_p < 1.0: | |
| logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p) | |
| # 3. Get probabilities | |
| probs = F.softmax(logits, dim=-1) | |
| # 4. Sample or take argmax | |
| if do_sample: | |
| curr_tokens = torch.multinomial(probs, num_samples=1).squeeze(-1) | |
| else: | |
| curr_tokens = torch.argmax(probs, dim=-1) | |
| curr_embeds = codebook(curr_tokens) | |
| return curr_tokens, curr_embeds | |
| def flip_tensor_elements_uniform_prob(tensor: torch.Tensor, p_max: float) -> torch.Tensor: | |
| if not 0.0 <= p_max <= 1.0: | |
| raise ValueError(f"p_max must in [0.0, 1.0]") | |
| r1 = torch.rand_like(tensor) | |
| r2 = torch.rand_like(tensor) | |
| flip_mask = r1 < p_max * r2 | |
| multiplier = torch.where(flip_mask, -1.0, 1.0) | |
| multiplier = multiplier.to(tensor.dtype) | |
| flipped_tensor = tensor * multiplier | |
| return flipped_tensor | |
| def gaussian_sample(raw_output): | |
| mu, log_var = raw_output.chunk(2, dim=-1) | |
| sigma = torch.exp(0.5 * log_var) | |
| sample = mu + torch.randn_like(mu) * sigma | |
| return sample | |
| def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0, pe_interpolation=1.0): | |
| """ | |
| grid_size: int or tuple/list of (h, w) | |
| return: | |
| pos_embed: [grid_h*grid_w, embed_dim] or [extra_tokens+grid_h*grid_w, embed_dim] (w/ or w/o cls_token) | |
| """ | |
| if isinstance(grid_size, int): | |
| grid_h_size, grid_w_size = grid_size, grid_size | |
| else: | |
| grid_h_size, grid_w_size = grid_size | |
| grid_h = torch.arange(grid_h_size, dtype=torch.float32) / pe_interpolation | |
| grid_w = torch.arange(grid_w_size, dtype=torch.float32) / pe_interpolation | |
| grid_w, grid_h = torch.meshgrid(grid_w, grid_h, indexing='xy') | |
| grid = torch.stack([grid_w, grid_h], dim=0) # shape: (2, grid_h_size, grid_w_size) | |
| grid = grid.reshape([2, 1, grid_h_size, grid_w_size]) | |
| pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) | |
| if cls_token and extra_tokens > 0: | |
| pos_embed = torch.cat([torch.zeros([extra_tokens, embed_dim]), pos_embed], dim=0) | |
| return pos_embed | |
| def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): | |
| assert embed_dim % 2 == 0 | |
| # use half of dimensions to encode | |
| emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) | |
| emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) | |
| emb = torch.cat([emb_h, emb_w], dim=1) # (H*W, D) | |
| return emb | |
| def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): | |
| """ | |
| embed_dim: output dimension for each position | |
| pos: a list of positions to be encoded: size (M,) | |
| out: (M, D) | |
| """ | |
| assert embed_dim % 2 == 0 | |
| omega = torch.arange(embed_dim // 2, dtype=torch.float32) | |
| omega /= embed_dim / 2.0 | |
| omega = 1.0 / 10000**omega # (D/2,) | |
| pos = pos.reshape(-1) # (M,) | |
| out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product | |
| emb_sin = torch.sin(out) # (M, D/2) | |
| emb_cos = torch.cos(out) # (M, D/2) | |
| emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D) | |
| return emb | |
| def remove_first_user_block(x: str) -> str: | |
| start_marker = "<|im_start|>user\n" | |
| end_marker = "<|im_end|>\n" | |
| start_index = x.find(start_marker) | |
| if start_index == -1: | |
| return x | |
| end_index = x.find(end_marker, start_index + len(start_marker)) | |
| if end_index == -1: | |
| return x | |
| result = x[:start_index] + x[end_index + len(end_marker):] | |
| return result |