Update model.py
Browse files
model.py
CHANGED
|
@@ -67,16 +67,19 @@ class INF5Model(PreTrainedModel):
|
|
| 67 |
self.ema_model.load_state_dict(state_dict, strict=False)
|
| 68 |
|
| 69 |
def _extract_embedding_from_audio_and_text(self, audio_path: str, text: str) -> torch.Tensor:
|
| 70 |
-
|
| 71 |
device = next(self.parameters()).device # model device
|
| 72 |
|
| 73 |
-
# Load audio waveform
|
| 74 |
waveform, sample_rate = torchaudio.load(audio_path)
|
| 75 |
target_sample_rate = 24000
|
| 76 |
if sample_rate != target_sample_rate:
|
|
|
|
|
|
|
| 77 |
resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sample_rate).to(device)
|
| 78 |
waveform = resampler(waveform)
|
| 79 |
-
|
|
|
|
|
|
|
| 80 |
|
| 81 |
# Forward pass - pass waveform and text directly to ema_model
|
| 82 |
with torch.no_grad():
|
|
|
|
| 67 |
self.ema_model.load_state_dict(state_dict, strict=False)
|
| 68 |
|
| 69 |
def _extract_embedding_from_audio_and_text(self, audio_path: str, text: str) -> torch.Tensor:
|
|
|
|
| 70 |
device = next(self.parameters()).device # model device
|
| 71 |
|
| 72 |
+
# Load audio waveform on CPU first
|
| 73 |
waveform, sample_rate = torchaudio.load(audio_path)
|
| 74 |
target_sample_rate = 24000
|
| 75 |
if sample_rate != target_sample_rate:
|
| 76 |
+
# Move waveform to device before resampling to avoid device mismatch
|
| 77 |
+
waveform = waveform.to(device)
|
| 78 |
resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sample_rate).to(device)
|
| 79 |
waveform = resampler(waveform)
|
| 80 |
+
else:
|
| 81 |
+
# If no resampling, still move waveform to device for model
|
| 82 |
+
waveform = waveform.to(device)
|
| 83 |
|
| 84 |
# Forward pass - pass waveform and text directly to ema_model
|
| 85 |
with torch.no_grad():
|