| # import os | |
| # import torch | |
| # from litgpt.generate.base import next_token_image_batch | |
| # import soundfile as sf | |
| # from utils.snac_utils import layershift, reconscruct_snac, reconstruct_tensors, get_time_str | |
| # from utils.snac_utils import get_snac, generate_audio_data | |
| # import clip | |
| # import inference | |
| # from tqdm import tqdm | |
| # from inference import OmniInference, load_model, load_audio, download_model | |
| # from inference import text_vocabsize, padded_text_vocabsize, get_text_stream | |
| # from PIL import Image | |
| # torch.set_printoptions(sci_mode=False) | |
| # _image = inference._image | |
| # _eoimage = inference._eoimage | |
| # _pad_t = inference._pad_t | |
| # _input_t = inference._input_t | |
| # _answer_t = inference._answer_t | |
| # _eot = inference._eot | |
| # _eoa = inference._eoa | |
| # _pad_a = inference._pad_a | |
| # _input_a = inference._input_a | |
| # _answer_a = inference._answer_a | |
| # def get_input_ids_ImageQA_ATBatch(mel, leng, whispermodel, device): | |
| # with torch.no_grad(): | |
| # mel = mel.unsqueeze(0).to(device) | |
| # audio_feature = whispermodel.embed_audio(mel)[0][:leng] | |
| # audio_len = audio_feature.size(0) | |
| # input_ids = [] | |
| # input_ids_item = [[] for i in range(8)] | |
| # for i in range(7): | |
| # input_ids_item[i] = [layershift(_image,i)] + [layershift(_pad_a,i)] * 50 + [layershift(_eoimage,i)] | |
| # input_ids_item[i] += [layershift(_input_a,i)]+[layershift(_pad_a,i)]*(audio_len)+[layershift(_eoa,i)] | |
| # input_ids_item[i] += [layershift(_answer_a,i)] | |
| # input_ids_item[-1] = [_pad_t]* (52 + 2 + audio_len) + [_answer_t] | |
| # input_ids_item = [torch.tensor(item) for item in input_ids_item] | |
| # input_ids.append(input_ids_item) | |
| # input_ids_item = [[] for i in range(8)] | |
| # for i in range(7): | |
| # input_ids_item[i] = [layershift(_image,i)] + [layershift(_pad_a,i)] * 50 + [layershift(_eoimage,i)] | |
| # input_ids_item[i] += [layershift(_input_a,i)]+[layershift(_pad_a,i)]*(audio_len)+[layershift(_eoa,i)] + [layershift(_pad_a,i)] | |
| # input_ids_item[-1] = [_pad_t]* (52 + 2 + audio_len) + [_answer_t] | |
| # input_ids_item = [torch.tensor(item) for item in input_ids_item] | |
| # input_ids.append(input_ids_item) | |
| # stacked_inputids = [[] for _ in range(8)] | |
| # for i in range(2): | |
| # for j in range(8): | |
| # stacked_inputids[j].append(input_ids[i][j]) | |
| # stacked_inputids = [torch.stack(tensors) for tensors in stacked_inputids] | |
| # return torch.stack([audio_feature,audio_feature]), stacked_inputids | |
| # def load_clip_model(ckpt_dir, device): | |
| # clip_model_path = ckpt_dir + "/ViT-B-32.pt" | |
| # if not os.path.exists(clip_model_path): | |
| # clip_model_path = "ViT-B/32" | |
| # clipmodel, clippreprocess = clip.load(clip_model_path, device=device) | |
| # return clipmodel, clippreprocess | |
| # class OmniVisionInference(OmniInference): | |
| # def __init__(self, ckpt_dir='./checkpoint', device='cuda:0'): | |
| # self.device = device | |
| # if not os.path.exists(ckpt_dir): | |
| # print(f"checkpoint directory {ckpt_dir} not found, downloading from huggingface") | |
| # download_model(ckpt_dir) | |
| # self.fabric, self.model, self.text_tokenizer, self.snacmodel, self.whispermodel = load_model(ckpt_dir, device) | |
| # self.clipmodel, self.clippreprocess = load_clip_model(ckpt_dir, device) | |
| # def warm_up(self, | |
| # audio_sample='./data/samples/vision_qa_audio.wav', | |
| # image_sample='./data/samples/vision_qa_image.jpg' | |
| # ): | |
| # for _ in self.run_vision_AA_batch_stream(audio_sample, image_sample, | |
| # save_path="./data/samples/vision_qa_output.wav", | |
| # warm_up=True): | |
| # pass | |
| # @torch.inference_mode() | |
| # def run_vision_AA_batch_stream(self, audio_path, image_path, | |
| # stream_stride=4, | |
| # max_returned_tokens=2048, | |
| # temperature=0.9, | |
| # top_k=1, | |
| # top_p=1.0, | |
| # eos_id_a=_eoa, | |
| # eos_id_t=_eot, | |
| # pad_id=_pad_t, | |
| # save_path=None, | |
| # warm_up=False | |
| # ): | |
| # with self.fabric.init_tensor(): | |
| # self.model.set_kv_cache(batch_size=2) | |
| # model = self.model | |
| # mel, leng = load_audio(audio_path) | |
| # img = Image.open(image_path) | |
| # audio_feature, input_ids = get_input_ids_ImageQA_ATBatch(mel, leng, self.whispermodel, self.device) | |
| # ima = self.clippreprocess(img).unsqueeze(0).to(self.device) | |
| # ima_feature = self.clipmodel.encode_image(ima).squeeze(0).to(self.device) | |
| # ima_feature = torch.stack([ima_feature.clone(),ima_feature.clone()]).to(self.device) | |
| # leng = [leng,leng] | |
| # task = ['ImageQA_A','ImageQA_AT'] | |
| # T = input_ids[0].size(1) | |
| # assert max_returned_tokens > T, f"max_returned_tokens {max_returned_tokens} should be greater than audio length {T}" | |
| # if model.max_seq_length < max_returned_tokens - 1: | |
| # raise NotImplementedError( | |
| # f"max_seq_length {model.max_seq_length} needs to be >= {max_returned_tokens - 1}" | |
| # ) | |
| # list_output = [[] for i in range(8)] | |
| # tokens_A , token_T = next_token_image_batch( | |
| # model, | |
| # audio_feature.to(torch.float32).to(self.device), | |
| # ima_feature.to(torch.float32).to(self.device) , | |
| # input_ids , | |
| # whisper_lens = leng , | |
| # task = task, | |
| # input_pos = torch.arange(0, T, device=self.device), | |
| # temperature=temperature, | |
| # top_k=top_k, | |
| # top_p=top_p | |
| # ) | |
| # for i in range(7): list_output[i].append(tokens_A[i].tolist()[0]) | |
| # list_output[7].append(token_T.tolist()[0]) | |
| # text_end = False | |
| # index = 1 | |
| # nums_generate = stream_stride | |
| # begin_generate = False | |
| # current_index = 0 | |
| # input_pos = torch.tensor([T], device=self.device) | |
| # model_input_ids = [[] for i in range(8)] | |
| # for i in range(7): | |
| # tokens_A[i] = tokens_A[i].clone() + padded_text_vocabsize+ i * 4160 | |
| # model_input_ids[i].append(tokens_A[i].clone().to(self.device).to(torch.int32)) | |
| # model_input_ids[i].append(torch.tensor([layershift(4097,i)],device=self.device)) | |
| # model_input_ids[i] = torch.stack(model_input_ids[i]) | |
| # model_input_ids[-1].append(token_T.clone().to(torch.int32)) | |
| # model_input_ids[-1].append(token_T.clone().to(torch.int32)) | |
| # model_input_ids[-1] = torch.stack(model_input_ids[-1]) | |
| # text_index = 0 | |
| # is_text_end = False | |
| # for _ in tqdm(range(2, max_returned_tokens - T + 1)): | |
| # tokens_A , token_T = next_token_image_batch(model, None , None , | |
| # input_ids = model_input_ids, | |
| # whisper_lens= None, | |
| # task = None, | |
| # input_pos = input_pos, | |
| # temperature=temperature, | |
| # top_k=top_k, | |
| # top_p=top_p) | |
| # if text_end: | |
| # token_T = torch.tensor([_pad_t], device=self.device) | |
| # if tokens_A[-1] == eos_id_a: | |
| # break | |
| # if token_T == eos_id_t: | |
| # text_end = True | |
| # for i in range(7): list_output[i].append(tokens_A[i].tolist()[0]) | |
| # list_output[7].append(token_T.tolist()[0]) | |
| # if index == 7: | |
| # begin_generate = True | |
| # if begin_generate: | |
| # current_index += 1 | |
| # if current_index == nums_generate: | |
| # current_index = 0 | |
| # snac = get_snac(list_output,index,nums_generate) | |
| # audio_stream = generate_audio_data(snac, self.snacmodel, self.device) | |
| # if is_text_end: | |
| # text_stream = "" | |
| # else: | |
| # text_stream, text_index, is_text_end = get_text_stream(list_output, text_index, self.text_tokenizer) | |
| # yield (audio_stream, text_stream) | |
| # if warm_up: | |
| # break | |
| # input_pos = input_pos.add_(1) | |
| # model_input_ids = [[] for i in range(8)] | |
| # for i in range(7): | |
| # tokens_A[i] = tokens_A[i].clone() + padded_text_vocabsize+ i * 4160 | |
| # model_input_ids[i].append(tokens_A[i].clone().to(self.device).to(torch.int32)) | |
| # model_input_ids[i].append(torch.tensor([layershift(4097,i)],device=self.device)) | |
| # model_input_ids[i] = torch.stack(model_input_ids[i]) | |
| # model_input_ids[-1].append(token_T.clone().to(torch.int32)) | |
| # model_input_ids[-1].append(token_T.clone().to(torch.int32)) | |
| # model_input_ids[-1] = torch.stack(model_input_ids[-1]) | |
| # index += 1 | |
| # text_tokens = list_output[-1] | |
| # if text_vocabsize in text_tokens: | |
| # text_tokens = text_tokens[:text_tokens.index(text_vocabsize)] | |
| # res_text = self.text_tokenizer.decode(torch.tensor(text_tokens)) | |
| # print(f"text output: {res_text}") | |
| # if save_path is not None: | |
| # audiolist = reconscruct_snac(list_output) | |
| # audio = reconstruct_tensors(audiolist) | |
| # with torch.inference_mode(): | |
| # audio_hat = self.snacmodel.decode(audio) | |
| # sf.write(save_path, audio_hat.squeeze().cpu().numpy(), 24000) | |
| # model.clear_kv_cache() | |
| # def test_vision_infer(): | |
| # client = OmniVisionInference() | |
| # client.warm_up() | |
| # input_audio_path = './data/samples/vision_qa_audio.wav' | |
| # input_image_path = './data/samples/vision_qa_image.jpg' | |
| # res_text = "" | |
| # for audio_stream, text_stream in client.run_vision_AA_batch_stream( | |
| # input_audio_path, | |
| # input_image_path, | |
| # save_path="./vision_qa_output.wav" | |
| # ): | |
| # res_text += text_stream | |
| # print(f"text_output: {res_text}") | |
| # if __name__ == "__main__": | |
| # test_vision_infer() | |
| # # 1234232434232 | |
| # # 1234232434232 | |
| # # 1234232434232 | |
| # # 1234232434232 | |
| # # 1234232434232 | |
| # # 1234232434232 | |
| # # 5069 | |
| # # 3670 | |