yuangai's picture
init space
849926f
import torch
import torch.nn.functional as F
from torch import nn
from torch.nn.attention.flex_attention import or_masks, and_masks
from transformers.activations import ACT2FN
class MLPconnector(nn.Module):
def __init__(self, in_dim: int, out_dim: int, hidden_act: str):
super().__init__()
self.activation_fn = ACT2FN[hidden_act]
self.fc1 = nn.Linear(in_dim, out_dim)
self.fc2 = nn.Linear(out_dim, out_dim)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.fc1(hidden_states)
hidden_states = self.activation_fn(hidden_states)
hidden_states = self.fc2(hidden_states)
return hidden_states
def create_sparse_mask(document_lens, split_lens, attn_modes, parallel_num, device):
parallel_causal_num = 2
parallel_block_causal_num = parallel_num
def causal_mask(b, h, q_idx, kv_idx):
return q_idx >= kv_idx
def parallel_block_mask(b, h, q_idx, kv_idx):
same_seg = segment_ids[q_idx] == segment_ids[kv_idx]
is_par = is_parallel[q_idx]
lq = local_ids[q_idx]
lk = local_ids[kv_idx]
in_block_region = (lq >= parallel_causal_num) & (lk >= parallel_causal_num)
same_block = ((lq - parallel_causal_num) // parallel_block_causal_num) == ((lk - parallel_causal_num) // parallel_block_causal_num)
return same_seg & is_par & in_block_region & same_block
def sample_mask(b, h, q_idx, kv_idx):
return document_id[q_idx] == document_id[kv_idx]
segment_ids_list = []
local_ids_list = []
is_parallel_list = []
current_seg_id = 0
for length, mode in zip(split_lens, attn_modes):
segment_ids_list.extend([current_seg_id] * length)
local_ids_list.extend(list(range(length)))
is_parallel_list.extend([True if mode == 'parallel' else False] * length)
current_seg_id += 1
segment_ids = torch.tensor(segment_ids_list, device=device, dtype=torch.long)
local_ids = torch.tensor(local_ids_list, device=device, dtype=torch.long)
is_parallel = torch.tensor(is_parallel_list, device=device, dtype=torch.bool)
document_id = torch.cat([torch.full((l,), i, device=device) for i, l in enumerate(document_lens, start=1)])
return and_masks(or_masks(causal_mask, parallel_block_mask), sample_mask)
def top_k_top_p_filtering(
logits,
top_k: int = 0,
top_p: float = 1.0,
filter_value: float = -float("Inf"),
min_tokens_to_keep: int = 1,
):
"""Filter a distribution of logits using top-k and/or top-p (nucleus) filtering."""
if top_k > 0:
top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1))
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits[indices_to_remove] = filter_value
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
if min_tokens_to_keep > 1:
sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
logits[indices_to_remove] = filter_value
return logits
def sample_codebook(
pred_logits,
cur_item_type,
codebook,
do_sample: bool = True,
temperature: float = 1.0,
top_k: int = 0,
top_p: float = 1.0,
):
"""
pred_logits: (B, vocab_size)
cur_item_type: 'text' or 'vision'
"""
# 1. Apply temperature
logits = pred_logits / max(temperature, 1e-5)
# 2. Apply top-k / top-p filtering
if top_k > 0 or top_p < 1.0:
logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
# 3. Get probabilities
probs = F.softmax(logits, dim=-1)
# 4. Sample or take argmax
if do_sample:
curr_tokens = torch.multinomial(probs, num_samples=1).squeeze(-1)
else:
curr_tokens = torch.argmax(probs, dim=-1)
curr_embeds = codebook(curr_tokens)
return curr_tokens, curr_embeds
def flip_tensor_elements_uniform_prob(tensor: torch.Tensor, p_max: float) -> torch.Tensor:
if not 0.0 <= p_max <= 1.0:
raise ValueError(f"p_max must in [0.0, 1.0]")
r1 = torch.rand_like(tensor)
r2 = torch.rand_like(tensor)
flip_mask = r1 < p_max * r2
multiplier = torch.where(flip_mask, -1.0, 1.0)
multiplier = multiplier.to(tensor.dtype)
flipped_tensor = tensor * multiplier
return flipped_tensor
def gaussian_sample(raw_output):
mu, log_var = raw_output.chunk(2, dim=-1)
sigma = torch.exp(0.5 * log_var)
sample = mu + torch.randn_like(mu) * sigma
return sample
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0, pe_interpolation=1.0):
"""
grid_size: int or tuple/list of (h, w)
return:
pos_embed: [grid_h*grid_w, embed_dim] or [extra_tokens+grid_h*grid_w, embed_dim] (w/ or w/o cls_token)
"""
if isinstance(grid_size, int):
grid_h_size, grid_w_size = grid_size, grid_size
else:
grid_h_size, grid_w_size = grid_size
grid_h = torch.arange(grid_h_size, dtype=torch.float32) / pe_interpolation
grid_w = torch.arange(grid_w_size, dtype=torch.float32) / pe_interpolation
grid_w, grid_h = torch.meshgrid(grid_w, grid_h, indexing='xy')
grid = torch.stack([grid_w, grid_h], dim=0) # shape: (2, grid_h_size, grid_w_size)
grid = grid.reshape([2, 1, grid_h_size, grid_w_size])
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
if cls_token and extra_tokens > 0:
pos_embed = torch.cat([torch.zeros([extra_tokens, embed_dim]), pos_embed], dim=0)
return pos_embed
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
assert embed_dim % 2 == 0
# use half of dimensions to encode
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
emb = torch.cat([emb_h, emb_w], dim=1) # (H*W, D)
return emb
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
"""
embed_dim: output dimension for each position
pos: a list of positions to be encoded: size (M,)
out: (M, D)
"""
assert embed_dim % 2 == 0
omega = torch.arange(embed_dim // 2, dtype=torch.float32)
omega /= embed_dim / 2.0
omega = 1.0 / 10000**omega # (D/2,)
pos = pos.reshape(-1) # (M,)
out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product
emb_sin = torch.sin(out) # (M, D/2)
emb_cos = torch.cos(out) # (M, D/2)
emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D)
return emb
def remove_first_user_block(x: str) -> str:
start_marker = "<|im_start|>user\n"
end_marker = "<|im_end|>\n"
start_index = x.find(start_marker)
if start_index == -1:
return x
end_index = x.find(end_marker, start_index + len(start_marker))
if end_index == -1:
return x
result = x[:start_index] + x[end_index + len(end_marker):]
return result