File size: 4,686 Bytes
8f993ed |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
from sympy import false
from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_rope_utils import rope_config_validation
class LlavaUHDV3VisionConfig(PretrainedConfig):
model_type = "llava_uhd_v3"
base_config_key = "vision_config"
def __init__(
self,
patch_size: int = 14,
init_pos_emb_height: int = 64,
init_pos_emb_width: int = 64,
num_attention_heads: int = 16,
num_hidden_layers: int = 27,
hidden_size: int = 1152,
intermediate_size: int = 4304,
merger_layer_index: list = None,
merging_method: str = None,
**kwargs,
):
super().__init__(**kwargs)
self.patch_size = patch_size
# Positional embedding config
self.init_pos_emb_height = init_pos_emb_height
self.init_pos_emb_width = init_pos_emb_width
# Transformer config
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
# Merging config
self.merger_layer_index = merger_layer_index
self.merging_method = merging_method
self.attn_implementation = "flash_attention_2"
class LlavaUHDV3TextConfig(PretrainedConfig):
model_type = "llava_uhd_v3"
base_config_key = "text_config"
def __init__(
self,
vocab_size=152064,
hidden_size=3584,
intermediate_size=18944,
num_hidden_layers=28,
num_attention_heads=28,
num_key_value_heads=4,
hidden_act="silu",
max_position_embeddings=131072,
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
tie_word_embeddings=False,
rope_theta=1000000.0,
rope_scaling=None,
use_sliding_window=False,
sliding_window=131072,
max_window_layers=28,
layer_types=None,
attention_dropout=0.0,
**kwargs,
):
self.attn_implementation = "flash_attention_2"
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.use_sliding_window = use_sliding_window
self.sliding_window = sliding_window if self.use_sliding_window else None
self.max_window_layers = max_window_layers
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self.attention_dropout = attention_dropout
# Validate the correctness of rotary position embeddings parameters
# BC: if there is a 'type' field, move it to 'rope_type'.
if self.rope_scaling is not None and "type" in self.rope_scaling:
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
rope_config_validation(self)
self.layer_types = layer_types
if self.layer_types is None:
self.layer_types = [
"sliding_attention"
if self.sliding_window is not None and i >= self.max_window_layers
else "full_attention"
for i in range(self.num_hidden_layers)
]
super().__init__(
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
class LlavaUHDV3Config(PretrainedConfig):
model_type = "llava_uhd_v3"
sub_configs = {"vision_config": LlavaUHDV3VisionConfig, "text_config": LlavaUHDV3TextConfig}
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
text_config=None,
vision_config=None,
**kwargs,
):
if isinstance(vision_config, dict):
self.vision_config = self.sub_configs["vision_config"](**vision_config)
elif vision_config is None:
self.vision_config = self.sub_configs["vision_config"]()
if isinstance(text_config, dict):
self.text_config = self.sub_configs["text_config"](**text_config)
elif text_config is None:
self.text_config = self.sub_configs["text_config"]()
super().__init__(**kwargs)
__all__ = ["LlavaUHDV3Config"] |