"""AshishOCR model implementation.""" import math from typing import List, Optional, Tuple, Union import torch import torch.nn as nn import torch.nn.functional as F from torch.nn import CrossEntropyLoss from transformers.activations import ACT2FN from transformers.cache_utils import Cache, DynamicCache from transformers.modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, ) from transformers.modeling_utils import PreTrainedModel from transformers.utils import logging from .configuration_ashish_ocr import AshishOcrConfig, AshishOcrTextConfig, AshishOcrVisionConfig logger = logging.get_logger(__name__) class AshishOcrRMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps def forward(self, hidden_states): input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) variance = hidden_states.pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) return self.weight * hidden_states.to(input_dtype) class AshishOcrRotaryEmbedding(nn.Module): def __init__(self, dim, max_position_embeddings=131072, base=10000, device=None): super().__init__() self.dim = dim self.max_position_embeddings = max_position_embeddings self.base = base inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.float32, device=device) / self.dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) @torch.no_grad() def forward(self, x, position_ids): inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) position_ids_expanded = position_ids[:, None, :].float() freqs = (inv_freq_expanded @ position_ids_expanded).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() sin = emb.sin() return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) def rotate_half(x): x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): cos = cos.unsqueeze(unsqueeze_dim) sin = sin.unsqueeze(unsqueeze_dim) q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed class AshishOcrMLP(nn.Module): def __init__(self, config): super().__init__() self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) self.act_fn = ACT2FN[config.hidden_act] def forward(self, x): return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) class AshishOcrAttention(nn.Module): def __init__(self, config: AshishOcrTextConfig, layer_idx: int): super().__init__() self.config = config self.layer_idx = layer_idx self.hidden_size = config.hidden_size self.num_heads = config.num_attention_heads self.head_dim = config.head_dim self.num_key_value_heads = config.num_key_value_heads self.num_key_value_groups = self.num_heads // self.num_key_value_heads self.attention_dropout = config.attention_dropout self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) self.rotary_emb = AshishOcrRotaryEmbedding( self.head_dim, max_position_embeddings=config.max_position_embeddings, ) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: bsz, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) cos, sin = self.rotary_emb(value_states, position_ids) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx) # Repeat k/v heads for grouped query attention key_states = key_states.repeat_interleave(self.num_key_value_groups, dim=1) value_states = value_states.repeat_interleave(self.num_key_value_groups, dim=1) attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) if attention_mask is not None: attn_weights = attn_weights + attention_mask attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) attn_weights = F.dropout(attn_weights, p=self.attention_dropout, training=self.training) attn_output = torch.matmul(attn_weights, value_states) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(bsz, q_len, -1) attn_output = self.o_proj(attn_output) if not output_attentions: attn_weights = None return attn_output, attn_weights, past_key_value class AshishOcrDecoderLayer(nn.Module): def __init__(self, config: AshishOcrTextConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size self.self_attn = AshishOcrAttention(config, layer_idx) self.mlp = AshishOcrMLP(config) self.input_layernorm = AshishOcrRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = AshishOcrRMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) hidden_states, self_attn_weights, present_key_value = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, ) hidden_states = residual + hidden_states residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states outputs = (hidden_states,) if output_attentions: outputs += (self_attn_weights,) if use_cache: outputs += (present_key_value,) return outputs # ==================== Vision Encoder ==================== class AshishOcrVisionMLP(nn.Module): def __init__(self, config: AshishOcrVisionConfig): super().__init__() self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) self.act = ACT2FN[config.hidden_act] self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) def forward(self, hidden_states): hidden_states = self.fc1(hidden_states) hidden_states = self.act(hidden_states) hidden_states = self.fc2(hidden_states) return hidden_states class AshishOcrVisionAttention(nn.Module): def __init__(self, config: AshishOcrVisionConfig): super().__init__() self.num_heads = config.num_heads self.head_dim = config.hidden_size // config.num_heads self.qkv = nn.Linear(config.hidden_size, 3 * config.hidden_size, bias=config.attention_bias) self.proj = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attention_bias) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: bsz, seq_len, _ = hidden_states.size() qkv = self.qkv(hidden_states) qkv = qkv.reshape(bsz, seq_len, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) q, k, v = qkv.unbind(0) attn_weights = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim) attn_weights = F.softmax(attn_weights, dim=-1) attn_output = torch.matmul(attn_weights, v) attn_output = attn_output.transpose(1, 2).reshape(bsz, seq_len, -1) attn_output = self.proj(attn_output) return attn_output class AshishOcrVisionBlock(nn.Module): def __init__(self, config: AshishOcrVisionConfig): super().__init__() self.norm1 = AshishOcrRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.attn = AshishOcrVisionAttention(config) self.norm2 = AshishOcrRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.mlp = AshishOcrVisionMLP(config) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = hidden_states + self.attn(self.norm1(hidden_states)) hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) return hidden_states class AshishOcrPatchEmbed(nn.Module): def __init__(self, config: AshishOcrVisionConfig): super().__init__() self.patch_size = config.patch_size self.temporal_patch_size = config.temporal_patch_size self.proj = nn.Conv3d( 3, config.hidden_size, kernel_size=(config.temporal_patch_size, config.patch_size, config.patch_size), stride=(config.temporal_patch_size, config.patch_size, config.patch_size), bias=False, ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # hidden_states: (B, C, T, H, W) hidden_states = self.proj(hidden_states) hidden_states = hidden_states.flatten(2).transpose(1, 2) # (B, N, D) return hidden_states class AshishOcrPatchMerger(nn.Module): def __init__(self, config: AshishOcrVisionConfig): super().__init__() self.hidden_size = config.hidden_size self.out_hidden_size = config.out_hidden_size self.spatial_merge_size = config.spatial_merge_size self.mlp = nn.Sequential( AshishOcrRMSNorm(config.hidden_size * config.spatial_merge_size ** 2, eps=config.rms_norm_eps), nn.Linear(config.hidden_size * config.spatial_merge_size ** 2, config.out_hidden_size, bias=False), nn.GELU(), nn.Linear(config.out_hidden_size, config.out_hidden_size, bias=False), ) def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor: # Merge spatial patches batch_size = hidden_states.shape[0] merged_states = [] for b in range(batch_size): t, h, w = grid_thw[b].tolist() if grid_thw.dim() > 1 else grid_thw.tolist() states = hidden_states[b, :t*h*w] states = states.view(t, h, w, -1) # Merge spatial patches h_new = h // self.spatial_merge_size w_new = w // self.spatial_merge_size states = states.view(t, h_new, self.spatial_merge_size, w_new, self.spatial_merge_size, -1) states = states.permute(0, 1, 3, 2, 4, 5).contiguous() states = states.view(t * h_new * w_new, -1) merged_states.append(states) hidden_states = torch.stack(merged_states, dim=0) hidden_states = self.mlp(hidden_states) return hidden_states class AshishOcrVisionEncoder(nn.Module): def __init__(self, config: AshishOcrVisionConfig): super().__init__() self.config = config self.patch_embed = AshishOcrPatchEmbed(config) self.blocks = nn.ModuleList([AshishOcrVisionBlock(config) for _ in range(config.depth)]) self.merger = AshishOcrPatchMerger(config) def forward( self, pixel_values: torch.Tensor, grid_thw: Optional[torch.Tensor] = None, ) -> torch.Tensor: hidden_states = self.patch_embed(pixel_values) for block in self.blocks: hidden_states = block(hidden_states) if grid_thw is not None: hidden_states = self.merger(hidden_states, grid_thw) return hidden_states # ==================== Main Model ==================== class AshishOcrPreTrainedModel(PreTrainedModel): config_class = AshishOcrConfig base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["AshishOcrDecoderLayer", "AshishOcrVisionBlock"] def _init_weights(self, module): std = self.config.text_config.initializer_range if hasattr(self.config, 'text_config') else 0.02 if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=std) class AshishOcrTextModel(AshishOcrPreTrainedModel): def __init__(self, config: AshishOcrTextConfig): super().__init__(config) self.config = config self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) self.layers = nn.ModuleList( [AshishOcrDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) self.norm = AshishOcrRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.gradient_checkpointing = False self.post_init() def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) batch_size, seq_length = inputs_embeds.shape[:2] if position_ids is None: position_ids = torch.arange(seq_length, device=inputs_embeds.device).unsqueeze(0) if past_key_values is None: past_key_values = DynamicCache() # Create causal mask if attention_mask is None: attention_mask = torch.ones((batch_size, seq_length), device=inputs_embeds.device) causal_mask = self._prepare_attention_mask(attention_mask, seq_length, inputs_embeds.dtype, inputs_embeds.device) hidden_states = inputs_embeds all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None for decoder_layer in self.layers: if output_hidden_states: all_hidden_states += (hidden_states,) layer_outputs = decoder_layer( hidden_states, attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, ) hidden_states = layer_outputs[0] if output_attentions: all_self_attns += (layer_outputs[1],) hidden_states = self.norm(hidden_states) if output_hidden_states: all_hidden_states += (hidden_states,) return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values if use_cache else None, hidden_states=all_hidden_states, attentions=all_self_attns, ) def _prepare_attention_mask(self, attention_mask, seq_length, dtype, device): # Create causal mask causal_mask = torch.triu(torch.ones((seq_length, seq_length), device=device), diagonal=1) causal_mask = causal_mask.masked_fill(causal_mask == 1, float("-inf")) causal_mask = causal_mask.unsqueeze(0).unsqueeze(0) # Expand attention mask if attention_mask.dim() == 2: extended_mask = attention_mask[:, None, None, :] extended_mask = (1.0 - extended_mask) * float("-inf") causal_mask = causal_mask + extended_mask return causal_mask.to(dtype) class AshishOcrForConditionalGeneration(AshishOcrPreTrainedModel): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config: AshishOcrConfig): super().__init__(config) self.config = config self.visual = AshishOcrVisionEncoder(config.vision_config) self.model = AshishOcrTextModel(config.text_config) self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) self.image_token_id = config.image_token_id self.video_token_id = config.video_token_id self.post_init() def get_input_embeddings(self): return self.model.embed_tokens def set_input_embeddings(self, value): self.model.embed_tokens = value def get_output_embeddings(self): return self.lm_head def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, pixel_values: Optional[torch.FloatTensor] = None, pixel_values_videos: Optional[torch.FloatTensor] = None, image_grid_thw: Optional[torch.LongTensor] = None, video_grid_thw: Optional[torch.LongTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, CausalLMOutputWithPast]: return_dict = return_dict if return_dict is not None else self.config.use_return_dict if inputs_embeds is None: inputs_embeds = self.model.embed_tokens(input_ids) # Process images if provided if pixel_values is not None: image_embeds = self.visual(pixel_values, image_grid_thw) image_mask = input_ids == self.image_token_id inputs_embeds = inputs_embeds.clone() inputs_embeds[image_mask] = image_embeds.view(-1, image_embeds.shape[-1]) # Process videos if provided if pixel_values_videos is not None: video_embeds = self.visual(pixel_values_videos, video_grid_thw) video_mask = input_ids == self.video_token_id inputs_embeds = inputs_embeds.clone() inputs_embeds[video_mask] = video_embeds.view(-1, video_embeds.shape[-1]) outputs = self.model( attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) hidden_states = outputs[0] logits = self.lm_head(hidden_states) logits = logits.float() loss = None if labels is not None: shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() loss_fct = CrossEntropyLoss() shift_logits = shift_logits.view(-1, self.config.text_config.vocab_size) shift_labels = shift_labels.view(-1) shift_labels = shift_labels.to(shift_logits.device) loss = loss_fct(shift_logits, shift_labels) if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output return CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) def prepare_inputs_for_generation( self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, pixel_values=None, pixel_values_videos=None, image_grid_thw=None, video_grid_thw=None, **kwargs, ): if past_key_values is not None: input_ids = input_ids[:, -1:] model_inputs = { "input_ids": input_ids, "past_key_values": past_key_values, "attention_mask": attention_mask, "inputs_embeds": inputs_embeds, "pixel_values": pixel_values, "pixel_values_videos": pixel_values_videos, "image_grid_thw": image_grid_thw, "video_grid_thw": video_grid_thw, } return model_inputs