| from transformers import Gemma3nAudioEncoder, Gemma3nConfig |
| from transformers import AutoFeatureExtractor, PreTrainedModel |
| from transformers.models.gemma3n.modeling_gemma3n import Gemma3nMultimodalEmbedder |
|
|
| class Audio(PreTrainedModel): |
| def __init__(self, config): |
| super().__init__(config) |
| self.audio_tower = Gemma3nAudioEncoder(config.audio_config) |
| self.embed_audio = Gemma3nMultimodalEmbedder(config.audio_config, config.text_config) |
|
|
| class GemmaAudio(PreTrainedModel): |
| config_class = Gemma3nConfig |
| |
| def __init__(self, config): |
| super().__init__(config) |
| self.model = Audio(config) |
|
|
| def forward(self, input_features, input_features_mask, **kwargs): |
| output = self.model.audio_tower( |
| input_features, ~input_features_mask, |
| ) |
| project = self.model.embed_audio(inputs_embeds = output[0]) |
| return project, output[1] |