"""AshishOCR model configuration""" from transformers.configuration_utils import PretrainedConfig class AshishOcrVisionConfig(PretrainedConfig): """Configuration class for AshishOCR vision encoder.""" model_type = "ashish_ocr_vision" def __init__( self, hidden_size=1024, depth=24, num_heads=16, attention_bias=True, intermediate_size=4096, hidden_act="silu", hidden_dropout_prob=0.0, initializer_range=0.02, image_size=336, patch_size=14, out_hidden_size=1536, rms_norm_eps=1e-05, spatial_merge_size=2, temporal_patch_size=2, **kwargs, ): super().__init__(**kwargs) self.hidden_size = hidden_size self.depth = depth self.num_heads = num_heads self.attention_bias = attention_bias self.intermediate_size = intermediate_size self.hidden_act = hidden_act self.hidden_dropout_prob = hidden_dropout_prob self.initializer_range = initializer_range self.image_size = image_size self.patch_size = patch_size self.out_hidden_size = out_hidden_size self.rms_norm_eps = rms_norm_eps self.spatial_merge_size = spatial_merge_size self.temporal_patch_size = temporal_patch_size class AshishOcrTextConfig(PretrainedConfig): """Configuration class for AshishOCR text decoder.""" model_type = "ashish_ocr_text" def __init__( self, vocab_size=59392, hidden_size=1536, intermediate_size=4608, num_hidden_layers=16, num_attention_heads=16, num_key_value_heads=8, head_dim=128, hidden_act="silu", max_position_embeddings=131072, initializer_range=0.02, rms_norm_eps=1e-05, use_cache=True, tie_word_embeddings=False, attention_bias=False, attention_dropout=0.0, pad_token_id=59246, eos_token_id=None, num_nextn_predict_layers=1, rope_parameters=None, dtype="bfloat16", **kwargs, ): super().__init__(**kwargs) self.vocab_size = vocab_size self.hidden_size = hidden_size self.intermediate_size = intermediate_size self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads self.num_key_value_heads = num_key_value_heads self.head_dim = head_dim self.hidden_act = hidden_act self.max_position_embeddings = max_position_embeddings self.initializer_range = initializer_range self.rms_norm_eps = rms_norm_eps self.use_cache = use_cache self.tie_word_embeddings = tie_word_embeddings self.attention_bias = attention_bias self.attention_dropout = attention_dropout self.pad_token_id = pad_token_id self.eos_token_id = eos_token_id if eos_token_id is not None else [59246, 59253] self.num_nextn_predict_layers = num_nextn_predict_layers self.rope_parameters = rope_parameters self.dtype = dtype class AshishOcrConfig(PretrainedConfig): """Configuration class for AshishOCR multimodal model.""" model_type = "ashish_ocr" sub_configs = {"text_config": AshishOcrTextConfig, "vision_config": AshishOcrVisionConfig} def __init__( self, text_config=None, vision_config=None, image_start_token_id=59256, image_end_token_id=59257, video_start_token_id=59258, video_end_token_id=59259, image_token_id=59280, video_token_id=59281, **kwargs, ): super().__init__(**kwargs) if text_config is None: text_config = {} if vision_config is None: vision_config = {} self.text_config = AshishOcrTextConfig(**text_config) self.vision_config = AshishOcrVisionConfig(**vision_config) self.image_start_token_id = image_start_token_id self.image_end_token_id = image_end_token_id self.video_start_token_id = video_start_token_id self.video_end_token_id = video_end_token_id self.image_token_id = image_token_id self.video_token_id = video_token_id # Inherit key parameters from text config self.vocab_size = self.text_config.vocab_size self.hidden_size = self.text_config.hidden_size