Spaces:
Sleeping
Sleeping
File size: 5,305 Bytes
6cda9a1 912afa9 91e0ece 2a9f4d0 91e0ece 1b67e19 91e0ece 2c8dad9 91e0ece 6cda9a1 2a9f4d0 91e0ece 6cda9a1 b89d6d1 2a9f4d0 91e0ece 2a9f4d0 6cda9a1 91e0ece 2a9f4d0 91e0ece 75dfe2d 2a9f4d0 91e0ece 6cda9a1 91e0ece 2a9f4d0 52e4280 91e0ece 2a9f4d0 75dfe2d 2a9f4d0 91e0ece 2a9f4d0 1cd5299 2a9f4d0 1cd5299 52e4280 2a9f4d0 52e4280 2a9f4d0 52e4280 2a9f4d0 52e4280 715a82c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 |
import torch
import whisperx
import gradio as gr
from peft import PeftModel
from configs import get_config_phase2
from transformers import AutoTokenizer, AutoProcessor, CLIPVisionModel, AutoModelForCausalLM
config = get_config_phase2()
clip_model = CLIPVisionModel.from_pretrained(config.get("clip_model_name"))
base_model = AutoModelForCausalLM.from_pretrained(
config.get("phi2_model_name"),
low_cpu_mem_usage=True,
return_dict=True,
torch_dtype=torch.float32,
trust_remote_code=True
)
ckpts = "ckpts/Qlora_adaptor/"
phi2_model = PeftModel.from_pretrained(base_model, ckpts)
phi2_model = phi2_model.merge_and_unload().to(config.get("device"))
projection_layer = torch.nn.Linear(config.get("clip_embed"), config.get("phi_embed"))
projection_layer.load_state_dict(torch.load('./ckpts/model_phase2.pth', map_location=config.get("device")))
# tokenizer
tokenizer = AutoTokenizer.from_pretrained(config.get("phi2_model_name"), trust_remote_code=True)
processor = AutoProcessor.from_pretrained(config.get("clip_model_name"), trust_remote_code=True)
audio_model = whisperx.load_model('tiny', 'cpu', compute_type="float32")
def generate_answers(img=None, aud = None, q = None, max_tokens = 30):
batch_size = 1
start_iq = tokenizer.encode("<iQ>")
end_iq = tokenizer.encode("</iQ>")
start_iq_embeds = torch.tensor(start_iq).repeat(batch_size, 1)
end_iq_embeds = torch.tensor(end_iq).repeat(batch_size, 1)
start_iq_embeds = phi2_model.model.embed_tokens(start_iq_embeds.to(config.get("device")))
end_iq_embeds = phi2_model.model.embed_tokens(end_iq_embeds.to(config.get("device")))
inputs_embeddings = []
inputs_embeddings.append(start_iq_embeds)
predicted_caption = torch.full((batch_size, max_tokens), 50256, dtype=torch.long, device=config.get('device'))
if img is not None:
images = processor(images=img, return_tensors="pt")['pixel_values'].to(config.get("device"))
images = {'pixel_values': images.to(config.get("device"))}
clip_outputs = clip_model(**images)
# remove cls token
images = clip_outputs.last_hidden_state[:, 1:, :]
image_embeddings = projection_layer(images).to(torch.float32)
inputs_embeddings.append(image_embeddings)
if aud is not None:
trans = audio_model.transcribe(aud)
audio_res = ""
for seg in trans['segments']:
audio_res += seg['text']
audio_res = audio_res.strip()
audio_tokens = tokenizer(audio_res,return_tensors="pt", return_attention_mask=False)['input_ids']
audio_embeds = phi2_model.model.embed_tokens(audio_tokens.to(config.get("device")))
inputs_embeddings.append(audio_embeds)
if q!='':
ques = tokenizer(q, return_tensors="pt", return_attention_mask=False)['input_ids']
q_embeds = phi2_model.model.embed_tokens(ques.to(config.get("device")))
inputs_embeddings.append(q_embeds)
inputs_embeddings.append(end_iq_embeds)
# Combine embeddings
combined_embeds = torch.cat(inputs_embeddings, dim=1)
predicted_caption = phi2_model.generate(inputs_embeds=combined_embeds,
max_new_tokens=max_tokens,
return_dict_in_generate = True)
# for pos in range(max_tokens - 1):
# model_output_logits = phi2_model.forward(inputs_embeds = combined_embeds)['logits']
# print(model_output_logits.shape)
# predicted_word_token_logits = model_output_logits[:, -1, :].unsqueeze(1)
# predicted_word_token = torch.argmax(predicted_word_token_logits, dim = -1)
# predicted_caption[:, pos] = predicted_word_token.view(1,-1).to('cpu')
# print(predicted_caption)
# next_token_embeds = phi2_model.model.embed_tokens(predicted_word_token)
# combined_embeds = torch.cat([combined_embeds, next_token_embeds], dim=1)
# print("combined_embeds", combined_embeds.shape)
# predicted_captions_decoded = tokenizer.batch_decode(predicted_caption,ignore_index = 50256)[0]
predicted_captions_decoded =tokenizer.batch_decode(predicted_caption.sequences[:, 1:])[0]
predicted_captions_decoded = predicted_captions_decoded.replace("<|endoftext|>","")
return predicted_captions_decoded
with gr.Blocks() as demo:
gr.Markdown(
"""
# TAI2T Model(Text, Audio, Image to Text Model)
Multimodel GPT with inputs as Image, Audio, Text with output as Text.
"""
)
with gr.Row():
with gr.Column():
image = gr.Image(label='Image', type="pil", value=None)
audio_q = gr.Audio(label="Audio Question", value=None, sources=['microphone', 'upload'], type='filepath')
question = gr.Text(label ='Question?', value=None)
max_tokens = gr.Slider(1, 50, value=10, step=1, label="Max tokens")
with gr.Row():
answer = gr.Text(label ='Answer')
with gr.Row():
submit = gr.Button("Submit")
submit.click(generate_answers, inputs=[image, audio_q, question, max_tokens], outputs=[answer])
clear_btn = gr.ClearButton([image, audio_q, question, max_tokens, answer])
if __name__ == "__main__":
demo.launch(share=True, debug=True) |