Commit ·
66c268e
1
Parent(s): 12af9b4
added float16 loading change to whisper audio tower
Browse files- modeling_bunny_phi.py +2 -4
modeling_bunny_phi.py
CHANGED
|
@@ -615,7 +615,7 @@ class WhisperAudioTower(nn.Module):
|
|
| 615 |
if self.is_loaded:
|
| 616 |
return
|
| 617 |
|
| 618 |
-
self.audio_tower = WhisperModel.from_pretrained(self.audio_tower_name)
|
| 619 |
|
| 620 |
self.audio_tower.requires_grad_(False)
|
| 621 |
self.audio_tower.eval()
|
|
@@ -2627,10 +2627,8 @@ class BunnyPhiForCausalLM(PhiForCausalLM, BunnyMetaForCausalLM):
|
|
| 2627 |
audio_tower = self.get_audio_tower()
|
| 2628 |
if not audio_tower.is_loaded:
|
| 2629 |
audio_tower.load_model()
|
| 2630 |
-
audio_tower.to(device='cuda', dtype=torch.float16)
|
| 2631 |
audio_processor = audio_tower.audio_processor
|
| 2632 |
-
audio_processor
|
| 2633 |
-
features = audio_processor(audio, sampling_rate=16000, return_tensors="pt").input_features # replace 16k with arg later
|
| 2634 |
audio_tensor = features.to(self.device, dtype=self.dtype)
|
| 2635 |
return audio_tensor
|
| 2636 |
|
|
|
|
| 615 |
if self.is_loaded:
|
| 616 |
return
|
| 617 |
|
| 618 |
+
self.audio_tower = WhisperModel.from_pretrained(self.audio_tower_name, torch_dtype=torch.float16)
|
| 619 |
|
| 620 |
self.audio_tower.requires_grad_(False)
|
| 621 |
self.audio_tower.eval()
|
|
|
|
| 2627 |
audio_tower = self.get_audio_tower()
|
| 2628 |
if not audio_tower.is_loaded:
|
| 2629 |
audio_tower.load_model()
|
|
|
|
| 2630 |
audio_processor = audio_tower.audio_processor
|
| 2631 |
+
features = audio_processor(audio, sampling_rate=16000, return_tensors="pt", device='cuda').input_features # replace 16k with arg later
|
|
|
|
| 2632 |
audio_tensor = features.to(self.device, dtype=self.dtype)
|
| 2633 |
return audio_tensor
|
| 2634 |
|