File size: 5,305 Bytes
6cda9a1
912afa9
91e0ece
 
2a9f4d0
91e0ece
 
1b67e19
91e0ece
 
 
 
 
 
 
 
 
 
 
 
 
 
2c8dad9
91e0ece
 
 
6cda9a1
2a9f4d0
 
 
91e0ece
 
6cda9a1
 
b89d6d1
2a9f4d0
 
 
 
 
91e0ece
 
2a9f4d0
 
 
6cda9a1
91e0ece
2a9f4d0
91e0ece
75dfe2d
2a9f4d0
91e0ece
6cda9a1
 
91e0ece
2a9f4d0
 
 
 
 
 
 
 
52e4280
91e0ece
2a9f4d0
 
75dfe2d
2a9f4d0
91e0ece
2a9f4d0
 
 
 
 
1cd5299
 
 
2a9f4d0
1cd5299
 
 
 
 
 
 
 
 
 
 
 
52e4280
2a9f4d0
 
 
 
 
 
 
 
 
 
 
 
 
52e4280
 
 
 
 
2a9f4d0
 
52e4280
 
 
 
2a9f4d0
 
52e4280
715a82c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import torch
import whisperx
import gradio as gr
from peft import PeftModel
from configs import get_config_phase2
from transformers import AutoTokenizer, AutoProcessor, CLIPVisionModel, AutoModelForCausalLM

config = get_config_phase2() 

clip_model = CLIPVisionModel.from_pretrained(config.get("clip_model_name"))

base_model = AutoModelForCausalLM.from_pretrained(
    config.get("phi2_model_name"),
    low_cpu_mem_usage=True,
    return_dict=True,
    torch_dtype=torch.float32,
    trust_remote_code=True
)


ckpts = "ckpts/Qlora_adaptor/"
phi2_model = PeftModel.from_pretrained(base_model, ckpts)
phi2_model = phi2_model.merge_and_unload().to(config.get("device"))

projection_layer = torch.nn.Linear(config.get("clip_embed"), config.get("phi_embed"))
projection_layer.load_state_dict(torch.load('./ckpts/model_phase2.pth', map_location=config.get("device")))

# tokenizer
tokenizer  = AutoTokenizer.from_pretrained(config.get("phi2_model_name"), trust_remote_code=True)
processor  = AutoProcessor.from_pretrained(config.get("clip_model_name"), trust_remote_code=True)

audio_model = whisperx.load_model('tiny', 'cpu', compute_type="float32")


def generate_answers(img=None, aud = None, q = None, max_tokens = 30):          
    batch_size = 1
    start_iq = tokenizer.encode("<iQ>")
    end_iq = tokenizer.encode("</iQ>")
    start_iq_embeds = torch.tensor(start_iq).repeat(batch_size, 1)
    end_iq_embeds = torch.tensor(end_iq).repeat(batch_size, 1)
    start_iq_embeds = phi2_model.model.embed_tokens(start_iq_embeds.to(config.get("device")))
    end_iq_embeds = phi2_model.model.embed_tokens(end_iq_embeds.to(config.get("device")))
    
    inputs_embeddings = []
    inputs_embeddings.append(start_iq_embeds)

    predicted_caption = torch.full((batch_size, max_tokens), 50256, dtype=torch.long, device=config.get('device'))
    
    if img is not None:
        images = processor(images=img, return_tensors="pt")['pixel_values'].to(config.get("device"))
        images = {'pixel_values': images.to(config.get("device"))}
        clip_outputs = clip_model(**images)
        # remove cls token
        images = clip_outputs.last_hidden_state[:, 1:, :]
        image_embeddings = projection_layer(images).to(torch.float32)
        inputs_embeddings.append(image_embeddings)
    
    if aud is not None:
        trans = audio_model.transcribe(aud)
        audio_res = ""
        for seg in trans['segments']:
            audio_res += seg['text']
        audio_res = audio_res.strip()
        audio_tokens = tokenizer(audio_res,return_tensors="pt", return_attention_mask=False)['input_ids']
        audio_embeds = phi2_model.model.embed_tokens(audio_tokens.to(config.get("device")))
        inputs_embeddings.append(audio_embeds)
        
    if q!='':
        ques = tokenizer(q, return_tensors="pt", return_attention_mask=False)['input_ids']
        q_embeds = phi2_model.model.embed_tokens(ques.to(config.get("device")))
        inputs_embeddings.append(q_embeds)
        
    inputs_embeddings.append(end_iq_embeds)
    # Combine embeddings
    combined_embeds  = torch.cat(inputs_embeddings, dim=1)
    predicted_caption = phi2_model.generate(inputs_embeds=combined_embeds,
                                                  max_new_tokens=max_tokens,
                                                  return_dict_in_generate = True)

    # for pos in range(max_tokens - 1):
    #     model_output_logits = phi2_model.forward(inputs_embeds = combined_embeds)['logits']
    #     print(model_output_logits.shape)
    #     predicted_word_token_logits = model_output_logits[:, -1, :].unsqueeze(1)
    #     predicted_word_token = torch.argmax(predicted_word_token_logits, dim = -1)
    #     predicted_caption[:, pos] = predicted_word_token.view(1,-1).to('cpu')
    #     print(predicted_caption)
    #     next_token_embeds = phi2_model.model.embed_tokens(predicted_word_token)
    #     combined_embeds   = torch.cat([combined_embeds, next_token_embeds], dim=1)
    #     print("combined_embeds", combined_embeds.shape)
    # predicted_captions_decoded = tokenizer.batch_decode(predicted_caption,ignore_index = 50256)[0]
    predicted_captions_decoded =tokenizer.batch_decode(predicted_caption.sequences[:, 1:])[0]
    predicted_captions_decoded = predicted_captions_decoded.replace("<|endoftext|>","")
    return predicted_captions_decoded


with gr.Blocks() as demo:

    gr.Markdown(
    """
    # TAI2T Model(Text, Audio, Image to Text Model)
    Multimodel GPT with inputs as Image, Audio, Text with output as Text.
    """
    )

    with gr.Row():
        with gr.Column():
            image = gr.Image(label='Image', type="pil", value=None)
            audio_q = gr.Audio(label="Audio Question", value=None, sources=['microphone', 'upload'], type='filepath')
            question = gr.Text(label ='Question?', value=None)
            max_tokens = gr.Slider(1, 50, value=10, step=1, label="Max tokens")
    with gr.Row():
        answer   = gr.Text(label ='Answer')
    with gr.Row():
        submit = gr.Button("Submit")
        submit.click(generate_answers, inputs=[image, audio_q, question, max_tokens], outputs=[answer])
        clear_btn = gr.ClearButton([image, audio_q, question, max_tokens, answer])
    
if __name__ == "__main__":
    
    demo.launch(share=True, debug=True)