Spaces:
Running
Running
Update inference.py
Browse files- inference.py +14 -14
inference.py
CHANGED
|
@@ -147,8 +147,8 @@ def load_audio(path):
|
|
| 147 |
|
| 148 |
def A1_A2_batch(fabric, audio_feature, input_ids, leng, model, text_tokenizer, step,
|
| 149 |
snacmodel, out_dir=None):
|
| 150 |
-
|
| 151 |
-
|
| 152 |
tokenlist = generate_TA_BATCH(
|
| 153 |
model,
|
| 154 |
audio_feature,
|
|
@@ -191,8 +191,8 @@ def A1_A2_batch(fabric, audio_feature, input_ids, leng, model, text_tokenizer, s
|
|
| 191 |
|
| 192 |
|
| 193 |
def A1_T2(fabric, audio_feature, input_ids, leng, model, text_tokenizer, step):
|
| 194 |
-
|
| 195 |
-
|
| 196 |
tokenlist = generate_AT(
|
| 197 |
model,
|
| 198 |
audio_feature,
|
|
@@ -214,8 +214,8 @@ def A1_T2(fabric, audio_feature, input_ids, leng, model, text_tokenizer, step):
|
|
| 214 |
|
| 215 |
def A1_A2(fabric, audio_feature, input_ids, leng, model, text_tokenizer, step,
|
| 216 |
snacmodel, out_dir=None):
|
| 217 |
-
|
| 218 |
-
|
| 219 |
tokenlist = generate_AA(
|
| 220 |
model,
|
| 221 |
audio_feature,
|
|
@@ -256,8 +256,8 @@ def A1_A2(fabric, audio_feature, input_ids, leng, model, text_tokenizer, step,
|
|
| 256 |
|
| 257 |
|
| 258 |
def A1_T1(fabric, audio_feature, input_ids, leng, model, text_tokenizer, step):
|
| 259 |
-
|
| 260 |
-
|
| 261 |
tokenlist = generate_ASR(
|
| 262 |
model,
|
| 263 |
audio_feature,
|
|
@@ -280,8 +280,8 @@ def A1_T1(fabric, audio_feature, input_ids, leng, model, text_tokenizer, step):
|
|
| 280 |
|
| 281 |
def T1_A2(fabric, input_ids, model, text_tokenizer, step,
|
| 282 |
snacmodel, out_dir=None):
|
| 283 |
-
|
| 284 |
-
|
| 285 |
tokenlist = generate_TA(
|
| 286 |
model,
|
| 287 |
None,
|
|
@@ -325,8 +325,8 @@ def T1_A2(fabric, input_ids, model, text_tokenizer, step,
|
|
| 325 |
|
| 326 |
def T1_T2(fabric, input_ids, model, text_tokenizer, step):
|
| 327 |
|
| 328 |
-
|
| 329 |
-
|
| 330 |
tokenlist = generate_TT(
|
| 331 |
model,
|
| 332 |
None,
|
|
@@ -386,6 +386,7 @@ class OmniInference:
|
|
| 386 |
pass
|
| 387 |
|
| 388 |
@torch.inference_mode()
|
|
|
|
| 389 |
def run_AT_batch_stream(self,
|
| 390 |
audio_path,
|
| 391 |
stream_stride=4,
|
|
@@ -400,8 +401,7 @@ class OmniInference:
|
|
| 400 |
assert os.path.exists(audio_path), f"audio file {audio_path} not found"
|
| 401 |
model = self.model
|
| 402 |
|
| 403 |
-
|
| 404 |
-
model.set_kv_cache(batch_size=2)
|
| 405 |
|
| 406 |
mel, leng = load_audio(audio_path)
|
| 407 |
audio_feature, input_ids = get_input_ids_whisper_ATBatch(mel, leng, self.whispermodel, self.device)
|
|
|
|
| 147 |
|
| 148 |
def A1_A2_batch(fabric, audio_feature, input_ids, leng, model, text_tokenizer, step,
|
| 149 |
snacmodel, out_dir=None):
|
| 150 |
+
|
| 151 |
+
model.set_kv_cache(batch_size=2)
|
| 152 |
tokenlist = generate_TA_BATCH(
|
| 153 |
model,
|
| 154 |
audio_feature,
|
|
|
|
| 191 |
|
| 192 |
|
| 193 |
def A1_T2(fabric, audio_feature, input_ids, leng, model, text_tokenizer, step):
|
| 194 |
+
|
| 195 |
+
model.set_kv_cache(batch_size=1)
|
| 196 |
tokenlist = generate_AT(
|
| 197 |
model,
|
| 198 |
audio_feature,
|
|
|
|
| 214 |
|
| 215 |
def A1_A2(fabric, audio_feature, input_ids, leng, model, text_tokenizer, step,
|
| 216 |
snacmodel, out_dir=None):
|
| 217 |
+
|
| 218 |
+
model.set_kv_cache(batch_size=1)
|
| 219 |
tokenlist = generate_AA(
|
| 220 |
model,
|
| 221 |
audio_feature,
|
|
|
|
| 256 |
|
| 257 |
|
| 258 |
def A1_T1(fabric, audio_feature, input_ids, leng, model, text_tokenizer, step):
|
| 259 |
+
|
| 260 |
+
model.set_kv_cache(batch_size=1)
|
| 261 |
tokenlist = generate_ASR(
|
| 262 |
model,
|
| 263 |
audio_feature,
|
|
|
|
| 280 |
|
| 281 |
def T1_A2(fabric, input_ids, model, text_tokenizer, step,
|
| 282 |
snacmodel, out_dir=None):
|
| 283 |
+
|
| 284 |
+
model.set_kv_cache(batch_size=1)
|
| 285 |
tokenlist = generate_TA(
|
| 286 |
model,
|
| 287 |
None,
|
|
|
|
| 325 |
|
| 326 |
def T1_T2(fabric, input_ids, model, text_tokenizer, step):
|
| 327 |
|
| 328 |
+
|
| 329 |
+
model.set_kv_cache(batch_size=1)
|
| 330 |
tokenlist = generate_TT(
|
| 331 |
model,
|
| 332 |
None,
|
|
|
|
| 386 |
pass
|
| 387 |
|
| 388 |
@torch.inference_mode()
|
| 389 |
+
@spaces.GPU
|
| 390 |
def run_AT_batch_stream(self,
|
| 391 |
audio_path,
|
| 392 |
stream_stride=4,
|
|
|
|
| 401 |
assert os.path.exists(audio_path), f"audio file {audio_path} not found"
|
| 402 |
model = self.model
|
| 403 |
|
| 404 |
+
model.set_kv_cache(batch_size=2)
|
|
|
|
| 405 |
|
| 406 |
mel, leng = load_audio(audio_path)
|
| 407 |
audio_feature, input_ids = get_input_ids_whisper_ATBatch(mel, leng, self.whispermodel, self.device)
|