File size: 4,811 Bytes
a99f13a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import torch
import whisperx
import gradio as gr
from peft import PeftModel
from configs import get_config_phase2
from transformers import AutoTokenizer, AutoProcessor, CLIPVisionModel, AutoModelForCausalLM

config = get_config_phase2() 

clip_model = CLIPVisionModel.from_pretrained(config.get("clip_model_name"))

base_model = AutoModelForCausalLM.from_pretrained(
    config.get("phi2_model_name"),
    low_cpu_mem_usage=True,
    return_dict=True,
    torch_dtype=torch.float32,
    trust_remote_code=True
)


ckpts = "ckpts/Qlora_adaptor/"
phi2_model = PeftModel.from_pretrained(base_model, ckpts)
phi2_model = phi2_model.merge_and_unload().to(config.get("device"))

projection_layer = torch.nn.Linear(config.get("clip_embed"), config.get("phi_embed"))
projection_layer.load_state_dict(torch.load('./ckpts/model_phase2.pth', map_location=config.get("device")))

# tokenizer
tokenizer  = AutoTokenizer.from_pretrained(config.get("phi2_model_name"), trust_remote_code=True)
processor  = AutoProcessor.from_pretrained(config.get("clip_model_name"), trust_remote_code=True)

audio_model = whisperx.load_model('tiny', 'cpu', compute_type="float32")


def generate_answers(img=None, aud = None, q = None, max_tokens = 30):
    batch_size = 1
    start_iq = tokenizer.encode("<iQ>")
    end_iq = tokenizer.encode("</iQ>")
    start_iq_embeds = torch.tensor(start_iq).repeat(batch_size, 1)
    end_iq_embeds = torch.tensor(end_iq).repeat(batch_size, 1)
    start_iq_embeds = phi2_model.model.embed_tokens(start_iq_embeds.to(config.get("device")))
    end_iq_embeds = phi2_model.model.embed_tokens(end_iq_embeds.to(config.get("device")))
    
    inputs_embeddings = []
    inputs_embeddings.append(start_iq_embeds)

    predicted_caption = torch.full((batch_size, max_tokens), 50256, dtype=torch.long, device=config.get('device'))
    
    if img is not None:
        images = processor(images=img, return_tensors="pt")['pixel_values'].to(config.get('device'))
        images = {'pixel_values': images.to(config.get("device"))}
        clip_outputs = clip_model(**images)
        # remove cls token
        images = clip_outputs.last_hidden_state[:, 1:, :]
        image_embeddings = projection_layer(images).to(torch.float32)
        inputs_embeddings.append(image_embeddings)
    
    if aud is not None:
        trans = audio_model.transcribe(aud)
        audio_res = ""
        for seg in trans['segments']:
            audio_res += seg['text']
        audio_res = audio_res.strip()
        audio_tokens = tokenizer(audio_res,return_tensors="pt", return_attention_mask=False)['input_ids']
        audio_embeds = phi2_model.model.embed_tokens(audio_tokens.to(config.get("device")))
        inputs_embeddings.append(audio_embeds)
        
    if q is not None:
        ques = tokenizer(q, return_tensors="pt", return_attention_mask=False)['input_ids']
        q_embeds = phi2_model.model.embed_tokens(ques.to(config.get("device")))
        inputs_embeddings.append(q_embeds)
        
    inputs_embeddings.append(end_iq_embeds)
    # Combine embeddings
    combined_embeds  = torch.cat(inputs_embeddings, dim=1)

    for pos in range(max_tokens - 1):
        model_output_logits = phi2_model.forward(inputs_embeds = combined_embeds)['logits']
        predicted_word_token_logits = model_output_logits[:, -1, :].unsqueeze(1)
        predicted_word_token = torch.argmax(predicted_word_token_logits, dim = -1)
        predicted_caption[:, pos] = predicted_word_token.view(1,-1).to('cpu')
        next_token_embeds = phi2_model.model.embed_tokens(predicted_word_token)
        combined_embeds   = torch.cat([combined_embeds, next_token_embeds], dim=1)
    predicted_captions_decoded = tokenizer.batch_decode(predicted_caption,ignore_index = 50256)[0]
    predicted_captions_decoded = predicted_captions_decoded.replace("<|endoftext|>","")
    return predicted_captions_decoded


with gr.Blocks() as demo:

    gr.Markdown(
    """
    # TAI2T Model(Text, Audio, Image to Text Model)
    Multimodel GPT with inputs as Image, Audio, Text with output as Text.
    """
    )

    with gr.Row():
        with gr.Column():
            image = gr.Image(label='Image', type="pil", value=None)
            audio_q = gr.Audio(label="Audio Question", value=None, sources=['microphone', 'upload'], type='filepath')
            question = gr.Text(label ='Question?', value=None)
            max_tokens = gr.Slider(1, 50, value=10, step=1, label="Max tokens")
    with gr.Row():
        answer   = gr.Text(label ='Answer')
    with gr.Row():
        submit = gr.Button("Submit")
        submit.click(generate_answers, inputs=[image, audio_q, question, max_tokens], outputs=[answer])
        clear_btn = gr.ClearButton([image, audio_q, question, max_tokens, answer])
    
if __name__ == "__main__":
    
    demo.launch(share=True)