Allen172 commited on
Commit
ee79e81
·
verified ·
1 Parent(s): 13bf996

Upload Gemma3OmniForConditionalGeneration

Browse files
README.md ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ tags: []
4
+ ---
5
+
6
+ # Model Card for Model ID
7
+
8
+ <!-- Provide a quick summary of what the model is/does. -->
9
+
10
+
11
+
12
+ ## Model Details
13
+
14
+ ### Model Description
15
+
16
+ <!-- Provide a longer summary of what this model is. -->
17
+
18
+ This is the model card of a 🤗 transformers model that has been pushed on the Hub. This model card has been automatically generated.
19
+
20
+ - **Developed by:** [More Information Needed]
21
+ - **Funded by [optional]:** [More Information Needed]
22
+ - **Shared by [optional]:** [More Information Needed]
23
+ - **Model type:** [More Information Needed]
24
+ - **Language(s) (NLP):** [More Information Needed]
25
+ - **License:** [More Information Needed]
26
+ - **Finetuned from model [optional]:** [More Information Needed]
27
+
28
+ ### Model Sources [optional]
29
+
30
+ <!-- Provide the basic links for the model. -->
31
+
32
+ - **Repository:** [More Information Needed]
33
+ - **Paper [optional]:** [More Information Needed]
34
+ - **Demo [optional]:** [More Information Needed]
35
+
36
+ ## Uses
37
+
38
+ <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
39
+
40
+ ### Direct Use
41
+
42
+ <!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
43
+
44
+ [More Information Needed]
45
+
46
+ ### Downstream Use [optional]
47
+
48
+ <!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
49
+
50
+ [More Information Needed]
51
+
52
+ ### Out-of-Scope Use
53
+
54
+ <!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
55
+
56
+ [More Information Needed]
57
+
58
+ ## Bias, Risks, and Limitations
59
+
60
+ <!-- This section is meant to convey both technical and sociotechnical limitations. -->
61
+
62
+ [More Information Needed]
63
+
64
+ ### Recommendations
65
+
66
+ <!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
67
+
68
+ Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
69
+
70
+ ## How to Get Started with the Model
71
+
72
+ Use the code below to get started with the model.
73
+
74
+ [More Information Needed]
75
+
76
+ ## Training Details
77
+
78
+ ### Training Data
79
+
80
+ <!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
81
+
82
+ [More Information Needed]
83
+
84
+ ### Training Procedure
85
+
86
+ <!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
87
+
88
+ #### Preprocessing [optional]
89
+
90
+ [More Information Needed]
91
+
92
+
93
+ #### Training Hyperparameters
94
+
95
+ - **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
96
+
97
+ #### Speeds, Sizes, Times [optional]
98
+
99
+ <!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
100
+
101
+ [More Information Needed]
102
+
103
+ ## Evaluation
104
+
105
+ <!-- This section describes the evaluation protocols and provides the results. -->
106
+
107
+ ### Testing Data, Factors & Metrics
108
+
109
+ #### Testing Data
110
+
111
+ <!-- This should link to a Dataset Card if possible. -->
112
+
113
+ [More Information Needed]
114
+
115
+ #### Factors
116
+
117
+ <!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
118
+
119
+ [More Information Needed]
120
+
121
+ #### Metrics
122
+
123
+ <!-- These are the evaluation metrics being used, ideally with a description of why. -->
124
+
125
+ [More Information Needed]
126
+
127
+ ### Results
128
+
129
+ [More Information Needed]
130
+
131
+ #### Summary
132
+
133
+
134
+
135
+ ## Model Examination [optional]
136
+
137
+ <!-- Relevant interpretability work for the model goes here -->
138
+
139
+ [More Information Needed]
140
+
141
+ ## Environmental Impact
142
+
143
+ <!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
144
+
145
+ Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
146
+
147
+ - **Hardware Type:** [More Information Needed]
148
+ - **Hours used:** [More Information Needed]
149
+ - **Cloud Provider:** [More Information Needed]
150
+ - **Compute Region:** [More Information Needed]
151
+ - **Carbon Emitted:** [More Information Needed]
152
+
153
+ ## Technical Specifications [optional]
154
+
155
+ ### Model Architecture and Objective
156
+
157
+ [More Information Needed]
158
+
159
+ ### Compute Infrastructure
160
+
161
+ [More Information Needed]
162
+
163
+ #### Hardware
164
+
165
+ [More Information Needed]
166
+
167
+ #### Software
168
+
169
+ [More Information Needed]
170
+
171
+ ## Citation [optional]
172
+
173
+ <!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
174
+
175
+ **BibTeX:**
176
+
177
+ [More Information Needed]
178
+
179
+ **APA:**
180
+
181
+ [More Information Needed]
182
+
183
+ ## Glossary [optional]
184
+
185
+ <!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
186
+
187
+ [More Information Needed]
188
+
189
+ ## More Information [optional]
190
+
191
+ [More Information Needed]
192
+
193
+ ## Model Card Authors [optional]
194
+
195
+ [More Information Needed]
196
+
197
+ ## Model Card Contact
198
+
199
+ [More Information Needed]
config.json ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "Gemma3OmniForConditionalGeneration"
4
+ ],
5
+ "audio_token_index": 262151,
6
+ "auto_map": {
7
+ "AutoConfig": "configuration_gemma3_omni.Gemma3OmniConfig",
8
+ "AutoFeatureExtractor": "processing_gemma3_omni.Gemma3AudioFeatureExtractor",
9
+ "AutoModel": "modeling_gemma3_omni.Gemma3OmniForConditionalGeneration",
10
+ "AutoModelForCausalLM": "modeling_gemma3_omni.Gemma3OmniForConditionalGeneration",
11
+ "AutoProcessor": "processing_gemma3_omni.Gemma3OmniProcessor"
12
+ },
13
+ "boi_token_index": 255999,
14
+ "eoi_token_index": 256000,
15
+ "eos_token_id": [
16
+ 1,
17
+ 106
18
+ ],
19
+ "image_token_index": 262152,
20
+ "initializer_range": 0.02,
21
+ "mm_tokens_per_image": 256,
22
+ "model_type": "gemma3omni",
23
+ "text_config": {
24
+ "attention_bias": false,
25
+ "attention_dropout": 0.0,
26
+ "attn_logit_softcapping": null,
27
+ "final_logit_softcapping": null,
28
+ "head_dim": 128,
29
+ "hidden_activation": "gelu_pytorch_tanh",
30
+ "hidden_size": 5376,
31
+ "initializer_range": 0.02,
32
+ "intermediate_size": 21504,
33
+ "layer_types": [
34
+ "sliding_attention",
35
+ "sliding_attention",
36
+ "sliding_attention",
37
+ "sliding_attention",
38
+ "sliding_attention",
39
+ "full_attention",
40
+ "sliding_attention",
41
+ "sliding_attention",
42
+ "sliding_attention",
43
+ "sliding_attention",
44
+ "sliding_attention",
45
+ "full_attention",
46
+ "sliding_attention",
47
+ "sliding_attention",
48
+ "sliding_attention",
49
+ "sliding_attention",
50
+ "sliding_attention",
51
+ "full_attention",
52
+ "sliding_attention",
53
+ "sliding_attention",
54
+ "sliding_attention",
55
+ "sliding_attention",
56
+ "sliding_attention",
57
+ "full_attention",
58
+ "sliding_attention",
59
+ "sliding_attention",
60
+ "sliding_attention",
61
+ "sliding_attention",
62
+ "sliding_attention",
63
+ "full_attention",
64
+ "sliding_attention",
65
+ "sliding_attention",
66
+ "sliding_attention",
67
+ "sliding_attention",
68
+ "sliding_attention",
69
+ "full_attention",
70
+ "sliding_attention",
71
+ "sliding_attention",
72
+ "sliding_attention",
73
+ "sliding_attention",
74
+ "sliding_attention",
75
+ "full_attention",
76
+ "sliding_attention",
77
+ "sliding_attention",
78
+ "sliding_attention",
79
+ "sliding_attention",
80
+ "sliding_attention",
81
+ "full_attention",
82
+ "sliding_attention",
83
+ "sliding_attention",
84
+ "sliding_attention",
85
+ "sliding_attention",
86
+ "sliding_attention",
87
+ "full_attention",
88
+ "sliding_attention",
89
+ "sliding_attention",
90
+ "sliding_attention",
91
+ "sliding_attention",
92
+ "sliding_attention",
93
+ "full_attention",
94
+ "sliding_attention",
95
+ "sliding_attention"
96
+ ],
97
+ "max_position_embeddings": 131072,
98
+ "model_type": "gemma3_text",
99
+ "num_attention_heads": 32,
100
+ "num_hidden_layers": 62,
101
+ "num_key_value_heads": 16,
102
+ "query_pre_attn_scalar": 168,
103
+ "rms_norm_eps": 1e-06,
104
+ "rope_local_base_freq": 10000.0,
105
+ "rope_scaling": {
106
+ "factor": 8.0,
107
+ "rope_type": "linear"
108
+ },
109
+ "rope_theta": 1000000.0,
110
+ "sliding_window": 1024,
111
+ "torch_dtype": "bfloat16",
112
+ "use_cache": true,
113
+ "vocab_size": 262208
114
+ },
115
+ "torch_dtype": "bfloat16",
116
+ "transformers_version": "4.53.0",
117
+ "vision_config": {
118
+ "attention_dropout": 0.0,
119
+ "hidden_act": "gelu_pytorch_tanh",
120
+ "hidden_size": 1152,
121
+ "image_size": 896,
122
+ "intermediate_size": 4304,
123
+ "layer_norm_eps": 1e-06,
124
+ "model_type": "siglip_vision_model",
125
+ "num_attention_heads": 16,
126
+ "num_channels": 3,
127
+ "num_hidden_layers": 27,
128
+ "patch_size": 14,
129
+ "torch_dtype": "bfloat16",
130
+ "vision_use_head": false
131
+ }
132
+ }
configuration_gemma3_omni.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Union, Dict, Any
2
+
3
+ from transformers import Gemma3TextConfig, SiglipVisionConfig, PretrainedConfig
4
+ from transformers.utils import logging
5
+
6
+ logger = logging.get_logger(__name__)
7
+
8
+
9
+ class Gemma3OmniConfig(PretrainedConfig):
10
+ model_type = "gemma3omni"
11
+ attribute_map = {
12
+ "image_token_id": "image_token_index",
13
+ "audio_token_id": "audio_token_index",
14
+ "boi_token_id": "boi_token_index",
15
+ "eoi_token_id": "eoi_token_index",
16
+ }
17
+ sub_configs = {
18
+ "text_config": Gemma3TextConfig,
19
+ "vision_config": SiglipVisionConfig,
20
+ }
21
+
22
+ def __init__(
23
+ self,
24
+ text_config: Optional[Union[Gemma3TextConfig, Dict[str, Any]]] = None,
25
+ vision_config: Optional[Union[SiglipVisionConfig, Dict[str, Any]]] = None,
26
+ mm_tokens_per_image: int = 256,
27
+ boi_token_index: int = 255_999,
28
+ eoi_token_index: int = 256_000,
29
+ image_token_index: int = 262_152,
30
+ audio_token_index: int = 262_151,
31
+ initializer_range: float = 0.02,
32
+ **kwargs,
33
+ ):
34
+ if text_config is None:
35
+ text_config = Gemma3TextConfig()
36
+ logger.info("text_config is None, using default Gemma3TextConfig text config.")
37
+ elif isinstance(text_config, dict):
38
+ text_config = Gemma3TextConfig(**text_config)
39
+
40
+ if isinstance(vision_config, dict):
41
+ vision_config = SiglipVisionConfig(**vision_config)
42
+ elif vision_config is None:
43
+ vision_config = SiglipVisionConfig()
44
+ logger.info("vision_config is None, using default SiglipVisionConfig vision config.")
45
+
46
+ self.text_config = text_config
47
+ self.vision_config = vision_config
48
+ self.mm_tokens_per_image = mm_tokens_per_image
49
+ self.boi_token_index = boi_token_index
50
+ self.eoi_token_index = eoi_token_index
51
+ self.image_token_index = image_token_index
52
+ self.audio_token_index = audio_token_index
53
+ self.initializer_range = initializer_range
54
+
55
+ super().__init__(**kwargs)
generation_config.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token_id": 2,
3
+ "cache_implementation": "hybrid",
4
+ "do_sample": true,
5
+ "eos_token_id": [
6
+ 1,
7
+ 106
8
+ ],
9
+ "pad_token_id": 0,
10
+ "top_k": 64,
11
+ "top_p": 0.95,
12
+ "transformers_version": "4.53.0"
13
+ }
model-00001-of-00012.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:265edf8d47207aa1371ad2e9b48a198de6374fa649c185b902879c2c42e86303
3
+ size 4922387560
model-00002-of-00012.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8ebcc142679e7a1982b4fbbbb1ef9a242b70eeb3731a3e98db15eadfe70296ba
3
+ size 4954792944
model-00003-of-00012.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f3d6aec4724a4b220cfc7b9b428a21e559511ad77d3067a09037ad7d2c0bafb1
3
+ size 4954792960
model-00004-of-00012.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3a50678e8853653af42583b0cce60f8aaaee3affdede5b8ea5596aa45a2d2fdf
3
+ size 4954793016
model-00005-of-00012.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:11d1cdb15b770bf3a8c340f4124ed2def7262e4b77bc97be6b24ac4eabf377b9
3
+ size 4954793016
model-00006-of-00012.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6daf7a5de65a8ec47420d11fc4f66fe027873507009d2dedc490a9f9a441a53c
3
+ size 4954793016
model-00007-of-00012.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9d18a18b1285e9d1d4f40d9a2df892b73c433116b0a5ac12abc59cf38bf345cd
3
+ size 4954793016
model-00008-of-00012.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b5b09b4980766a7add4def1801d8e2791fa439bb3dd8abb9c72324652eaabf4e
3
+ size 4954793016
model-00009-of-00012.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:95f7b177c303ac266c1ff7fb897426235ac87ff275185971cee1462851b961ec
3
+ size 4954793016
model-00010-of-00012.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dfa233912793743d12ace582c227097120174bcb1865b0989d21a12ec7a0406f
3
+ size 4954793016
model-00011-of-00012.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8638f5b356e7b7f2d059365d9afd8dc03bdd2b524dee1c14338d62fa6eac483d
3
+ size 4954793016
model-00012-of-00012.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6c042662255b2c23d6546318982a10474af91b1e2d0fdf83dcb265ed323ca4e5
3
+ size 1288275520
model.safetensors.index.json ADDED
The diff for this file is too large to render. See raw diff
 
modeling_gemma3_omni.py ADDED
@@ -0,0 +1,484 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ from __future__ import annotations
3
+
4
+ from typing import List, Optional, Tuple, Union, Callable
5
+
6
+ from transformers import (
7
+ AutoModel,
8
+ Cache,
9
+ PreTrainedModel,
10
+ PretrainedConfig, )
11
+ from transformers.generation import GenerationMixin
12
+ from transformers.masking_utils import create_causal_mask, create_masks_for_generate, create_sliding_window_causal_mask
13
+ from transformers.models.gemma3.modeling_gemma3 import (
14
+ Gemma3CausalLMOutputWithPast,
15
+ Gemma3RMSNorm, Gemma3PreTrainedModel, Gemma3ModelOutputWithPast,
16
+ )
17
+ from transformers.utils import is_torchdynamo_compiling, logging, is_torch_flex_attn_available
18
+
19
+ try:
20
+ from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
21
+ except:
22
+ LigerFusedLinearCrossEntropyLoss = None
23
+
24
+ from .configuration_gemma3_omni import Gemma3OmniConfig
25
+ from .speech_conformer_encoder import ConformerEncoder
26
+
27
+ logger = logging.get_logger(__name__)
28
+
29
+ if is_torch_flex_attn_available():
30
+ from torch.nn.attention.flex_attention import BlockMask
31
+
32
+
33
+ class Gemma3AudioProjectorConfig(PretrainedConfig):
34
+ model_type = "gemma3_audio"
35
+
36
+ def __init__(
37
+ self,
38
+ hidden_size: int = 1024,
39
+ num_hidden_layers: int = 24,
40
+ sample_rate: int = 16_000,
41
+ n_mels: int = 80,
42
+ image_token_index: int = 0, # This seems unused for audio projector, maybe a copy-paste?
43
+ # Added Mel spectrogram specific parameters
44
+ n_fft: int = 400, # Typical for 25ms window at 16kHz
45
+ hop_length: int = 160, # Typical for 10ms hop at 16kHz
46
+ **kwargs,
47
+ ):
48
+ super().__init__(**kwargs)
49
+ self.hidden_size = hidden_size
50
+ self.num_hidden_layers = num_hidden_layers
51
+ self.sample_rate = sample_rate
52
+ self.n_mels = n_mels
53
+ self.image_token_index = image_token_index
54
+ self.n_fft = n_fft
55
+ self.hop_length = hop_length
56
+
57
+
58
+ import torch
59
+ from torch import nn
60
+
61
+
62
+ class LayerWiseWeightedSum(nn.Module):
63
+ def __init__(self, num_layers: int, learnable: bool = True):
64
+ super().__init__()
65
+ self.num_layers = num_layers
66
+ if learnable:
67
+ self.scalar = nn.Parameter(torch.zeros(num_layers))
68
+ else:
69
+ self.register_buffer("scalar", torch.zeros(num_layers))
70
+
71
+ def forward(self, hidden_states: list[torch.Tensor]) -> torch.Tensor:
72
+ assert len(hidden_states) == self.num_layers
73
+ norm_w = torch.softmax(self.scalar, dim=0).view(-1, 1, 1, 1)
74
+ stacked = torch.stack(hidden_states, dim=0)
75
+ return (norm_w * stacked).sum(dim=0)
76
+
77
+
78
+ class Gemma3AudioProjector(PreTrainedModel):
79
+ """Conformer-based audio encoder → project to LM hidden-dim."""
80
+
81
+ config_class = Gemma3AudioProjectorConfig
82
+ base_model_prefix = "audio_projector"
83
+
84
+ def __init__(self, config: Gemma3AudioProjectorConfig):
85
+ super().__init__(config)
86
+ encoder_config = {
87
+ "activation": "swish",
88
+ "activation_checkpointing": "",
89
+ "attention_dim": 1024,
90
+ "attention_heads": 16,
91
+ "batch_norm": False,
92
+ "bias_in_glu": True,
93
+ "causal": True,
94
+ "chunk_size": -1,
95
+ "conv_activation": "swish",
96
+ "conv_glu_type": "swish",
97
+ "depthwise_multiplier": 1,
98
+ "depthwise_seperable_out_channel": 1024,
99
+ "dropout_rate": 0.0,
100
+ "encoder_embedding_config": {
101
+ "input_size": config.n_mels # This is feat_in for NemoConvSubsampling
102
+ },
103
+ "ext_pw_kernel_size": 1,
104
+ "ext_pw_out_channel": 1024,
105
+ "input_layer": "nemo_conv",
106
+ "input_size": config.n_mels, # Also feat_in for NemoConvSubsampling, consistency
107
+ "kernel_size": 3,
108
+ "left_chunk": 18,
109
+ "linear_units": 1536,
110
+ "nemo_conv_settings": {
111
+ "conv_channels": 1024,
112
+ },
113
+ "num_blocks": 24,
114
+ "relative_attention_bias_args": {
115
+ "t5_bias_max_distance": 500,
116
+ "type": "t5"
117
+ },
118
+ "time_reduction": 8
119
+ }
120
+ self.encoder = ConformerEncoder(**encoder_config)
121
+ self.layer_weighter = LayerWiseWeightedSum(
122
+ num_layers=encoder_config["num_blocks"]
123
+ )
124
+ self.proj = nn.Linear(encoder_config['attention_dim'], config.hidden_size, bias=False)
125
+
126
+ def forward(self, mel: torch.Tensor, mel_mask: torch.Tensor):
127
+ mel = mel.squeeze(1) # (B, T, 80)
128
+ mel_mask = mel_mask.squeeze(1) # (B, L)
129
+
130
+ if mel_mask.size(1) != mel.size(1):
131
+ mel_mask = mel_mask[..., : mel.size(1)]
132
+
133
+ _, out_mask, hidden_list = self.encoder(mel, mel_mask)
134
+ hidden_sum = self.layer_weighter(hidden_list)
135
+ hidden = self.proj(hidden_list[-1])
136
+ return hidden, out_mask
137
+
138
+
139
+ class Gemma3VisionProjector(nn.Module):
140
+ def __init__(self, config):
141
+ super().__init__()
142
+ self.mm_input_projection_weight = nn.Parameter(
143
+ torch.zeros(config.vision_config.hidden_size, config.text_config.hidden_size)
144
+ )
145
+ self.mm_soft_emb_norm = Gemma3RMSNorm(
146
+ config.vision_config.hidden_size, eps=config.vision_config.layer_norm_eps
147
+ )
148
+ self.patches_per_image = config.vision_config.image_size // config.vision_config.patch_size
149
+ self.tokens_per_side = int(config.mm_tokens_per_image ** 0.5)
150
+ self.kernel_size = self.patches_per_image // self.tokens_per_side
151
+ self.avg_pool = nn.AvgPool2d(kernel_size=self.kernel_size, stride=self.kernel_size)
152
+
153
+ def forward(self, vision_outputs: torch.Tensor):
154
+ b, _, seq_len = vision_outputs.shape
155
+ x = vision_outputs.transpose(1, 2).reshape(
156
+ b, seq_len, self.patches_per_image, self.patches_per_image
157
+ )
158
+ x = self.avg_pool(x).flatten(2).transpose(1, 2)
159
+ x = self.mm_soft_emb_norm(x)
160
+ return torch.matmul(x, self.mm_input_projection_weight).type_as(vision_outputs)
161
+
162
+
163
+ def token_type_ids_mask_function(token_type_ids: Optional[torch.Tensor]) -> Optional[Callable]:
164
+ """
165
+ This function adds the correct offsets to the `q_idx` and `kv_idx` as the torch API can only accept lengths,
166
+ not start and end indices.
167
+ """
168
+ # Do not return an additional mask in this case
169
+ if token_type_ids is None:
170
+ return None
171
+
172
+ def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
173
+ return token_type_ids[batch_idx, kv_idx] != 0
174
+ return inner_mask
175
+
176
+
177
+ class Gemma3OmniModel(Gemma3PreTrainedModel):
178
+ config_class = Gemma3OmniConfig
179
+
180
+ def __init__(self, config):
181
+ super().__init__(config)
182
+ self.vision_tower = AutoModel.from_config(config=config.vision_config)
183
+ self.multi_modal_projector = Gemma3VisionProjector(config)
184
+ self.audio_projector = Gemma3AudioProjector(
185
+ Gemma3AudioProjectorConfig(hidden_size=config.text_config.hidden_size)
186
+ )
187
+ self.vocab_size = config.text_config.vocab_size
188
+
189
+ language_model = AutoModel.from_config(config=config.text_config)
190
+ self.language_model = language_model
191
+
192
+ self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
193
+ self.post_init()
194
+
195
+ def get_input_embeddings(self):
196
+ return self.language_model.get_input_embeddings()
197
+
198
+ def set_input_embeddings(self, value):
199
+ self.language_model.set_input_embeddings(value)
200
+
201
+ def forward(
202
+ self,
203
+ input_ids: torch.LongTensor = None,
204
+ pixel_values: torch.FloatTensor = None,
205
+ input_audio_embeds: Optional[torch.FloatTensor] = None,
206
+ audio_attention_mask: Optional[torch.LongTensor] = None,
207
+ attention_mask: Optional[torch.Tensor] = None,
208
+ position_ids: Optional[torch.LongTensor] = None,
209
+ past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None,
210
+ token_type_ids: Optional[torch.LongTensor] = None,
211
+ cache_position: Optional[torch.LongTensor] = None,
212
+ inputs_embeds: Optional[torch.FloatTensor] = None,
213
+ labels: Optional[torch.LongTensor] = None,
214
+ use_cache: Optional[bool] = None,
215
+ output_attentions: Optional[bool] = None,
216
+ output_hidden_states: Optional[bool] = None,
217
+ return_dict: Optional[bool] = None,
218
+ **lm_kwargs,
219
+ ) -> Union[Tuple, Gemma3ModelOutputWithPast]:
220
+ if (input_ids is None) ^ (inputs_embeds is not None):
221
+ print("input_ids:", input_ids, "inputs_embeds:", inputs_embeds)
222
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
223
+
224
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
225
+ output_hidden_states = (
226
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
227
+ )
228
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
229
+
230
+ # Replace image id woth PAD if the image token if OOV, to avoid index-errors
231
+ if input_ids is not None and self.config.image_token_id >= self.vocab_size:
232
+ special_image_mask = input_ids == self.config.image_token_id
233
+ llm_input_ids = input_ids.clone()
234
+ llm_input_ids[special_image_mask] = 0
235
+ else:
236
+ llm_input_ids = input_ids
237
+
238
+ if inputs_embeds is None:
239
+ inputs_embeds = self.get_input_embeddings()(llm_input_ids).clone()
240
+
241
+ if cache_position is None:
242
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
243
+ cache_position = torch.arange(
244
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
245
+ )
246
+
247
+ if pixel_values is not None and past_key_values is None:
248
+ image_features = self.get_image_features(pixel_values)
249
+
250
+ if input_ids is None:
251
+ special_image_mask = inputs_embeds == self.get_input_embeddings()(
252
+ torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
253
+ )
254
+ else:
255
+ special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
256
+ special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
257
+
258
+ if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
259
+ image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0]
260
+ raise ValueError(
261
+ f"Number of images does not match number of special image tokens in the input text. "
262
+ f"Got {image_tokens_in_text} image tokens in the text but {image_features.shape[0] * image_features.shape[1]} "
263
+ "tokens from image embeddings."
264
+ )
265
+ image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
266
+ inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
267
+
268
+ if input_audio_embeds is not None and past_key_values is None:
269
+ audio_features, audio_feat_mask = self.audio_projector(
270
+ input_audio_embeds, audio_attention_mask
271
+ )
272
+ if input_ids is None:
273
+ special_audio_mask = (
274
+ inputs_embeds
275
+ == self.get_input_embeddings()(
276
+ torch.tensor(
277
+ self.config.audio_token_index,
278
+ dtype=torch.long,
279
+ device=inputs_embeds.device,
280
+ )
281
+ )
282
+ )
283
+ else:
284
+ special_audio_mask = (
285
+ input_ids == self.config.audio_token_index
286
+ ).unsqueeze(-1) # [B, L, 1]
287
+ special_audio_mask = special_audio_mask.expand_as(inputs_embeds).to(
288
+ inputs_embeds.device
289
+ )
290
+ if (
291
+ not is_torchdynamo_compiling()
292
+ and inputs_embeds[special_audio_mask].numel() != audio_features.numel()
293
+ ):
294
+ audio_tokens_in_text = special_audio_mask.sum(dim=1).sum(dim=0)[0]
295
+ raise ValueError(
296
+ f"Number of audio tokens in the text ({audio_tokens_in_text}) "
297
+ f"≠ number of tokens from audio embeddings "
298
+ f"({audio_features.shape[0] * audio_features.shape[1]})."
299
+ )
300
+ audio_features = audio_features.to(inputs_embeds.device, inputs_embeds.dtype)
301
+ audio_features = audio_features.reshape(-1)
302
+ inputs_embeds = inputs_embeds.masked_scatter(special_audio_mask, audio_features)
303
+
304
+ if not isinstance(causal_mask_mapping := attention_mask, dict):
305
+ # Prepare mask arguments
306
+ mask_kwargs = {
307
+ "config": self.config.get_text_config(),
308
+ "input_embeds": inputs_embeds,
309
+ "attention_mask": attention_mask,
310
+ "cache_position": cache_position,
311
+ "past_key_values": past_key_values,
312
+ }
313
+ if token_type_ids is not None and inputs_embeds.shape[1] != 1:
314
+ mask_kwargs["or_mask_function"] = token_type_ids_mask_function(
315
+ token_type_ids.to(cache_position.device)
316
+ )
317
+
318
+ # Create the masks
319
+ causal_mask_mapping = {
320
+ "full_attention": create_causal_mask(**mask_kwargs),
321
+ "sliding_attention": create_sliding_window_causal_mask(**mask_kwargs),
322
+ }
323
+
324
+ outputs = self.language_model(
325
+ attention_mask=causal_mask_mapping,
326
+ position_ids=position_ids,
327
+ past_key_values=past_key_values,
328
+ inputs_embeds=inputs_embeds,
329
+ use_cache=use_cache,
330
+ output_attentions=output_attentions,
331
+ output_hidden_states=output_hidden_states,
332
+ return_dict=True,
333
+ cache_position=cache_position,
334
+ **lm_kwargs,
335
+ )
336
+
337
+ return Gemma3ModelOutputWithPast(
338
+ last_hidden_state=outputs.last_hidden_state,
339
+ past_key_values=outputs.past_key_values if use_cache else None,
340
+ hidden_states=outputs.hidden_states,
341
+ attentions=outputs.attentions,
342
+ image_hidden_states=image_features if pixel_values is not None else None,
343
+ )
344
+
345
+
346
+ class Gemma3OmniForConditionalGeneration(Gemma3PreTrainedModel, GenerationMixin):
347
+ config_class = Gemma3OmniConfig
348
+ """Gemma-3 Omni:vision + audio + text causal LM."""
349
+ _checkpoint_conversion_mapping = {
350
+ "^language_model.model": "model.language_model",
351
+ "^vision_tower": "model.vision_tower",
352
+ "^multi_modal_projector": "model.multi_modal_projector",
353
+ "^language_model.lm_head": "lm_head",
354
+ }
355
+ _tied_weights_keys = ["lm_head.weight"]
356
+
357
+ def __init__(self, config):
358
+ super().__init__(config)
359
+ self.model = Gemma3OmniModel(config)
360
+ self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
361
+ self.post_init()
362
+
363
+ def get_input_embeddings(self):
364
+ return self.model.get_input_embeddings()
365
+
366
+ def set_input_embeddings(self, value):
367
+ self.model.set_input_embeddings(value)
368
+
369
+ def get_output_embeddings(self):
370
+ return self.lm_head
371
+
372
+ def set_output_embeddings(self, new_embeddings):
373
+ self.lm_head = new_embeddings
374
+
375
+ def forward(
376
+ self,
377
+ input_ids: torch.LongTensor = None,
378
+ pixel_values: torch.FloatTensor = None,
379
+ input_audio_embeds: Optional[torch.FloatTensor] = None,
380
+ audio_attention_mask: Optional[torch.LongTensor] = None,
381
+ attention_mask: Optional[torch.Tensor] = None,
382
+ position_ids: Optional[torch.LongTensor] = None,
383
+ past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None,
384
+ token_type_ids: Optional[torch.LongTensor] = None,
385
+ cache_position: Optional[torch.LongTensor] = None,
386
+ inputs_embeds: Optional[torch.FloatTensor] = None,
387
+ labels: Optional[torch.LongTensor] = None,
388
+ use_cache: Optional[bool] = None,
389
+ output_attentions: Optional[bool] = None,
390
+ output_hidden_states: Optional[bool] = None,
391
+ return_dict: Optional[bool] = None,
392
+ logits_to_keep: Union[int, torch.Tensor] = 0,
393
+ **lm_kwargs,
394
+ ) -> Union[Tuple, Gemma3CausalLMOutputWithPast]:
395
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
396
+ output_hidden_states = (
397
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
398
+ )
399
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
400
+
401
+ outputs = self.model(
402
+ input_ids=input_ids,
403
+ pixel_values=pixel_values,
404
+ input_audio_embeds=input_audio_embeds,
405
+ audio_attention_mask=audio_attention_mask,
406
+ token_type_ids=token_type_ids,
407
+ attention_mask=attention_mask,
408
+ position_ids=position_ids,
409
+ past_key_values=past_key_values,
410
+ inputs_embeds=inputs_embeds,
411
+ use_cache=use_cache,
412
+ labels=labels,
413
+ output_attentions=output_attentions,
414
+ output_hidden_states=output_hidden_states,
415
+ return_dict=return_dict,
416
+ cache_position=cache_position,
417
+ **lm_kwargs,
418
+ )
419
+
420
+ hidden_states = outputs[0]
421
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
422
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
423
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
424
+
425
+ loss = None
426
+ if labels is not None:
427
+ if LigerFusedLinearCrossEntropyLoss is not None:
428
+ shift_hidden_states = hidden_states[..., :-1, :] # (B, S-1, H)
429
+ shift_labels = labels[..., 1:] # (B, S-1)
430
+ hidden_device = shift_hidden_states.device
431
+
432
+ if attention_mask is not None:
433
+ shift_attention_mask = attention_mask[:, -shift_hidden_states.shape[1]:].to(hidden_device)
434
+ shift_hidden_states = shift_hidden_states[shift_attention_mask != 0].contiguous()
435
+ shift_labels = shift_labels[shift_attention_mask.to(shift_labels.device) != 0].contiguous()
436
+ else:
437
+ shift_hidden_states = shift_hidden_states.contiguous()
438
+ shift_labels = shift_labels.contiguous()
439
+
440
+ shift_hidden_states = shift_hidden_states.view(-1, self.config.text_config.hidden_size) # (N, H)
441
+ shift_labels = shift_labels.view(-1).to(hidden_device)
442
+
443
+ loss_fct = LigerFusedLinearCrossEntropyLoss()
444
+ loss = loss_fct(self.lm_head.weight, shift_hidden_states, shift_labels)
445
+ else:
446
+ logits = logits.float()
447
+ shift_logits = logits[..., :-1, :] # (B, S-1, V)
448
+ shift_labels = labels[..., 1:] # (B, S-1)
449
+
450
+ if attention_mask is not None:
451
+ shift_attention_mask = attention_mask[:, -shift_logits.shape[1]:].to(logits.device)
452
+ shift_logits = shift_logits[shift_attention_mask != 0].contiguous()
453
+ shift_labels = shift_labels[shift_attention_mask.to(shift_labels.device) != 0].contiguous()
454
+ else:
455
+ shift_logits = shift_logits.contiguous()
456
+ shift_labels = shift_labels.contiguous()
457
+
458
+ flat_logits = shift_logits.view(-1, self.config.text_config.vocab_size)
459
+ flat_labels = shift_labels.view(-1).to(shift_logits.device)
460
+
461
+ loss_fct = nn.CrossEntropyLoss()
462
+ loss = loss_fct(flat_logits, flat_labels)
463
+
464
+ if not return_dict:
465
+ output = (logits,) + outputs[1:]
466
+ return (loss,) + output if loss is not None else output
467
+
468
+ return Gemma3CausalLMOutputWithPast(
469
+ loss=loss,
470
+ logits=logits,
471
+ past_key_values=outputs.past_key_values,
472
+ hidden_states=outputs.hidden_states,
473
+ attentions=outputs.attentions,
474
+ image_hidden_states=outputs.image_hidden_states,
475
+ )
476
+
477
+
478
+ __all__ = [
479
+ "Gemma3AudioProjectorConfig",
480
+ "Gemma3AudioProjector",
481
+ "Gemma3VisionProjector",
482
+ "Gemma3OmniModel",
483
+ "Gemma3OmniForConditionalGeneration",
484
+ ]
speech_conformer_encoder.py ADDED
The diff for this file is too large to render. See raw diff