yfyeung commited on
Commit
f6daa89
·
verified ·
1 Parent(s): 83f4649

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_clsp.py +10 -1
modeling_clsp.py CHANGED
@@ -102,6 +102,7 @@ class CLAP(nn.Module):
102
  )
103
 
104
  # text branch
 
105
  self.text_encoder = text_encoder = RobertaModel(
106
  RobertaConfig.from_pretrained("roberta-base")
107
  )
@@ -252,8 +253,16 @@ class CLAP(nn.Module):
252
  audio_encoder_out = F.normalize(audio_encoder_out, dim=-1)
253
 
254
  if text is not None:
 
 
 
 
 
 
 
 
 
255
  assert text["input_ids"].ndim == 2, text["input_ids"].shape
256
-
257
  text_encoder_out = self.forward_text_encoder(
258
  text, freeze_encoder=freeze_text_encoder
259
  )
 
102
  )
103
 
104
  # text branch
105
+ self.text_tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
106
  self.text_encoder = text_encoder = RobertaModel(
107
  RobertaConfig.from_pretrained("roberta-base")
108
  )
 
253
  audio_encoder_out = F.normalize(audio_encoder_out, dim=-1)
254
 
255
  if text is not None:
256
+ text = self.text_tokenizer(
257
+ text,
258
+ padding=True,
259
+ truncation=True,
260
+ return_tensors="pt",
261
+ )
262
+ text = {
263
+ k: v.to(device=next(self.parameters()).device) for k, v in text.items()
264
+ }
265
  assert text["input_ids"].ndim == 2, text["input_ids"].shape
 
266
  text_encoder_out = self.forward_text_encoder(
267
  text, freeze_encoder=freeze_text_encoder
268
  )