| | import os |
| | import torch |
| | from litgpt.generate.base import next_token_image_batch |
| | import soundfile as sf |
| | from utils.snac_utils import layershift, reconscruct_snac, reconstruct_tensors, get_time_str |
| | from utils.snac_utils import get_snac, generate_audio_data |
| | import clip |
| | import inference |
| | from tqdm import tqdm |
| | from inference import OmniInference, load_model, load_audio, download_model |
| | from inference import text_vocabsize, padded_text_vocabsize, get_text_stream |
| | from PIL import Image |
| |
|
| |
|
| | torch.set_printoptions(sci_mode=False) |
| |
|
| | _image = inference._image |
| | _eoimage = inference._eoimage |
| | _pad_t = inference._pad_t |
| | _input_t = inference._input_t |
| | _answer_t = inference._answer_t |
| | _eot = inference._eot |
| | _eoa = inference._eoa |
| | _pad_a = inference._pad_a |
| | _input_a = inference._input_a |
| | _answer_a = inference._answer_a |
| |
|
| |
|
| | def get_input_ids_ImageQA_ATBatch(mel, leng, whispermodel, device): |
| | |
| | with torch.no_grad(): |
| | mel = mel.unsqueeze(0).to(device) |
| | audio_feature = whispermodel.embed_audio(mel)[0][:leng] |
| | |
| | audio_len = audio_feature.size(0) |
| | |
| | input_ids = [] |
| | input_ids_item = [[] for i in range(8)] |
| | for i in range(7): |
| | input_ids_item[i] = [layershift(_image,i)] + [layershift(_pad_a,i)] * 50 + [layershift(_eoimage,i)] |
| | input_ids_item[i] += [layershift(_input_a,i)]+[layershift(_pad_a,i)]*(audio_len)+[layershift(_eoa,i)] |
| | input_ids_item[i] += [layershift(_answer_a,i)] |
| |
|
| | input_ids_item[-1] = [_pad_t]* (52 + 2 + audio_len) + [_answer_t] |
| | input_ids_item = [torch.tensor(item) for item in input_ids_item] |
| |
|
| | input_ids.append(input_ids_item) |
| |
|
| | input_ids_item = [[] for i in range(8)] |
| | for i in range(7): |
| | input_ids_item[i] = [layershift(_image,i)] + [layershift(_pad_a,i)] * 50 + [layershift(_eoimage,i)] |
| | input_ids_item[i] += [layershift(_input_a,i)]+[layershift(_pad_a,i)]*(audio_len)+[layershift(_eoa,i)] + [layershift(_pad_a,i)] |
| |
|
| | input_ids_item[-1] = [_pad_t]* (52 + 2 + audio_len) + [_answer_t] |
| |
|
| | input_ids_item = [torch.tensor(item) for item in input_ids_item] |
| | input_ids.append(input_ids_item) |
| |
|
| | stacked_inputids = [[] for _ in range(8)] |
| | for i in range(2): |
| | for j in range(8): |
| | stacked_inputids[j].append(input_ids[i][j]) |
| | stacked_inputids = [torch.stack(tensors) for tensors in stacked_inputids] |
| |
|
| | return torch.stack([audio_feature,audio_feature]), stacked_inputids |
| |
|
| | |
| | def load_clip_model(ckpt_dir, device): |
| | clip_model_path = ckpt_dir + "/ViT-B-32.pt" |
| | if not os.path.exists(clip_model_path): |
| | clip_model_path = "ViT-B/32" |
| | clipmodel, clippreprocess = clip.load(clip_model_path, device=device) |
| | return clipmodel, clippreprocess |
| |
|
| | |
| | class OmniVisionInference(OmniInference): |
| |
|
| | def __init__(self, ckpt_dir='./checkpoint', device='cuda:0'): |
| | self.device = device |
| | if not os.path.exists(ckpt_dir): |
| | print(f"checkpoint directory {ckpt_dir} not found, downloading from huggingface") |
| | download_model(ckpt_dir) |
| | self.fabric, self.model, self.text_tokenizer, self.snacmodel, self.whispermodel = load_model(ckpt_dir, device) |
| | self.clipmodel, self.clippreprocess = load_clip_model(ckpt_dir, device) |
| |
|
| | def warm_up(self, |
| | audio_sample='./data/samples/vision_qa_audio.wav', |
| | image_sample='./data/samples/vision_qa_image.jpg' |
| | ): |
| | for _ in self.run_vision_AA_batch_stream(audio_sample, image_sample, |
| | save_path="./data/samples/vision_qa_output.wav", |
| | warm_up=True): |
| | pass |
| |
|
| | @torch.inference_mode() |
| | def run_vision_AA_batch_stream(self, audio_path, image_path, |
| | stream_stride=4, |
| | max_returned_tokens=2048, |
| | temperature=0.9, |
| | top_k=1, |
| | top_p=1.0, |
| | eos_id_a=_eoa, |
| | eos_id_t=_eot, |
| | pad_id=_pad_t, |
| | save_path=None, |
| | warm_up=False |
| | ): |
| | with self.fabric.init_tensor(): |
| | self.model.set_kv_cache(batch_size=2) |
| |
|
| | model = self.model |
| |
|
| | mel, leng = load_audio(audio_path) |
| | img = Image.open(image_path) |
| |
|
| | audio_feature, input_ids = get_input_ids_ImageQA_ATBatch(mel, leng, self.whispermodel, self.device) |
| | ima = self.clippreprocess(img).unsqueeze(0).to(self.device) |
| | ima_feature = self.clipmodel.encode_image(ima).squeeze(0).to(self.device) |
| | |
| | ima_feature = torch.stack([ima_feature.clone(),ima_feature.clone()]).to(self.device) |
| | leng = [leng,leng] |
| | task = ['ImageQA_A','ImageQA_AT'] |
| |
|
| | T = input_ids[0].size(1) |
| | assert max_returned_tokens > T, f"max_returned_tokens {max_returned_tokens} should be greater than audio length {T}" |
| |
|
| | if model.max_seq_length < max_returned_tokens - 1: |
| | raise NotImplementedError( |
| | f"max_seq_length {model.max_seq_length} needs to be >= {max_returned_tokens - 1}" |
| | ) |
| |
|
| | list_output = [[] for i in range(8)] |
| |
|
| | tokens_A , token_T = next_token_image_batch( |
| | model, |
| | audio_feature.to(torch.float32).to(self.device), |
| | ima_feature.to(torch.float32).to(self.device) , |
| | input_ids , |
| | whisper_lens = leng , |
| | task = task, |
| | input_pos = torch.arange(0, T, device=self.device), |
| | temperature=temperature, |
| | top_k=top_k, |
| | top_p=top_p |
| | ) |
| | for i in range(7): list_output[i].append(tokens_A[i].tolist()[0]) |
| | list_output[7].append(token_T.tolist()[0]) |
| |
|
| | text_end = False |
| | index = 1 |
| | nums_generate = stream_stride |
| | begin_generate = False |
| | current_index = 0 |
| | input_pos = torch.tensor([T], device=self.device) |
| |
|
| | model_input_ids = [[] for i in range(8)] |
| | for i in range(7): |
| | tokens_A[i] = tokens_A[i].clone() + padded_text_vocabsize+ i * 4160 |
| | model_input_ids[i].append(tokens_A[i].clone().to(self.device).to(torch.int32)) |
| | model_input_ids[i].append(torch.tensor([layershift(4097,i)],device=self.device)) |
| | model_input_ids[i] = torch.stack(model_input_ids[i]) |
| | |
| | model_input_ids[-1].append(token_T.clone().to(torch.int32)) |
| | model_input_ids[-1].append(token_T.clone().to(torch.int32)) |
| | model_input_ids[-1] = torch.stack(model_input_ids[-1]) |
| |
|
| | text_index = 0 |
| | is_text_end = False |
| |
|
| | for _ in tqdm(range(2, max_returned_tokens - T + 1)): |
| | |
| | tokens_A , token_T = next_token_image_batch(model, None , None , |
| | input_ids = model_input_ids, |
| | whisper_lens= None, |
| | task = None, |
| | input_pos = input_pos, |
| | temperature=temperature, |
| | top_k=top_k, |
| | top_p=top_p) |
| |
|
| | if text_end: |
| | token_T = torch.tensor([_pad_t], device=self.device) |
| |
|
| | if tokens_A[-1] == eos_id_a: |
| | break |
| | if token_T == eos_id_t: |
| | text_end = True |
| |
|
| | for i in range(7): list_output[i].append(tokens_A[i].tolist()[0]) |
| | list_output[7].append(token_T.tolist()[0]) |
| | |
| |
|
| | if index == 7: |
| | begin_generate = True |
| | |
| | if begin_generate: |
| | current_index += 1 |
| | if current_index == nums_generate: |
| | current_index = 0 |
| | snac = get_snac(list_output,index,nums_generate) |
| | audio_stream = generate_audio_data(snac, self.snacmodel, self.device) |
| | if is_text_end: |
| | text_stream = "" |
| | else: |
| | text_stream, text_index, is_text_end = get_text_stream(list_output, text_index, self.text_tokenizer) |
| |
|
| | yield (audio_stream, text_stream) |
| |
|
| | if warm_up: |
| | break |
| |
|
| | input_pos = input_pos.add_(1) |
| | model_input_ids = [[] for i in range(8)] |
| | for i in range(7): |
| | tokens_A[i] = tokens_A[i].clone() + padded_text_vocabsize+ i * 4160 |
| | model_input_ids[i].append(tokens_A[i].clone().to(self.device).to(torch.int32)) |
| | model_input_ids[i].append(torch.tensor([layershift(4097,i)],device=self.device)) |
| | model_input_ids[i] = torch.stack(model_input_ids[i]) |
| | |
| | model_input_ids[-1].append(token_T.clone().to(torch.int32)) |
| | model_input_ids[-1].append(token_T.clone().to(torch.int32)) |
| | model_input_ids[-1] = torch.stack(model_input_ids[-1]) |
| |
|
| | index += 1 |
| | |
| | text_tokens = list_output[-1] |
| | if text_vocabsize in text_tokens: |
| | text_tokens = text_tokens[:text_tokens.index(text_vocabsize)] |
| | res_text = self.text_tokenizer.decode(torch.tensor(text_tokens)) |
| | print(f"text output: {res_text}") |
| |
|
| | if save_path is not None: |
| | audiolist = reconscruct_snac(list_output) |
| | audio = reconstruct_tensors(audiolist) |
| | with torch.inference_mode(): |
| | audio_hat = self.snacmodel.decode(audio) |
| | sf.write(save_path, audio_hat.squeeze().cpu().numpy(), 24000) |
| |
|
| | model.clear_kv_cache() |
| |
|
| | |
| | def test_vision_infer(): |
| | client = OmniVisionInference() |
| | client.warm_up() |
| | input_audio_path = './data/samples/vision_qa_audio.wav' |
| | input_image_path = './data/samples/vision_qa_image.jpg' |
| |
|
| | res_text = "" |
| | for audio_stream, text_stream in client.run_vision_AA_batch_stream( |
| | input_audio_path, |
| | input_image_path, |
| | save_path="./vision_qa_output.wav" |
| | ): |
| | res_text += text_stream |
| | print(f"text_output: {res_text}") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | test_vision_infer() |
| |
|