Fix OOM
Browse files- app.py +0 -3
- models/vallex.py +3 -0
app.py
CHANGED
|
@@ -116,7 +116,6 @@ def transcribe_one(model, audio_path):
|
|
| 116 |
return lang, text_pr
|
| 117 |
|
| 118 |
def make_npz_prompt(name, uploaded_audio, recorded_audio, transcript_content):
|
| 119 |
-
global model, text_collater, text_tokenizer, audio_tokenizer
|
| 120 |
clear_prompts()
|
| 121 |
audio_prompt = uploaded_audio if uploaded_audio is not None else recorded_audio
|
| 122 |
sr, wav_pr = audio_prompt
|
|
@@ -159,7 +158,6 @@ def make_npz_prompt(name, uploaded_audio, recorded_audio, transcript_content):
|
|
| 159 |
|
| 160 |
|
| 161 |
def make_prompt(name, wav, sr, save=True):
|
| 162 |
-
global whisper_model
|
| 163 |
if not isinstance(wav, torch.FloatTensor):
|
| 164 |
wav = torch.tensor(wav)
|
| 165 |
if wav.abs().max() > 1:
|
|
@@ -185,7 +183,6 @@ def make_prompt(name, wav, sr, save=True):
|
|
| 185 |
def infer_from_audio(text, language, accent, audio_prompt, record_audio_prompt, transcript_content):
|
| 186 |
if len(text) > 150:
|
| 187 |
return "Rejected, Text too long (should be less than 150 characters)", None
|
| 188 |
-
global model, text_collater, text_tokenizer, audio_tokenizer
|
| 189 |
audio_prompt = audio_prompt if audio_prompt is not None else record_audio_prompt
|
| 190 |
sr, wav_pr = audio_prompt
|
| 191 |
if len(wav_pr) / sr > 15:
|
|
|
|
| 116 |
return lang, text_pr
|
| 117 |
|
| 118 |
def make_npz_prompt(name, uploaded_audio, recorded_audio, transcript_content):
|
|
|
|
| 119 |
clear_prompts()
|
| 120 |
audio_prompt = uploaded_audio if uploaded_audio is not None else recorded_audio
|
| 121 |
sr, wav_pr = audio_prompt
|
|
|
|
| 158 |
|
| 159 |
|
| 160 |
def make_prompt(name, wav, sr, save=True):
|
|
|
|
| 161 |
if not isinstance(wav, torch.FloatTensor):
|
| 162 |
wav = torch.tensor(wav)
|
| 163 |
if wav.abs().max() > 1:
|
|
|
|
| 183 |
def infer_from_audio(text, language, accent, audio_prompt, record_audio_prompt, transcript_content):
|
| 184 |
if len(text) > 150:
|
| 185 |
return "Rejected, Text too long (should be less than 150 characters)", None
|
|
|
|
| 186 |
audio_prompt = audio_prompt if audio_prompt is not None else record_audio_prompt
|
| 187 |
sr, wav_pr = audio_prompt
|
| 188 |
if len(wav_pr) / sr > 15:
|
models/vallex.py
CHANGED
|
@@ -14,6 +14,7 @@
|
|
| 14 |
|
| 15 |
import random
|
| 16 |
from typing import Dict, Iterator, List, Tuple, Union
|
|
|
|
| 17 |
|
| 18 |
import numpy as np
|
| 19 |
import torch
|
|
@@ -462,6 +463,7 @@ class VALLE(VALLF):
|
|
| 462 |
**kwargs,
|
| 463 |
):
|
| 464 |
raise NotImplementedError
|
|
|
|
| 465 |
def inference(
|
| 466 |
self,
|
| 467 |
x: torch.Tensor,
|
|
@@ -674,6 +676,7 @@ class VALLE(VALLF):
|
|
| 674 |
y_emb[:, prefix_len:] += embedding_layer(samples)
|
| 675 |
|
| 676 |
assert len(codes) == self.num_quantizers
|
|
|
|
| 677 |
return torch.stack(codes, dim=-1)
|
| 678 |
|
| 679 |
def continual(
|
|
|
|
| 14 |
|
| 15 |
import random
|
| 16 |
from typing import Dict, Iterator, List, Tuple, Union
|
| 17 |
+
import gc
|
| 18 |
|
| 19 |
import numpy as np
|
| 20 |
import torch
|
|
|
|
| 463 |
**kwargs,
|
| 464 |
):
|
| 465 |
raise NotImplementedError
|
| 466 |
+
|
| 467 |
def inference(
|
| 468 |
self,
|
| 469 |
x: torch.Tensor,
|
|
|
|
| 676 |
y_emb[:, prefix_len:] += embedding_layer(samples)
|
| 677 |
|
| 678 |
assert len(codes) == self.num_quantizers
|
| 679 |
+
gc.collect()
|
| 680 |
return torch.stack(codes, dim=-1)
|
| 681 |
|
| 682 |
def continual(
|