massabaali commited on
Commit
728c030
Β·
verified Β·
1 Parent(s): 357cead

Upload CoLMbo weights, config, and source code

Browse files
Files changed (3) hide show
  1. config.json +9 -2
  2. modeling_colmbo.py +174 -52
  3. pytorch_model.bin +2 -2
config.json CHANGED
@@ -10,8 +10,15 @@
10
  "n_mels": 80,
11
  "embedding_dim": 192,
12
  "channel": 1024,
13
- "prefix_length": 10,
14
- "gpt_model_name": "gpt2",
 
 
 
 
 
 
 
15
  "sample_rate": 16000,
16
  "torch_dtype": "float32"
17
  }
 
10
  "n_mels": 80,
11
  "embedding_dim": 192,
12
  "channel": 1024,
13
+ "map_type": "mlp",
14
+ "prefix_size": 192,
15
+ "sid_prefix_length": 40,
16
+ "sid_prefix_length_clip": 40,
17
+ "num_layers": 8,
18
+ "norm_sid_emb": false,
19
+ "text_decoder": "gpt2",
20
+ "tok_len": 67,
21
+ "text_prefix_length": 10,
22
  "sample_rate": 16000,
23
  "torch_dtype": "float32"
24
  }
modeling_colmbo.py CHANGED
@@ -1,12 +1,20 @@
1
  """
2
  modeling_colmbo.py β€” CoLMbo HuggingFace-compatible model wrapper.
3
- Registered with AutoModel via trust_remote_code=True.
 
 
 
 
 
 
4
  """
5
 
 
 
6
  import torch
7
  import torch.nn as nn
8
  import torchaudio
9
- from transformers import PreTrainedModel, PretrainedConfig
10
  from transformers.modeling_outputs import BaseModelOutput
11
 
12
 
@@ -17,11 +25,22 @@ class CoLMboConfig(PretrainedConfig):
17
 
18
  def __init__(
19
  self,
 
20
  n_mels: int = 80,
21
  embedding_dim: int = 192,
22
  channel: int = 1024,
23
- prefix_length: int = 10,
24
- gpt_model_name: str = "gpt2",
 
 
 
 
 
 
 
 
 
 
25
  sample_rate: int = 16000,
26
  **kwargs,
27
  ):
@@ -29,8 +48,15 @@ class CoLMboConfig(PretrainedConfig):
29
  self.n_mels = n_mels
30
  self.embedding_dim = embedding_dim
31
  self.channel = channel
32
- self.prefix_length = prefix_length
33
- self.gpt_model_name = gpt_model_name
 
 
 
 
 
 
 
34
  self.sample_rate = sample_rate
35
 
36
 
@@ -40,9 +66,14 @@ class CoLMboModel(PreTrainedModel):
40
  """
41
  CoLMbo: Speaker Language Model for Descriptive Profiling.
42
 
43
- Usage:
44
- model = AutoModel.from_pretrained("cmu-mlsp/CoLMbo", trust_remote_code=True)
45
- output = model.describe_file("audio.wav", prompt="What is the speaker's gender?")
 
 
 
 
 
46
  """
47
 
48
  config_class = CoLMboConfig
@@ -51,30 +82,44 @@ class CoLMboModel(PreTrainedModel):
51
  def __init__(self, config: CoLMboConfig):
52
  super().__init__(config)
53
 
 
54
  from encoder.encoder import Model
55
  from load_data.extract_fbanks import Mel_Spectrogram
 
56
 
 
57
  self.mel_extractor = Mel_Spectrogram()
58
 
59
- # Speaker encoder
60
  self.sid_model = Model(
61
  n_mels=config.n_mels,
62
  embedding_dim=config.embedding_dim,
63
  channel=config.channel,
64
  )
65
 
66
- # Mapper: linear projection from speaker embedding β†’ LM token space
67
- self.mapper = nn.Linear(config.embedding_dim, 768 * config.prefix_length)
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
- # GPT LM head
70
- from transformers import GPT2LMHeadModel
71
- self.gpt = GPT2LMHeadModel.from_pretrained(config.gpt_model_name)
72
 
73
- self.prefix_length = config.prefix_length
74
  self.post_init()
75
 
76
  # ------------------------------------------------------------------
77
- # Forward β€” returns speaker embedding (pipeline compatibility)
78
  # ------------------------------------------------------------------
79
  def forward(self, input_values: torch.Tensor) -> BaseModelOutput:
80
  mel = self.mel_extractor(input_values)
@@ -82,50 +127,121 @@ class CoLMboModel(PreTrainedModel):
82
  return BaseModelOutput(last_hidden_state=spk_emb.unsqueeze(1))
83
 
84
  # ------------------------------------------------------------------
85
- # Internal helpers
86
  # ------------------------------------------------------------------
87
- def _get_sid_prefix(self, spk_emb: torch.Tensor) -> torch.Tensor:
88
- batch = spk_emb.size(0)
89
- prefix = self.mapper(spk_emb)
90
- return prefix.view(batch, self.prefix_length, -1)
91
-
92
- def _get_prompt_prefix(self, prompt: str, device):
93
- from transformers import GPT2Tokenizer
94
- tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
95
- tokens = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
96
- return self.gpt.transformer.wte(tokens)
97
 
98
- @torch.no_grad()
99
- def _generate_beam(self, prefix_emb: torch.Tensor, num_beams: int = 5) -> list:
100
- from transformers import GPT2Tokenizer
101
- tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
102
- tokenizer.pad_token = tokenizer.eos_token
103
- generated = self.gpt.generate(
104
- inputs_embeds=prefix_emb,
105
- max_new_tokens=100,
106
- num_beams=num_beams,
107
- early_stopping=True,
108
- pad_token_id=tokenizer.eos_token_id,
109
  )
110
- return [tokenizer.decode(g, skip_special_tokens=True) for g in generated]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
 
112
  # ------------------------------------------------------------------
113
- # Main API
114
  # ------------------------------------------------------------------
115
  @torch.no_grad()
116
  def describe(
117
  self,
118
  waveform: torch.Tensor,
119
  prompt: str = "Please describe the speaker.",
120
- num_beams: int = 5,
 
 
121
  ) -> str:
122
  """
123
  Generate a natural language description of the speaker.
124
 
125
  Args:
126
- waveform: raw audio [1, T] at 16 kHz
127
- prompt: e.g. "What is the speaker's dialect?"
128
- num_beams: beam search width
 
 
129
 
130
  Returns:
131
  str: generated description
@@ -135,16 +251,22 @@ class CoLMboModel(PreTrainedModel):
135
  >>> waveform, sr = torchaudio.load("audio.wav")
136
  >>> print(model.describe(waveform, "What is the speaker's age?"))
137
  """
138
- device = next(self.parameters()).device
139
  self.eval()
140
 
141
- mel = self.mel_extractor(waveform).to(device)
142
- spk_emb = self.sid_model(mel)
143
- sid_pfx = self._get_sid_prefix(spk_emb)
144
- pmt_pfx = self._get_prompt_prefix(prompt, device)
145
- prefix = torch.cat((sid_pfx, pmt_pfx), dim=1)
146
 
147
- return self._generate_beam(prefix, num_beams=num_beams)[0]
 
 
 
 
 
 
148
 
149
  @torch.no_grad()
150
  def describe_file(
 
1
  """
2
  modeling_colmbo.py β€” CoLMbo HuggingFace-compatible model wrapper.
3
+
4
+ Faithfully wraps the original ExpWrapper inference pipeline so that
5
+ users can run:
6
+
7
+ from transformers import AutoModel
8
+ model = AutoModel.from_pretrained("cmu-mlsp/CoLMbo", trust_remote_code=True)
9
+ text = model.describe_file("audio.wav", "What is the speaker's dialect?")
10
  """
11
 
12
+ import os
13
+ import numpy as np
14
  import torch
15
  import torch.nn as nn
16
  import torchaudio
17
+ from transformers import PreTrainedModel, PretrainedConfig, GPT2LMHeadModel, AutoTokenizer
18
  from transformers.modeling_outputs import BaseModelOutput
19
 
20
 
 
25
 
26
  def __init__(
27
  self,
28
+ # speaker encoder
29
  n_mels: int = 80,
30
  embedding_dim: int = 192,
31
  channel: int = 1024,
32
+ # mapper / prefix
33
+ map_type: str = "mlp",
34
+ prefix_size: int = 192, # matches sid embedding dim
35
+ sid_prefix_length: int = 40,
36
+ sid_prefix_length_clip: int = 40,
37
+ num_layers: int = 8,
38
+ norm_sid_emb: bool = False,
39
+ # LM
40
+ text_decoder: str = "gpt2",
41
+ tok_len: int = 67,
42
+ text_prefix_length: int = 10,
43
+ # audio
44
  sample_rate: int = 16000,
45
  **kwargs,
46
  ):
 
48
  self.n_mels = n_mels
49
  self.embedding_dim = embedding_dim
50
  self.channel = channel
51
+ self.map_type = map_type
52
+ self.prefix_size = prefix_size
53
+ self.sid_prefix_length = sid_prefix_length
54
+ self.sid_prefix_length_clip = sid_prefix_length_clip
55
+ self.num_layers = num_layers
56
+ self.norm_sid_emb = norm_sid_emb
57
+ self.text_decoder = text_decoder
58
+ self.tok_len = tok_len
59
+ self.text_prefix_length = text_prefix_length
60
  self.sample_rate = sample_rate
61
 
62
 
 
66
  """
67
  CoLMbo: Speaker Language Model for Descriptive Profiling.
68
 
69
+ Architecture:
70
+ audio β†’ Mel_Spectrogram β†’ ECAPA encoder β†’ sid_mapper β†’ prefix tokens
71
+ prefix tokens + prompt tokens β†’ GPT-2 LM β†’ natural language description
72
+
73
+ Example:
74
+ >>> from transformers import AutoModel
75
+ >>> model = AutoModel.from_pretrained("cmu-mlsp/CoLMbo", trust_remote_code=True)
76
+ >>> print(model.describe_file("speaker.wav", "What is the speaker's dialect?"))
77
  """
78
 
79
  config_class = CoLMboConfig
 
82
  def __init__(self, config: CoLMboConfig):
83
  super().__init__(config)
84
 
85
+ # Local imports β€” resolved from files shipped in the HF repo
86
  from encoder.encoder import Model
87
  from load_data.extract_fbanks import Mel_Spectrogram
88
+ from mapper import get_sid_mapper
89
 
90
+ # ── Audio frontend ────────────────────────────────────────────
91
  self.mel_extractor = Mel_Spectrogram()
92
 
93
+ # ── Speaker encoder (ECAPA-TDNN) ──────────────────────────────
94
  self.sid_model = Model(
95
  n_mels=config.n_mels,
96
  embedding_dim=config.embedding_dim,
97
  channel=config.channel,
98
  )
99
 
100
+ # ── GPT-2 decoder ─────────────────────────────────────────────
101
+ self.gpt = GPT2LMHeadModel.from_pretrained(config.text_decoder)
102
+ self.gpt_embedding_size = self.gpt.transformer.wte.weight.shape[1]
103
+
104
+ # ── Speaker β†’ prefix mapper ───────────────────────────────────
105
+ self.sid_mapper = get_sid_mapper(
106
+ config.map_type,
107
+ None,
108
+ config.prefix_size,
109
+ self.gpt_embedding_size,
110
+ config.sid_prefix_length,
111
+ config.sid_prefix_length_clip,
112
+ config.num_layers,
113
+ )
114
 
115
+ # ── Tokenizer ───────────────────────────────────────���─────────
116
+ self.tokenizer = AutoTokenizer.from_pretrained(config.text_decoder)
117
+ self.tokenizer.add_special_tokens({'pad_token': '!'})
118
 
 
119
  self.post_init()
120
 
121
  # ------------------------------------------------------------------
122
+ # HF-standard forward (returns speaker embedding for pipeline compat)
123
  # ------------------------------------------------------------------
124
  def forward(self, input_values: torch.Tensor) -> BaseModelOutput:
125
  mel = self.mel_extractor(input_values)
 
127
  return BaseModelOutput(last_hidden_state=spk_emb.unsqueeze(1))
128
 
129
  # ------------------------------------------------------------------
130
+ # Internal helpers β€” mirror ExpWrapper exactly
131
  # ------------------------------------------------------------------
132
+ def _get_sid_prefix(self, sid_embeddings: torch.Tensor) -> torch.Tensor:
133
+ if self.config.norm_sid_emb:
134
+ sid_embeddings = sid_embeddings / sid_embeddings.norm(2, -1).reshape(-1, 1)
135
+ return (
136
+ self.sid_mapper(sid_embeddings)
137
+ .contiguous()
138
+ .view(-1, self.config.sid_prefix_length, self.gpt_embedding_size)
139
+ )
 
 
140
 
141
+ def _preprocess_prompt_single(self, text: str, device) -> dict:
142
+ tok = self.tokenizer.encode_plus(
143
+ text=text,
144
+ add_special_tokens=True,
145
+ max_length=10,
146
+ pad_to_max_length=True,
147
+ return_tensors="pt",
148
+ truncation=True,
 
 
 
149
  )
150
+ return {k: v.reshape(-1).to(device) for k, v in tok.items()}
151
+
152
+ def _get_prompt_prefix(self, text: str, device) -> torch.Tensor:
153
+ preprocessed = self._preprocess_prompt_single(text, device)
154
+ # Stack to [1, seq_len] then embed
155
+ input_ids = preprocessed["input_ids"].unsqueeze(0)
156
+ with torch.no_grad():
157
+ return self.gpt.transformer.wte(input_ids) # [1, seq_len, 768]
158
+
159
+ def _generate_beam(
160
+ self,
161
+ prefix_emb: torch.Tensor,
162
+ beam_size: int = 1,
163
+ entry_length: int = 80,
164
+ temperature: float = 1.0,
165
+ stop_token: str = " <|endoftext|>",
166
+ ) -> list:
167
+ """Exact port of ExpWrapper.generate_beam."""
168
+ stop_token_index = self.tokenizer.encode(stop_token)[0]
169
+ tokens = None
170
+ scores = None
171
+ device = next(self.gpt.parameters()).device
172
+ seq_lengths = torch.ones(beam_size, device=device)
173
+ is_stopped = torch.zeros(beam_size, device=device, dtype=torch.bool)
174
+
175
+ with torch.no_grad():
176
+ generated = prefix_emb
177
+ for i in range(entry_length):
178
+ outputs = self.gpt(inputs_embeds=generated)
179
+ logits = outputs.logits
180
+ logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0)
181
+ logits = logits.softmax(-1).log()
182
+
183
+ if scores is None:
184
+ scores, next_tokens = logits.topk(beam_size, -1)
185
+ generated = generated.expand(beam_size, *generated.shape[1:])
186
+ next_tokens, scores = next_tokens.permute(1, 0), scores.squeeze(0)
187
+ tokens = next_tokens if tokens is None else torch.cat(
188
+ (tokens.expand(beam_size, *tokens.shape[1:]), next_tokens), dim=1
189
+ )
190
+ else:
191
+ logits[is_stopped] = -float(np.inf)
192
+ logits[is_stopped, 0] = 0
193
+ scores_sum = scores[:, None] + logits
194
+ seq_lengths[~is_stopped] += 1
195
+ scores_sum_average = scores_sum / seq_lengths[:, None]
196
+ scores_sum_average, next_tokens = scores_sum_average.view(-1).topk(beam_size, -1)
197
+ next_tokens_source = next_tokens // scores_sum.shape[1]
198
+ seq_lengths = seq_lengths[next_tokens_source]
199
+ next_tokens = next_tokens % scores_sum.shape[1]
200
+ next_tokens = next_tokens.unsqueeze(1)
201
+ tokens = tokens[next_tokens_source]
202
+ tokens = torch.cat((tokens, next_tokens), dim=1)
203
+ generated = generated[next_tokens_source]
204
+ scores = scores_sum_average * seq_lengths
205
+ is_stopped = is_stopped[next_tokens_source]
206
+
207
+ next_token_embed = self.gpt.transformer.wte(
208
+ next_tokens.squeeze()
209
+ ).view(generated.shape[0], 1, -1)
210
+ generated = torch.cat((generated, next_token_embed), dim=1)
211
+ is_stopped = is_stopped + next_tokens.eq(stop_token_index).squeeze()
212
+ if is_stopped.all():
213
+ break
214
+
215
+ scores = scores / seq_lengths
216
+ output_list = tokens.cpu().numpy()
217
+ output_texts = [
218
+ self.tokenizer.decode(output[: int(length)])
219
+ for output, length in zip(output_list, seq_lengths)
220
+ ]
221
+ order = scores.argsort(descending=True)
222
+ return [output_texts[i] for i in order]
223
 
224
  # ------------------------------------------------------------------
225
+ # Public API
226
  # ------------------------------------------------------------------
227
  @torch.no_grad()
228
  def describe(
229
  self,
230
  waveform: torch.Tensor,
231
  prompt: str = "Please describe the speaker.",
232
+ beam_size: int = 1,
233
+ entry_length: int = 80,
234
+ temperature: float = 1.0,
235
  ) -> str:
236
  """
237
  Generate a natural language description of the speaker.
238
 
239
  Args:
240
+ waveform: raw audio tensor [1, T] at 16 kHz
241
+ prompt: e.g. "What is the speaker's dialect?"
242
+ beam_size: beam search width (default 1 = greedy)
243
+ entry_length: max tokens to generate
244
+ temperature: sampling temperature
245
 
246
  Returns:
247
  str: generated description
 
251
  >>> waveform, sr = torchaudio.load("audio.wav")
252
  >>> print(model.describe(waveform, "What is the speaker's age?"))
253
  """
254
+ device = next(self.gpt.parameters()).device
255
  self.eval()
256
 
257
+ mel = self.mel_extractor(waveform).to(device)
258
+ spk_emb = self.sid_model(mel)
259
+ sids_prefix = self._get_sid_prefix(spk_emb) # [1, sid_prefix_len, 768]
260
+ pmt_prefix = self._get_prompt_prefix(prompt, device) # [1, tok_len, 768]
261
+ prefix_emb = torch.cat((sids_prefix, pmt_prefix), dim=1) # [1, total_len, 768]
262
 
263
+ texts = self._generate_beam(
264
+ prefix_emb,
265
+ beam_size=beam_size,
266
+ entry_length=entry_length,
267
+ temperature=temperature,
268
+ )
269
+ return texts[0]
270
 
271
  @torch.no_grad()
272
  def describe_file(
pytorch_model.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:e0d80efbeffb56f4038bf9d320d15b5377d12b1cb85833e908d9f0f6b5c2bbab
3
- size 2066033810
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:54b52bd0b2c80e0afcddeebf6c30ce4d9645c265b546cdc95c2cf36ba7564b3f
3
+ size 1982720694