| |
| |
| |
| |
| @@ -160,12 +160,14 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return x |
| |
| |
| -def apply_rotary_pos_emb_flashatt(tensor: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor: |
| - tensor_ = tensor.float() |
| - cos = freqs.cos().float() |
| - sin = freqs.sin().float() |
| - output = apply_rotary_emb(tensor_, cos, sin).type_as(tensor) |
| - return output |
| +def apply_rotary_pos_emb_flashatt( |
| + q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor |
| +) -> Tuple[torch.Tensor, torch.Tensor]: |
| + cos = cos.chunk(2, dim=-1)[0].contiguous() |
| + sin = sin.chunk(2, dim=-1)[0].contiguous() |
| + q_embed = apply_rotary_emb(q.float(), cos, sin).type_as(q) |
| + k_embed = apply_rotary_emb(k.float(), cos, sin).type_as(k) |
| + return q_embed, k_embed |
| |
| |
| class Qwen2_5_VLVisionFlashAttention2(nn.Module): |
| @@ -179,12 +181,26 @@ def forward( |
| self, |
| hidden_states: torch.Tensor, |
| cu_seqlens: torch.Tensor, |
| - rotary_pos_emb: torch.Tensor = None, |
| + rotary_pos_emb: Optional[torch.Tensor] = None, |
| + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, |
| ) -> torch.Tensor: |
| seq_length = hidden_states.shape[0] |
| q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) |
| - q = apply_rotary_pos_emb_flashatt(q.unsqueeze(0), rotary_pos_emb).squeeze(0) |
| - k = apply_rotary_pos_emb_flashatt(k.unsqueeze(0), rotary_pos_emb).squeeze(0) |
| + if position_embeddings is None: |
| + logger.warning_once( |
| + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " |
| + "through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed " |
| + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be " |
| + "removed and `position_embeddings` will be mandatory." |
| + ) |
| + emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) |
| + cos = emb.cos().float() |
| + sin = emb.sin().float() |
| + else: |
| + cos, sin = position_embeddings |
| + q, k = apply_rotary_pos_emb_flashatt(q.unsqueeze(0), k.unsqueeze(0), cos, sin) |
| + q = q.squeeze(0) |
| + k = k.squeeze(0) |
| |
| max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() |
| attn_output = flash_attn_varlen_func(q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen).reshape( |
| @@ -201,16 +217,18 @@ def rotate_half(x): |
| return torch.cat((-x2, x1), dim=-1) |
| |
| |
| -def apply_rotary_pos_emb_vision(tensor: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor: |
| - orig_dtype = tensor.dtype |
| - tensor = tensor.float() |
| - cos = freqs.cos() |
| - sin = freqs.sin() |
| - cos = cos.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float() |
| - sin = sin.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float() |
| - output = (tensor * cos) + (rotate_half(tensor) * sin) |
| - output = output.to(orig_dtype) |
| - return output |
| +def apply_rotary_pos_emb_vision( |
| + q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor |
| +) -> Tuple[torch.Tensor, torch.Tensor]: |
| + orig_q_dtype = q.dtype |
| + orig_k_dtype = k.dtype |
| + q, k = q.float(), k.float() |
| + cos, sin = cos.unsqueeze(-2), sin.unsqueeze(-2) |
| + q_embed = (q * cos) + (rotate_half(q) * sin) |
| + k_embed = (k * cos) + (rotate_half(k) * sin) |
| + q_embed = q_embed.to(orig_q_dtype) |
| + k_embed = k_embed.to(orig_k_dtype) |
| + return q_embed, k_embed |
| |
| |
| class Qwen2_5_VLVisionAttention(nn.Module): |
| @@ -222,12 +240,27 @@ def __init__(self, dim: int, num_heads: int = 16) -> None: |
| self.proj = nn.Linear(dim, dim) |
| |
| def forward( |
| - self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None |
| + self, |
| + hidden_states: torch.Tensor, |
| + cu_seqlens: torch.Tensor, |
| + rotary_pos_emb: Optional[torch.Tensor] = None, |
| + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, |
| ) -> torch.Tensor: |
| seq_length = hidden_states.shape[0] |
| q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) |
| - q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0) |
| - k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0) |
| + if position_embeddings is None: |
| + logger.warning_once( |
| + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " |
| + "through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed " |
| + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be " |
| + "removed and `position_embeddings` will be mandatory." |
| + ) |
| + emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) |
| + cos = emb.cos().float() |
| + sin = emb.sin().float() |
| + else: |
| + cos, sin = position_embeddings |
| + q, k = apply_rotary_pos_emb_vision(q, k, cos, sin) |
| |
| attention_mask = torch.full( |
| [1, seq_length, seq_length], torch.finfo(q.dtype).min, device=q.device, dtype=q.dtype |
| @@ -256,12 +289,27 @@ def __init__(self, dim: int, num_heads: int = 16) -> None: |
| self.proj = nn.Linear(dim, dim) |
| |
| def forward( |
| - self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None |
| + self, |
| + hidden_states: torch.Tensor, |
| + cu_seqlens: torch.Tensor, |
| + rotary_pos_emb: Optional[torch.Tensor] = None, |
| + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, |
| ) -> torch.Tensor: |
| seq_length = hidden_states.shape[0] |
| q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) |
| - q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0) |
| - k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0) |
| + if position_embeddings is None: |
| + logger.warning_once( |
| + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " |
| + "through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed " |
| + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be " |
| + "removed and `position_embeddings` will be mandatory." |
| + ) |
| + emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) |
| + cos = emb.cos().float() |
| + sin = emb.sin().float() |
| + else: |
| + cos, sin = position_embeddings |
| + q, k = apply_rotary_pos_emb_vision(q, k, cos, sin) |
| |
| attention_mask = torch.zeros([1, seq_length, seq_length], device=q.device, dtype=torch.bool) |
| for i in range(1, len(cu_seqlens)): |
| @@ -293,11 +341,18 @@ def __init__(self, config, attn_implementation: str = "sdpa") -> None: |
| ) |
| self.mlp = Qwen2_5_VLMLP(config, bias=True) |
| |
| - def forward(self, hidden_states, cu_seqlens, rotary_pos_emb) -> torch.Tensor: |
| + def forward( |
| + self, |
| + hidden_states: torch.Tensor, |
| + cu_seqlens: torch.Tensor, |
| + rotary_pos_emb: Optional[torch.Tensor] = None, |
| + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, |
| + ) -> torch.Tensor: |
| hidden_states = hidden_states + self.attn( |
| self.norm1(hidden_states), |
| cu_seqlens=cu_seqlens, |
| rotary_pos_emb=rotary_pos_emb, |
| + position_embeddings=position_embeddings, |
| ) |
| hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) |
| return hidden_states |
| @@ -477,6 +532,8 @@ def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch. |
| rotary_pos_emb = rotary_pos_emb.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) |
| rotary_pos_emb = rotary_pos_emb[window_index, :, :] |
| rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1) |
| + emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) |
| + position_embeddings = (emb.cos(), emb.sin()) |
| |
| cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( |
| dim=0, |
| @@ -495,14 +552,10 @@ def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch. |
| cu_seqlens_now = cu_window_seqlens |
| if self.gradient_checkpointing and self.training: |
| hidden_states = self._gradient_checkpointing_func( |
| - blk.__call__, hidden_states, cu_seqlens_now, rotary_pos_emb |
| + blk.__call__, hidden_states, cu_seqlens_now, None, position_embeddings |
| ) |
| else: |
| - hidden_states = blk( |
| - hidden_states, |
| - cu_seqlens=cu_seqlens_now, |
| - rotary_pos_emb=rotary_pos_emb, |
| - ) |
| + hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens_now, position_embeddings=position_embeddings) |
| |
| hidden_states = self.merger(hidden_states) |
| reverse_indices = torch.argsort(window_index) |
| |
| |
| |
| |
| @@ -51,7 +51,7 @@ |
| from ...image_utils import ImageInput, VideoInput |
| from ...processing_utils import ProcessingKwargs, Unpack, VideosKwargs |
| from ...tokenization_utils_base import PreTokenizedInput, TextInput |
| -from ...utils import is_flash_attn_2_available, is_torchdynamo_compiling |
| +from ...utils import is_flash_attn_2_available, is_torchdynamo_compiling, logging |
| |
| |
| if is_flash_attn_2_available(): |
| @@ -63,12 +63,17 @@ |
| apply_rotary_emb = None |
| |
| |
| -def apply_rotary_pos_emb_flashatt(tensor: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor: |
| - tensor_ = tensor.float() |
| - cos = freqs.cos().float() |
| - sin = freqs.sin().float() |
| - output = apply_rotary_emb(tensor_, cos, sin).type_as(tensor) |
| - return output |
| +logger = logging.get_logger(__name__) |
| + |
| + |
| +def apply_rotary_pos_emb_flashatt( |
| + q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor |
| +) -> Tuple[torch.Tensor, torch.Tensor]: |
| + cos = cos.chunk(2, dim=-1)[0].contiguous() |
| + sin = sin.chunk(2, dim=-1)[0].contiguous() |
| + q_embed = apply_rotary_emb(q.float(), cos, sin).type_as(q) |
| + k_embed = apply_rotary_emb(k.float(), cos, sin).type_as(k) |
| + return q_embed, k_embed |
| |
| |
| class Qwen2_5_VLVisionConfig(PretrainedConfig): |
| @@ -153,12 +158,26 @@ def forward( |
| self, |
| hidden_states: torch.Tensor, |
| cu_seqlens: torch.Tensor, |
| - rotary_pos_emb: torch.Tensor = None, |
| + rotary_pos_emb: Optional[torch.Tensor] = None, |
| + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, |
| ) -> torch.Tensor: |
| seq_length = hidden_states.shape[0] |
| q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) |
| - q = apply_rotary_pos_emb_flashatt(q.unsqueeze(0), rotary_pos_emb).squeeze(0) |
| - k = apply_rotary_pos_emb_flashatt(k.unsqueeze(0), rotary_pos_emb).squeeze(0) |
| + if position_embeddings is None: |
| + logger.warning_once( |
| + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " |
| + "through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed " |
| + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be " |
| + "removed and `position_embeddings` will be mandatory." |
| + ) |
| + emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) |
| + cos = emb.cos().float() |
| + sin = emb.sin().float() |
| + else: |
| + cos, sin = position_embeddings |
| + q, k = apply_rotary_pos_emb_flashatt(q.unsqueeze(0), k.unsqueeze(0), cos, sin) |
| + q = q.squeeze(0) |
| + k = k.squeeze(0) |
| |
| max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() |
| attn_output = flash_attn_varlen_func(q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen).reshape( |
| @@ -193,11 +212,18 @@ def __init__(self, config, attn_implementation: str = "sdpa") -> None: |
| ) |
| self.mlp = Qwen2_5_VLMLP(config, bias=True) |
| |
| - def forward(self, hidden_states, cu_seqlens, rotary_pos_emb) -> torch.Tensor: |
| + def forward( |
| + self, |
| + hidden_states: torch.Tensor, |
| + cu_seqlens: torch.Tensor, |
| + rotary_pos_emb: Optional[torch.Tensor] = None, |
| + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, |
| + ) -> torch.Tensor: |
| hidden_states = hidden_states + self.attn( |
| self.norm1(hidden_states), |
| cu_seqlens=cu_seqlens, |
| rotary_pos_emb=rotary_pos_emb, |
| + position_embeddings=position_embeddings, |
| ) |
| hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) |
| return hidden_states |
| @@ -337,6 +363,8 @@ def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch. |
| rotary_pos_emb = rotary_pos_emb.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) |
| rotary_pos_emb = rotary_pos_emb[window_index, :, :] |
| rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1) |
| + emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) |
| + position_embeddings = (emb.cos(), emb.sin()) |
| |
| cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( |
| dim=0, |
| @@ -355,14 +383,10 @@ def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch. |
| cu_seqlens_now = cu_window_seqlens |
| if self.gradient_checkpointing and self.training: |
| hidden_states = self._gradient_checkpointing_func( |
| - blk.__call__, hidden_states, cu_seqlens_now, rotary_pos_emb |
| + blk.__call__, hidden_states, cu_seqlens_now, None, position_embeddings |
| ) |
| else: |
| - hidden_states = blk( |
| - hidden_states, |
| - cu_seqlens=cu_seqlens_now, |
| - rotary_pos_emb=rotary_pos_emb, |
| - ) |
| + hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens_now, position_embeddings=position_embeddings) |
| |
| hidden_states = self.merger(hidden_states) |
| reverse_indices = torch.argsort(window_index) |
| |
| |
| |
| |
| @@ -214,16 +214,18 @@ def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim |
| return q_embed, k_embed |
| |
| |
| -def apply_rotary_pos_emb_vision(tensor: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor: |
| - orig_dtype = tensor.dtype |
| - tensor = tensor.float() |
| - cos = freqs.cos() |
| - sin = freqs.sin() |
| - cos = cos.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float() |
| - sin = sin.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float() |
| - output = (tensor * cos) + (rotate_half(tensor) * sin) |
| - output = output.to(orig_dtype) |
| - return output |
| +def apply_rotary_pos_emb_vision( |
| + q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor |
| +) -> Tuple[torch.Tensor, torch.Tensor]: |
| + orig_q_dtype = q.dtype |
| + orig_k_dtype = k.dtype |
| + q, k = q.float(), k.float() |
| + cos, sin = cos.unsqueeze(-2), sin.unsqueeze(-2) |
| + q_embed = (q * cos) + (rotate_half(q) * sin) |
| + k_embed = (k * cos) + (rotate_half(k) * sin) |
| + q_embed = q_embed.to(orig_q_dtype) |
| + k_embed = k_embed.to(orig_k_dtype) |
| + return q_embed, k_embed |
| |
| |
| class VisionRotaryEmbedding(nn.Module): |
| @@ -300,12 +302,27 @@ def __init__(self, dim: int, num_heads: int = 16) -> None: |
| self.proj = nn.Linear(dim, dim) |
| |
| def forward( |
| - self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None |
| + self, |
| + hidden_states: torch.Tensor, |
| + cu_seqlens: torch.Tensor, |
| + rotary_pos_emb: Optional[torch.Tensor] = None, |
| + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, |
| ) -> torch.Tensor: |
| seq_length = hidden_states.shape[0] |
| q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) |
| - q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0) |
| - k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0) |
| + if position_embeddings is None: |
| + logger.warning_once( |
| + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " |
| + "through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed " |
| + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be " |
| + "removed and `position_embeddings` will be mandatory." |
| + ) |
| + emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) |
| + cos = emb.cos().float() |
| + sin = emb.sin().float() |
| + else: |
| + cos, sin = position_embeddings |
| + q, k = apply_rotary_pos_emb_vision(q, k, cos, sin) |
| |
| attention_mask = torch.full( |
| [1, seq_length, seq_length], torch.finfo(q.dtype).min, device=q.device, dtype=q.dtype |
| @@ -334,12 +351,27 @@ def __init__(self, dim: int, num_heads: int = 16) -> None: |
| self.proj = nn.Linear(dim, dim) |
| |
| def forward( |
| - self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None |
| + self, |
| + hidden_states: torch.Tensor, |
| + cu_seqlens: torch.Tensor, |
| + rotary_pos_emb: Optional[torch.Tensor] = None, |
| + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, |
| ) -> torch.Tensor: |
| seq_length = hidden_states.shape[0] |
| q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) |
| - q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0) |
| - k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0) |
| + if position_embeddings is None: |
| + logger.warning_once( |
| + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " |
| + "through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed " |
| + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be " |
| + "removed and `position_embeddings` will be mandatory." |
| + ) |
| + emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) |
| + cos = emb.cos().float() |
| + sin = emb.sin().float() |
| + else: |
| + cos, sin = position_embeddings |
| + q, k = apply_rotary_pos_emb_vision(q, k, cos, sin) |
| |
| max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() |
| attn_output = flash_attn_varlen_func(q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen).reshape( |
| @@ -357,12 +389,27 @@ def __init__(self, dim: int, num_heads: int = 16) -> None: |
| self.proj = nn.Linear(dim, dim) |
| |
| def forward( |
| - self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None |
| + self, |
| + hidden_states: torch.Tensor, |
| + cu_seqlens: torch.Tensor, |
| + rotary_pos_emb: Optional[torch.Tensor] = None, |
| + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, |
| ) -> torch.Tensor: |
| seq_length = hidden_states.shape[0] |
| q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) |
| - q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0) |
| - k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0) |
| + if position_embeddings is None: |
| + logger.warning_once( |
| + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " |
| + "through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed " |
| + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be " |
| + "removed and `position_embeddings` will be mandatory." |
| + ) |
| + emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) |
| + cos = emb.cos().float() |
| + sin = emb.sin().float() |
| + else: |
| + cos, sin = position_embeddings |
| + q, k = apply_rotary_pos_emb_vision(q, k, cos, sin) |
| |
| attention_mask = torch.zeros([1, seq_length, seq_length], device=q.device, dtype=torch.bool) |
| for i in range(1, len(cu_seqlens)): |
| @@ -396,9 +443,18 @@ def __init__(self, config, attn_implementation: str = "sdpa") -> None: |
| ) |
| self.mlp = VisionMlp(dim=config.embed_dim, hidden_dim=mlp_hidden_dim, hidden_act=config.hidden_act) |
| |
| - def forward(self, hidden_states, cu_seqlens, rotary_pos_emb) -> torch.Tensor: |
| + def forward( |
| + self, |
| + hidden_states: torch.Tensor, |
| + cu_seqlens: torch.Tensor, |
| + rotary_pos_emb: Optional[torch.Tensor] = None, |
| + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, |
| + ) -> torch.Tensor: |
| hidden_states = hidden_states + self.attn( |
| - self.norm1(hidden_states), cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb |
| + self.norm1(hidden_states), |
| + cu_seqlens=cu_seqlens, |
| + rotary_pos_emb=rotary_pos_emb, |
| + position_embeddings=position_embeddings, |
| ) |
| hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) |
| return hidden_states |
| @@ -961,6 +1017,8 @@ def rot_pos_emb(self, grid_thw): |
| def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor: |
| hidden_states = self.patch_embed(hidden_states) |
| rotary_pos_emb = self.rot_pos_emb(grid_thw) |
| + emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) |
| + position_embeddings = (emb.cos(), emb.sin()) |
| |
| cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( |
| dim=0, |
| @@ -975,10 +1033,10 @@ def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch. |
| for blk in self.blocks: |
| if self.gradient_checkpointing and self.training: |
| hidden_states = self._gradient_checkpointing_func( |
| - blk.__call__, hidden_states, cu_seqlens, rotary_pos_emb |
| + blk.__call__, hidden_states, cu_seqlens, None, position_embeddings |
| ) |
| else: |
| - hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb) |
| + hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens, position_embeddings=position_embeddings) |
| |
| return self.merger(hidden_states) |
| |
|
|