martylabs commited on
Commit
271e08d
·
verified ·
1 Parent(s): f266d6a

Update generate_multitalk.py

Browse files
Files changed (1) hide show
  1. generate_multitalk.py +1 -1
generate_multitalk.py CHANGED
@@ -214,7 +214,7 @@ def _parse_args():
214
  return args
215
 
216
  def custom_init(device, wav2vec):
217
- audio_encoder = Wav2Vec2Model.from_pretrained(wav2vec, local_files_only=True).to(device)
218
  audio_encoder.feature_extractor._freeze_parameters()
219
  wav2vec_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(wav2vec, local_files_only=True)
220
  return wav2vec_feature_extractor, audio_encoder
 
214
  return args
215
 
216
  def custom_init(device, wav2vec):
217
+ audio_encoder = Wav2Vec2ForCTC.from_pretrained(args.wav2vec_dir, attn_implementation="eager").to(device)
218
  audio_encoder.feature_extractor._freeze_parameters()
219
  wav2vec_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(wav2vec, local_files_only=True)
220
  return wav2vec_feature_extractor, audio_encoder