Anjan9320 commited on
Commit
c5614c9
·
verified ·
1 Parent(s): 062feb0

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +6 -3
model.py CHANGED
@@ -67,16 +67,19 @@ class INF5Model(PreTrainedModel):
67
  self.ema_model.load_state_dict(state_dict, strict=False)
68
 
69
  def _extract_embedding_from_audio_and_text(self, audio_path: str, text: str) -> torch.Tensor:
70
-
71
  device = next(self.parameters()).device # model device
72
 
73
- # Load audio waveform
74
  waveform, sample_rate = torchaudio.load(audio_path)
75
  target_sample_rate = 24000
76
  if sample_rate != target_sample_rate:
 
 
77
  resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sample_rate).to(device)
78
  waveform = resampler(waveform)
79
- waveform = waveform.to(device)
 
 
80
 
81
  # Forward pass - pass waveform and text directly to ema_model
82
  with torch.no_grad():
 
67
  self.ema_model.load_state_dict(state_dict, strict=False)
68
 
69
  def _extract_embedding_from_audio_and_text(self, audio_path: str, text: str) -> torch.Tensor:
 
70
  device = next(self.parameters()).device # model device
71
 
72
+ # Load audio waveform on CPU first
73
  waveform, sample_rate = torchaudio.load(audio_path)
74
  target_sample_rate = 24000
75
  if sample_rate != target_sample_rate:
76
+ # Move waveform to device before resampling to avoid device mismatch
77
+ waveform = waveform.to(device)
78
  resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sample_rate).to(device)
79
  waveform = resampler(waveform)
80
+ else:
81
+ # If no resampling, still move waveform to device for model
82
+ waveform = waveform.to(device)
83
 
84
  # Forward pass - pass waveform and text directly to ema_model
85
  with torch.no_grad():