Spaces:
Running
Running
Upload 3 files
Browse files- inference.py +19 -21
inference.py
CHANGED
|
@@ -2,7 +2,6 @@ import os
|
|
| 2 |
import lightning as L
|
| 3 |
import torch
|
| 4 |
import time
|
| 5 |
-
import spaces
|
| 6 |
from snac import SNAC
|
| 7 |
from litgpt import Tokenizer
|
| 8 |
from litgpt.utils import (
|
|
@@ -147,8 +146,8 @@ def load_audio(path):
|
|
| 147 |
|
| 148 |
def A1_A2_batch(fabric, audio_feature, input_ids, leng, model, text_tokenizer, step,
|
| 149 |
snacmodel, out_dir=None):
|
| 150 |
-
|
| 151 |
-
|
| 152 |
tokenlist = generate_TA_BATCH(
|
| 153 |
model,
|
| 154 |
audio_feature,
|
|
@@ -191,8 +190,8 @@ def A1_A2_batch(fabric, audio_feature, input_ids, leng, model, text_tokenizer, s
|
|
| 191 |
|
| 192 |
|
| 193 |
def A1_T2(fabric, audio_feature, input_ids, leng, model, text_tokenizer, step):
|
| 194 |
-
|
| 195 |
-
|
| 196 |
tokenlist = generate_AT(
|
| 197 |
model,
|
| 198 |
audio_feature,
|
|
@@ -214,8 +213,8 @@ def A1_T2(fabric, audio_feature, input_ids, leng, model, text_tokenizer, step):
|
|
| 214 |
|
| 215 |
def A1_A2(fabric, audio_feature, input_ids, leng, model, text_tokenizer, step,
|
| 216 |
snacmodel, out_dir=None):
|
| 217 |
-
|
| 218 |
-
|
| 219 |
tokenlist = generate_AA(
|
| 220 |
model,
|
| 221 |
audio_feature,
|
|
@@ -256,8 +255,8 @@ def A1_A2(fabric, audio_feature, input_ids, leng, model, text_tokenizer, step,
|
|
| 256 |
|
| 257 |
|
| 258 |
def A1_T1(fabric, audio_feature, input_ids, leng, model, text_tokenizer, step):
|
| 259 |
-
|
| 260 |
-
|
| 261 |
tokenlist = generate_ASR(
|
| 262 |
model,
|
| 263 |
audio_feature,
|
|
@@ -280,8 +279,8 @@ def A1_T1(fabric, audio_feature, input_ids, leng, model, text_tokenizer, step):
|
|
| 280 |
|
| 281 |
def T1_A2(fabric, input_ids, model, text_tokenizer, step,
|
| 282 |
snacmodel, out_dir=None):
|
| 283 |
-
|
| 284 |
-
|
| 285 |
tokenlist = generate_TA(
|
| 286 |
model,
|
| 287 |
None,
|
|
@@ -325,8 +324,8 @@ def T1_A2(fabric, input_ids, model, text_tokenizer, step,
|
|
| 325 |
|
| 326 |
def T1_T2(fabric, input_ids, model, text_tokenizer, step):
|
| 327 |
|
| 328 |
-
|
| 329 |
-
|
| 330 |
tokenlist = generate_TT(
|
| 331 |
model,
|
| 332 |
None,
|
|
@@ -356,13 +355,12 @@ def load_model(ckpt_dir, device):
|
|
| 356 |
config.post_adapter = False
|
| 357 |
|
| 358 |
with fabric.init_module(empty_init=False):
|
| 359 |
-
model = GPT(config
|
| 360 |
|
| 361 |
-
|
| 362 |
state_dict = lazy_load(ckpt_dir + "/lit_model.pth")
|
| 363 |
model.load_state_dict(state_dict, strict=True)
|
| 364 |
-
model
|
| 365 |
-
model.eval()
|
| 366 |
|
| 367 |
return fabric, model, text_tokenizer, snacmodel, whispermodel
|
| 368 |
|
|
@@ -385,8 +383,7 @@ class OmniInference:
|
|
| 385 |
for _ in self.run_AT_batch_stream(sample):
|
| 386 |
pass
|
| 387 |
|
| 388 |
-
|
| 389 |
-
@spaces.GPU
|
| 390 |
def run_AT_batch_stream(self,
|
| 391 |
audio_path,
|
| 392 |
stream_stride=4,
|
|
@@ -401,7 +398,8 @@ class OmniInference:
|
|
| 401 |
assert os.path.exists(audio_path), f"audio file {audio_path} not found"
|
| 402 |
model = self.model
|
| 403 |
|
| 404 |
-
|
|
|
|
| 405 |
|
| 406 |
mel, leng = load_audio(audio_path)
|
| 407 |
audio_feature, input_ids = get_input_ids_whisper_ATBatch(mel, leng, self.whispermodel, self.device)
|
|
@@ -419,7 +417,7 @@ class OmniInference:
|
|
| 419 |
list_output = [[] for i in range(8)]
|
| 420 |
tokens_A, token_T = next_token_batch(
|
| 421 |
model,
|
| 422 |
-
audio_feature.to(torch.float32).to(device),
|
| 423 |
input_ids,
|
| 424 |
[T - 3, T - 3],
|
| 425 |
["A1T2", "A1T2"],
|
|
|
|
| 2 |
import lightning as L
|
| 3 |
import torch
|
| 4 |
import time
|
|
|
|
| 5 |
from snac import SNAC
|
| 6 |
from litgpt import Tokenizer
|
| 7 |
from litgpt.utils import (
|
|
|
|
| 146 |
|
| 147 |
def A1_A2_batch(fabric, audio_feature, input_ids, leng, model, text_tokenizer, step,
|
| 148 |
snacmodel, out_dir=None):
|
| 149 |
+
with fabric.init_tensor():
|
| 150 |
+
model.set_kv_cache(batch_size=2)
|
| 151 |
tokenlist = generate_TA_BATCH(
|
| 152 |
model,
|
| 153 |
audio_feature,
|
|
|
|
| 190 |
|
| 191 |
|
| 192 |
def A1_T2(fabric, audio_feature, input_ids, leng, model, text_tokenizer, step):
|
| 193 |
+
with fabric.init_tensor():
|
| 194 |
+
model.set_kv_cache(batch_size=1)
|
| 195 |
tokenlist = generate_AT(
|
| 196 |
model,
|
| 197 |
audio_feature,
|
|
|
|
| 213 |
|
| 214 |
def A1_A2(fabric, audio_feature, input_ids, leng, model, text_tokenizer, step,
|
| 215 |
snacmodel, out_dir=None):
|
| 216 |
+
with fabric.init_tensor():
|
| 217 |
+
model.set_kv_cache(batch_size=1)
|
| 218 |
tokenlist = generate_AA(
|
| 219 |
model,
|
| 220 |
audio_feature,
|
|
|
|
| 255 |
|
| 256 |
|
| 257 |
def A1_T1(fabric, audio_feature, input_ids, leng, model, text_tokenizer, step):
|
| 258 |
+
with fabric.init_tensor():
|
| 259 |
+
model.set_kv_cache(batch_size=1)
|
| 260 |
tokenlist = generate_ASR(
|
| 261 |
model,
|
| 262 |
audio_feature,
|
|
|
|
| 279 |
|
| 280 |
def T1_A2(fabric, input_ids, model, text_tokenizer, step,
|
| 281 |
snacmodel, out_dir=None):
|
| 282 |
+
with fabric.init_tensor():
|
| 283 |
+
model.set_kv_cache(batch_size=1)
|
| 284 |
tokenlist = generate_TA(
|
| 285 |
model,
|
| 286 |
None,
|
|
|
|
| 324 |
|
| 325 |
def T1_T2(fabric, input_ids, model, text_tokenizer, step):
|
| 326 |
|
| 327 |
+
with fabric.init_tensor():
|
| 328 |
+
model.set_kv_cache(batch_size=1)
|
| 329 |
tokenlist = generate_TT(
|
| 330 |
model,
|
| 331 |
None,
|
|
|
|
| 355 |
config.post_adapter = False
|
| 356 |
|
| 357 |
with fabric.init_module(empty_init=False):
|
| 358 |
+
model = GPT(config)
|
| 359 |
|
| 360 |
+
model = fabric.setup(model)
|
| 361 |
state_dict = lazy_load(ckpt_dir + "/lit_model.pth")
|
| 362 |
model.load_state_dict(state_dict, strict=True)
|
| 363 |
+
model.to(device).eval()
|
|
|
|
| 364 |
|
| 365 |
return fabric, model, text_tokenizer, snacmodel, whispermodel
|
| 366 |
|
|
|
|
| 383 |
for _ in self.run_AT_batch_stream(sample):
|
| 384 |
pass
|
| 385 |
|
| 386 |
+
@torch.inference_mode()
|
|
|
|
| 387 |
def run_AT_batch_stream(self,
|
| 388 |
audio_path,
|
| 389 |
stream_stride=4,
|
|
|
|
| 398 |
assert os.path.exists(audio_path), f"audio file {audio_path} not found"
|
| 399 |
model = self.model
|
| 400 |
|
| 401 |
+
with self.fabric.init_tensor():
|
| 402 |
+
model.set_kv_cache(batch_size=2)
|
| 403 |
|
| 404 |
mel, leng = load_audio(audio_path)
|
| 405 |
audio_feature, input_ids = get_input_ids_whisper_ATBatch(mel, leng, self.whispermodel, self.device)
|
|
|
|
| 417 |
list_output = [[] for i in range(8)]
|
| 418 |
tokens_A, token_T = next_token_batch(
|
| 419 |
model,
|
| 420 |
+
audio_feature.to(torch.float32).to(model.device),
|
| 421 |
input_ids,
|
| 422 |
[T - 3, T - 3],
|
| 423 |
["A1T2", "A1T2"],
|