tear_141 / quantized /infer.py
JOSESMOKE's picture
Upload folder using huggingface_hub
a6bd5e1 verified
# import os
# import torch
# from litgpt.generate.base import next_token_image_batch
# import soundfile as sf
# from utils.snac_utils import layershift, reconscruct_snac, reconstruct_tensors, get_time_str
# from utils.snac_utils import get_snac, generate_audio_data
# import clip
# import inference
# from tqdm import tqdm
# from inference import OmniInference, load_model, load_audio, download_model
# from inference import text_vocabsize, padded_text_vocabsize, get_text_stream
# from PIL import Image
# torch.set_printoptions(sci_mode=False)
# _image = inference._image
# _eoimage = inference._eoimage
# _pad_t = inference._pad_t
# _input_t = inference._input_t
# _answer_t = inference._answer_t
# _eot = inference._eot
# _eoa = inference._eoa
# _pad_a = inference._pad_a
# _input_a = inference._input_a
# _answer_a = inference._answer_a
# def get_input_ids_ImageQA_ATBatch(mel, leng, whispermodel, device):
# with torch.no_grad():
# mel = mel.unsqueeze(0).to(device)
# audio_feature = whispermodel.embed_audio(mel)[0][:leng]
# audio_len = audio_feature.size(0)
# input_ids = []
# input_ids_item = [[] for i in range(8)]
# for i in range(7):
# input_ids_item[i] = [layershift(_image,i)] + [layershift(_pad_a,i)] * 50 + [layershift(_eoimage,i)]
# input_ids_item[i] += [layershift(_input_a,i)]+[layershift(_pad_a,i)]*(audio_len)+[layershift(_eoa,i)]
# input_ids_item[i] += [layershift(_answer_a,i)]
# input_ids_item[-1] = [_pad_t]* (52 + 2 + audio_len) + [_answer_t]
# input_ids_item = [torch.tensor(item) for item in input_ids_item]
# input_ids.append(input_ids_item)
# input_ids_item = [[] for i in range(8)]
# for i in range(7):
# input_ids_item[i] = [layershift(_image,i)] + [layershift(_pad_a,i)] * 50 + [layershift(_eoimage,i)]
# input_ids_item[i] += [layershift(_input_a,i)]+[layershift(_pad_a,i)]*(audio_len)+[layershift(_eoa,i)] + [layershift(_pad_a,i)]
# input_ids_item[-1] = [_pad_t]* (52 + 2 + audio_len) + [_answer_t]
# input_ids_item = [torch.tensor(item) for item in input_ids_item]
# input_ids.append(input_ids_item)
# stacked_inputids = [[] for _ in range(8)]
# for i in range(2):
# for j in range(8):
# stacked_inputids[j].append(input_ids[i][j])
# stacked_inputids = [torch.stack(tensors) for tensors in stacked_inputids]
# return torch.stack([audio_feature,audio_feature]), stacked_inputids
# def load_clip_model(ckpt_dir, device):
# clip_model_path = ckpt_dir + "/ViT-B-32.pt"
# if not os.path.exists(clip_model_path):
# clip_model_path = "ViT-B/32"
# clipmodel, clippreprocess = clip.load(clip_model_path, device=device)
# return clipmodel, clippreprocess
# class OmniVisionInference(OmniInference):
# def __init__(self, ckpt_dir='./checkpoint', device='cuda:0'):
# self.device = device
# if not os.path.exists(ckpt_dir):
# print(f"checkpoint directory {ckpt_dir} not found, downloading from huggingface")
# download_model(ckpt_dir)
# self.fabric, self.model, self.text_tokenizer, self.snacmodel, self.whispermodel = load_model(ckpt_dir, device)
# self.clipmodel, self.clippreprocess = load_clip_model(ckpt_dir, device)
# def warm_up(self,
# audio_sample='./data/samples/vision_qa_audio.wav',
# image_sample='./data/samples/vision_qa_image.jpg'
# ):
# for _ in self.run_vision_AA_batch_stream(audio_sample, image_sample,
# save_path="./data/samples/vision_qa_output.wav",
# warm_up=True):
# pass
# @torch.inference_mode()
# def run_vision_AA_batch_stream(self, audio_path, image_path,
# stream_stride=4,
# max_returned_tokens=2048,
# temperature=0.9,
# top_k=1,
# top_p=1.0,
# eos_id_a=_eoa,
# eos_id_t=_eot,
# pad_id=_pad_t,
# save_path=None,
# warm_up=False
# ):
# with self.fabric.init_tensor():
# self.model.set_kv_cache(batch_size=2)
# model = self.model
# mel, leng = load_audio(audio_path)
# img = Image.open(image_path)
# audio_feature, input_ids = get_input_ids_ImageQA_ATBatch(mel, leng, self.whispermodel, self.device)
# ima = self.clippreprocess(img).unsqueeze(0).to(self.device)
# ima_feature = self.clipmodel.encode_image(ima).squeeze(0).to(self.device)
# ima_feature = torch.stack([ima_feature.clone(),ima_feature.clone()]).to(self.device)
# leng = [leng,leng]
# task = ['ImageQA_A','ImageQA_AT']
# T = input_ids[0].size(1)
# assert max_returned_tokens > T, f"max_returned_tokens {max_returned_tokens} should be greater than audio length {T}"
# if model.max_seq_length < max_returned_tokens - 1:
# raise NotImplementedError(
# f"max_seq_length {model.max_seq_length} needs to be >= {max_returned_tokens - 1}"
# )
# list_output = [[] for i in range(8)]
# tokens_A , token_T = next_token_image_batch(
# model,
# audio_feature.to(torch.float32).to(self.device),
# ima_feature.to(torch.float32).to(self.device) ,
# input_ids ,
# whisper_lens = leng ,
# task = task,
# input_pos = torch.arange(0, T, device=self.device),
# temperature=temperature,
# top_k=top_k,
# top_p=top_p
# )
# for i in range(7): list_output[i].append(tokens_A[i].tolist()[0])
# list_output[7].append(token_T.tolist()[0])
# text_end = False
# index = 1
# nums_generate = stream_stride
# begin_generate = False
# current_index = 0
# input_pos = torch.tensor([T], device=self.device)
# model_input_ids = [[] for i in range(8)]
# for i in range(7):
# tokens_A[i] = tokens_A[i].clone() + padded_text_vocabsize+ i * 4160
# model_input_ids[i].append(tokens_A[i].clone().to(self.device).to(torch.int32))
# model_input_ids[i].append(torch.tensor([layershift(4097,i)],device=self.device))
# model_input_ids[i] = torch.stack(model_input_ids[i])
# model_input_ids[-1].append(token_T.clone().to(torch.int32))
# model_input_ids[-1].append(token_T.clone().to(torch.int32))
# model_input_ids[-1] = torch.stack(model_input_ids[-1])
# text_index = 0
# is_text_end = False
# for _ in tqdm(range(2, max_returned_tokens - T + 1)):
# tokens_A , token_T = next_token_image_batch(model, None , None ,
# input_ids = model_input_ids,
# whisper_lens= None,
# task = None,
# input_pos = input_pos,
# temperature=temperature,
# top_k=top_k,
# top_p=top_p)
# if text_end:
# token_T = torch.tensor([_pad_t], device=self.device)
# if tokens_A[-1] == eos_id_a:
# break
# if token_T == eos_id_t:
# text_end = True
# for i in range(7): list_output[i].append(tokens_A[i].tolist()[0])
# list_output[7].append(token_T.tolist()[0])
# if index == 7:
# begin_generate = True
# if begin_generate:
# current_index += 1
# if current_index == nums_generate:
# current_index = 0
# snac = get_snac(list_output,index,nums_generate)
# audio_stream = generate_audio_data(snac, self.snacmodel, self.device)
# if is_text_end:
# text_stream = ""
# else:
# text_stream, text_index, is_text_end = get_text_stream(list_output, text_index, self.text_tokenizer)
# yield (audio_stream, text_stream)
# if warm_up:
# break
# input_pos = input_pos.add_(1)
# model_input_ids = [[] for i in range(8)]
# for i in range(7):
# tokens_A[i] = tokens_A[i].clone() + padded_text_vocabsize+ i * 4160
# model_input_ids[i].append(tokens_A[i].clone().to(self.device).to(torch.int32))
# model_input_ids[i].append(torch.tensor([layershift(4097,i)],device=self.device))
# model_input_ids[i] = torch.stack(model_input_ids[i])
# model_input_ids[-1].append(token_T.clone().to(torch.int32))
# model_input_ids[-1].append(token_T.clone().to(torch.int32))
# model_input_ids[-1] = torch.stack(model_input_ids[-1])
# index += 1
# text_tokens = list_output[-1]
# if text_vocabsize in text_tokens:
# text_tokens = text_tokens[:text_tokens.index(text_vocabsize)]
# res_text = self.text_tokenizer.decode(torch.tensor(text_tokens))
# print(f"text output: {res_text}")
# if save_path is not None:
# audiolist = reconscruct_snac(list_output)
# audio = reconstruct_tensors(audiolist)
# with torch.inference_mode():
# audio_hat = self.snacmodel.decode(audio)
# sf.write(save_path, audio_hat.squeeze().cpu().numpy(), 24000)
# model.clear_kv_cache()
# def test_vision_infer():
# client = OmniVisionInference()
# client.warm_up()
# input_audio_path = './data/samples/vision_qa_audio.wav'
# input_image_path = './data/samples/vision_qa_image.jpg'
# res_text = ""
# for audio_stream, text_stream in client.run_vision_AA_batch_stream(
# input_audio_path,
# input_image_path,
# save_path="./vision_qa_output.wav"
# ):
# res_text += text_stream
# print(f"text_output: {res_text}")
# if __name__ == "__main__":
# test_vision_infer()
# # 1234232434232
# # 1234232434232
# # 1234232434232
# # 1234232434232
# # 1234232434232
# # 1234232434232
# # 5069
# # 3670