File size: 6,177 Bytes
f5a276a
 
36e20ed
f5a276a
 
b4adf90
f5a276a
6e88e9b
57b9db9
 
3d653ef
6e88e9b
f5a276a
6e88e9b
f5a276a
6e88e9b
 
ea831bc
 
6e88e9b
f5a276a
36e20ed
f5a276a
6e88e9b
f5a276a
36e20ed
f5a276a
6e88e9b
cb34c43
f5a276a
6e88e9b
f5a276a
cb34c43
 
 
 
 
36e20ed
 
 
 
f5a276a
36e20ed
 
f5a276a
36e20ed
 
f5a276a
36e20ed
 
 
f5a276a
961cbae
f5a276a
 
36e20ed
 
 
 
 
f5a276a
36e20ed
 
f5a276a
36e20ed
f5a276a
36e20ed
 
 
 
f5a276a
36e20ed
 
 
 
 
 
 
 
 
 
 
 
f5a276a
36e20ed
6e88e9b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
---
library_name: transformers
license: gemma
---

# Google USM: Extracted Gemma-3n Audio Encoder (USM)

> [!Note]
> このモデルの実態は不明確です。[Introducing Gemma 3n: The developer guide](https://developers.googleblog.com/en/introducing-gemma-3n-developer-guide/#:~:text=Gemma%203n%20uses%20an%20advanced%20audio%20encoder%20based%20on%20the%20Universal%20Speech%20Model%20(USM).)には、
> USMに基づくエンコーダーが使用されていると記述されていますが、USMの論文とこのモデルにはいくつかの異なる点が存在します。
> このモデルは0.6Bですが、USMの論文の0.6Bモデルとは層の数が異なります。
> このモデルは Gemma 3n の AudioEncoder であり、本来の USM とは異なる可能性があります。

## Model Description

このモデルは、Googleのマルチモーダルモデル [google/gemma-3n-e2b-it](https://huggingface.co/google/gemma-3n-e2b-it) から、音声エンコーダー部分 (`audio_tower`) のみを抽出したものです。

bf16版:https://huggingface.co/Atotti/google-usm-bf16

アーキテクチャは、論文 [Universal Speech Model](https://arxiv.org/abs/2303.01037) に基づくGemma3nAudioEncoderです。

このエンコーダーは、音声波形データを受け取り、その内容を表現する高次元の特徴量(エンコーディング)のシーケンスに変換する役割を果たします。

## Intended Use

このモデルは単体で音声認識(文字起こし)などを行うものではなく、より大きなモデルのコンポーネントとして使用されることを想定しています。

* マルチモーダルモデルの音声入力部として: 生成AIに音声情報を与えるための特徴量を抽出します。
* 音声分類: このモデルの出力に分類ヘッドを追加して、特定の音声を分類するタスクでファインチューニングします。

## How to Use

### dependencies
```
pip install transformers==4.53.0
```

```python
import torch
import soundfile as sf
from transformers import Gemma3nAudioEncoder, Gemma3nAudioFeatureExtractor

encoder_id = "Atotti/google-usm"
source_model_id = "google/gemma-3n-e2b-it"

audio_encoder = Gemma3nAudioEncoder.from_pretrained(encoder_id)
feature_extractor = Gemma3nAudioFeatureExtractor.from_pretrained(source_model_id)

device = "cuda" if torch.cuda.is_available() else "cpu"
audio_encoder.to(device)
audio_encoder.eval()

waveform, sampling_rate = sf.read("/path/to/your_audio_file.wav")


inputs = feature_extractor(
    [waveform],
    sampling_rate=sampling_rate,
    return_tensors="pt"
)

audio_mel = inputs["input_features"].to(device)
audio_mel_mask = (inputs["input_features_mask"] == 0).to(device)

with torch.inference_mode():

    audio_encodings, output_mask = audio_encoder(
        audio_mel=audio_mel,
        audio_mel_mask=audio_mel_mask
    )

print(audio_encodings.shape) # torch.Size([1, 18, 1536])
print(audio_encodings[0, :5, :10])
# tensor([[ 0.0014, -0.0044,  0.0003,  0.0084, -0.0076, -0.0194,  0.0071,  0.0160,
#           0.0137,  0.0146],
#         [-0.0153,  0.0051,  0.0111, -0.0134, -0.0032, -0.0134,  0.0112, -0.0163,
#           0.0050,  0.0036],
#         [ 0.0003, -0.0022,  0.0164, -0.0090, -0.0033, -0.0043,  0.0030, -0.0042,
#          -0.0060,  0.0066],
#         [-0.0006, -0.0194, -0.0006, -0.0097, -0.0049, -0.0132,  0.0012,  0.0175,
#          -0.0242, -0.0091],
#         [ 0.0127,  0.0122,  0.0125,  0.0277,  0.0116,  0.0152,  0.0142, -0.0099,
#          -0.0080, -0.0233]], device='cuda:0')

```

## Model Architecture
```
Gemma3nAudioEncoder(
  (subsample_conv_projection): Gemma3nAudioSubSampleConvProjection(
    (conv_0): Gemma3nAudioSSCPConvBlock(
      (conv): Conv2d(1, 128, kernel_size=(3, 3), stride=(2, 2), bias=False)
      (norm): Gemma3nAudioCumulativeGroupNorm()
      (activation): ReLU()
    )
    (conv_1): Gemma3nAudioSSCPConvBlock(
      (conv): Conv2d(128, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
      (norm): Gemma3nAudioCumulativeGroupNorm()
      (activation): ReLU()
    )
    (input_proj_linear): Linear(in_features=1024, out_features=1536, bias=False)
  )
  (conformer): ModuleList(
    (0-11): 12 x Gemma3nAudioConformerBlock(
      (ffw_layer_start): Gemma3nAudioConformerFeedForward(
        (pre_layer_norm): Gemma3nRMSNorm((1536,), eps=1e-06)
        (ffw_layer_1): Linear(in_features=1536, out_features=6144, bias=False)
        (ffw_layer_2): Linear(in_features=6144, out_features=1536, bias=False)
        (post_layer_norm): Gemma3nRMSNorm((1536,), eps=1e-06)
      )
      (attention): Gemma3nAudioConformerAttention(
        (pre_attn_norm): Gemma3nRMSNorm((1536,), eps=1e-06)
        (attn): Gemma3nAudioAttention(
          (relative_position_embedding): Gemma3nAudioRelativePositionEmbedding(
            (pos_proj): Linear(in_features=1536, out_features=1536, bias=False)
          )
          (q_proj): Linear(in_features=1536, out_features=1536, bias=False)
          (k_proj): Linear(in_features=1536, out_features=1536, bias=False)
          (v_proj): Linear(in_features=1536, out_features=1536, bias=False)
        )
        (post): Linear(in_features=1536, out_features=1536, bias=False)
        (post_norm): Gemma3nRMSNorm((1536,), eps=1e-06)
      )
      (lconv1d): Gemma3nAudioConformerLightConv1d(
        (pre_layer_norm): Gemma3nRMSNorm((1536,), eps=1e-06)
        (linear_start): Linear(in_features=1536, out_features=3072, bias=False)
        (depthwise_conv1d): Conv1d(1536, 1536, kernel_size=(5,), stride=(1,), groups=1536, bias=False)
        (conv_norm): Gemma3nRMSNorm((1536,), eps=1e-06)
        (linear_end): Linear(in_features=1536, out_features=1536, bias=False)
      )
      (ffw_layer_end): Gemma3nAudioConformerFeedForward(
        (pre_layer_norm): Gemma3nRMSNorm((1536,), eps=1e-06)
        (ffw_layer_1): Linear(in_features=1536, out_features=6144, bias=False)
        (ffw_layer_2): Linear(in_features=6144, out_features=1536, bias=False)
        (post_layer_norm): Gemma3nRMSNorm((1536,), eps=1e-06)
      )
      (norm): Gemma3nRMSNorm((1536,), eps=1e-06)
    )
  )
)
```