Van commited on
Commit
73fcb8d
Β·
1 Parent(s): 21ed6fb
app.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import peft
3
+ from peft import LoraConfig
4
+ from transformers import AutoTokenizer,BitsAndBytesConfig, AutoModelForCausalLM, CLIPVisionModel, AutoProcessor
5
+ import torch
6
+ from peft import PeftModel
7
+ import torch.nn as nn
8
+ import whisperx
9
+ import os
10
+ clip_model_name = "openai/clip-vit-base-patch32"
11
+ phi_model_name = "microsoft/phi-2"
12
+ # Tokenizers and Processors: The tokenizer tokenizes text, and the processor handles preprocessing for images.
13
+ # Embedding sizes: clip_embed (768) is for the CLIP model, and phi_embed (2560) is for the Phi-2 model.
14
+ # Device: It selects CUDA if a GPU is available, otherwise, it uses the CPU.
15
+ # IMAGE_TOKEN_ID: Token ID reserved for images.
16
+ tokenizer = AutoTokenizer.from_pretrained(phi_model_name, trust_remote_code=True)
17
+ processor = AutoProcessor.from_pretrained(clip_model_name)
18
+ tokenizer.pad_token = tokenizer.eos_token
19
+ IMAGE_TOKEN_ID = 23893 # token for word comment
20
+ device = "cuda" if torch.cuda.is_available() else "cpu"
21
+ clip_embed = 768
22
+ phi_embed = 2560
23
+ compute_type = "float32"
24
+ audio_batch_size = 16
25
+
26
+ # This defines a simple residual block that uses a layer normalization (LayerNorm) followed by two linear layers with a GELU activation function in between.
27
+ # The block is used to add learned transformations to the embeddings, which helps in stabilizing learning and improving generalization.
28
+ class SimpleResBlock(nn.Module):
29
+ def __init__(self, phi_embed):
30
+ super().__init__()
31
+ self.pre_norm = nn.LayerNorm(phi_embed)
32
+ self.proj = nn.Sequential(
33
+ nn.Linear(phi_embed, phi_embed),
34
+ nn.GELU(),
35
+ nn.Linear(phi_embed, phi_embed)
36
+ )
37
+ def forward(self, x):
38
+ x = self.pre_norm(x)
39
+ return x + self.proj(x)
40
+
41
+ # models
42
+ # CLIP Vision Model: Pretrained on visual tasks, outputs image embeddings.
43
+ # Projection Layer: Projects the clip_embed (768) dimensions to phi_embed (2560) to match the embedding sizes for downstream tasks.
44
+ # Residual Block: Uses the custom SimpleResBlock to process the embeddings further.
45
+ # Phi-2 Model: The language model handles text generation tasks.
46
+ clip_model = CLIPVisionModel.from_pretrained(clip_model_name).to(device)
47
+ projection = torch.nn.Linear(clip_embed, phi_embed).to(device)
48
+ resblock = SimpleResBlock(phi_embed).to(device)
49
+ phi_model = AutoModelForCausalLM.from_pretrained(phi_model_name,trust_remote_code=True).to(device)
50
+ audio_model = whisperx.load_model("tiny", device, compute_type=compute_type, asr_options={'max_new_tokens': 2048, 'clip_timestamps': True, 'hallucination_silence_threshold': 0.25})
51
+
52
+ # load weights
53
+ # LoRA Weights: The LoRA-adapted model merges with the Phi-2 model for fine-tuning.
54
+ # Loading Finetuned Layers: The pre-trained weights for the projection layer and residual block are loaded for further use.
55
+ model_to_merge = PeftModel.from_pretrained(phi_model,os.path.join(os.getcwd(), 'model_chkpt/lora_adaptor'))
56
+ merged_model = model_to_merge.merge_and_unload()
57
+ projection.load_state_dict(torch.load(os.path.join(os.getcwd(),'model_chkpt/finetunned_projection.pth'),map_location=torch.device(device)))
58
+ resblock.load_state_dict(torch.load(os.path.join(os.getcwd(),'model_chkpt/finetuned_resblock.pth'),map_location=torch.device(device)))
59
+
60
+
61
+ # Image Handling: Extracts image embeddings, passes through CLIP and a projection layer.
62
+ # Audio Handling: Transcribes audio with WhisperX, tokenizes it, and embeds the tokens.
63
+ # Text Handling: Tokenizes the text query and embeds it.
64
+ # Generating Response: The model generates tokens sequentially, combining inputs from images, audio, and text, and predicting the next token until it generates a full response.
65
+ def model_generate_ans(img=None,img_audio=None,val_q=None):
66
+
67
+ max_generate_length = 100
68
+ val_combined_embeds = []
69
+
70
+ with torch.no_grad():
71
+
72
+ # image
73
+ if img is not None:
74
+ image_processed = processor(images=img, return_tensors="pt").to(device)
75
+ clip_val_outputs = clip_model(**image_processed).last_hidden_state[:,1:,:]
76
+ val_image_embeds = projection(clip_val_outputs)
77
+ val_image_embeds = resblock(val_image_embeds).to(torch.float16)
78
+
79
+ img_token_tensor = torch.tensor(IMAGE_TOKEN_ID).to(device)
80
+ img_token_embeds = merged_model.model.embed_tokens(img_token_tensor).unsqueeze(0).unsqueeze(0)
81
+
82
+ val_combined_embeds.append(val_image_embeds)
83
+ val_combined_embeds.append(img_token_embeds)
84
+
85
+ # audio
86
+ if img_audio is not None:
87
+ audio_result = audio_model.transcribe(img_audio)
88
+ audio_text = ''
89
+ for seg in audio_result['segments']:
90
+ audio_text += seg['text']
91
+ audio_text = audio_text.strip()
92
+ audio_tokens = tokenizer(audio_text, return_tensors="pt", return_attention_mask=False)['input_ids'].squeeze(0).to(device)
93
+ audio_embeds = merged_model.model.embed_tokens(audio_tokens).unsqueeze(0)
94
+ val_combined_embeds.append(audio_embeds)
95
+
96
+ # text question
97
+ if len(val_q) != 0:
98
+ val_q_tokenised = tokenizer(val_q, return_tensors="pt", return_attention_mask=False)['input_ids'].squeeze(0).to(device)
99
+ val_q_embeds = merged_model.model.embed_tokens(val_q_tokenised).unsqueeze(0)
100
+ val_combined_embeds.append(val_q_embeds)
101
+
102
+ val_combined_embeds = torch.cat(val_combined_embeds,dim=1)
103
+
104
+ #val_combined_embeds = torch.cat([val_image_embeds, img_token_embeds, val_q_embeds], dim=1) # 4, 69, 2560
105
+ predicted_caption = torch.full((1,max_generate_length),50256).to(device)
106
+
107
+ for g in range(max_generate_length):
108
+ phi_output_logits = merged_model(inputs_embeds=val_combined_embeds)['logits'] # 4, 69, 51200
109
+ predicted_word_token_logits = phi_output_logits[:, -1, :].unsqueeze(1) # 4,1,51200
110
+ predicted_word_token = torch.argmax(predicted_word_token_logits, dim = -1) # 4,1
111
+ predicted_caption[:,g] = predicted_word_token.view(1,-1)
112
+ next_token_embeds = phi_model.model.embed_tokens(predicted_word_token) # 4,1,2560
113
+ val_combined_embeds = torch.cat([val_combined_embeds, next_token_embeds], dim=1)
114
+
115
+ predicted_captions_decoded = tokenizer.batch_decode(predicted_caption,ignore_index = 50256)[0]
116
+
117
+ # Split the string at the first occurrence of <|endoftext|>
118
+ result = predicted_captions_decoded.split('<|endoftext|>')[0]
119
+ return result.strip() # Strip any trailing spaces or newlines
120
+
121
+ #return predicted_captions_decoded
122
+
123
+
124
+ with gr.Blocks() as demo:
125
+
126
+ # Add custom CSS stylesheet within Markdown
127
+ gr.Markdown(
128
+ """
129
+ <style>
130
+ /* General Layout */
131
+ body {
132
+ font-family: 'Arial', sans-serif;
133
+ background-color: #ffe4e1;
134
+ margin: 0;
135
+ padding: 0;
136
+ }
137
+ /* Header */
138
+ h1, h2, h3 {
139
+ text-align: center;
140
+ color: #3a3a3a;
141
+ font-weight: bold;
142
+ }
143
+ gr-Markdown h1 {
144
+ font-size: 28px;
145
+ color: #a3d5d3; /* Soft pastel teal for the header */
146
+ }
147
+ /* Container and Columns */
148
+ .gr-row {
149
+ display: flex;
150
+ justify-content: center;
151
+ margin: 20px 0;
152
+ }
153
+ .gr-column {
154
+ flex: 1;
155
+ margin: 0 10px;
156
+ padding: 10px;
157
+ box-shadow: 0px 0px 10px rgba(0, 0, 0, 0.05);
158
+ background-color: #f8f0fa; /* Pastel pink background for columns */
159
+ border-radius: 8px;
160
+ }
161
+ /* Input Components */
162
+ .gr-Image, .gr-Audio, .gr-Text {
163
+ width: 100%;
164
+ margin-bottom: 15px;
165
+ background-color: #fff5e1; /* Soft pastel yellow for inputs */
166
+ border: 1px solid #e3e3e3;
167
+ border-radius: 8px;
168
+ }
169
+ .gr-Image label, .gr-Audio label, .gr-Text label {
170
+ font-size: 16px;
171
+ font-weight: bold;
172
+ color: #8b8b8b;
173
+ }
174
+ /* Submit Button */
175
+ .gr-Button {
176
+ width: 100%;
177
+ background-color: #b2c7e1; /* Pastel blue button */
178
+ color: white;
179
+ padding: 10px;
180
+ font-size: 16px;
181
+ border: none;
182
+ border-radius: 5px;
183
+ cursor: pointer;
184
+ transition: background-color 0.3s ease;
185
+ }
186
+ .gr-Button:hover {
187
+ background-color: #9db6d3; /* Darker pastel blue on hover */
188
+ }
189
+ /* Text Output */
190
+ .gr-Text {
191
+ font-size: 16px;
192
+ color: #333;
193
+ min-height: 100px;
194
+ padding: 10px;
195
+ border: 1px solid #ddd;
196
+ border-radius: 5px;
197
+ background-color: #edf5e1; /* Light pastel green for the output text box */
198
+ }
199
+ /* Responsive Design */
200
+ @media (max-width: 768px) {
201
+ .gr-row {
202
+ flex-direction: column;
203
+ }
204
+ .gr-column {
205
+ margin: 10px 0;
206
+ }
207
+ }
208
+ </style>
209
+
210
+ # Engage with MultiModal GPT!
211
+ A seamless AI experience combining CLIP and Phi-2 models.
212
+ """
213
+ )
214
+
215
+ # app GUI
216
+ with gr.Row():
217
+ with gr.Column():
218
+ img_input = gr.Image(label='Image',type="pil")
219
+ img_audio = gr.Audio(label="Audio Query", sources=['microphone', 'upload'], type='filepath')
220
+ img_question = gr.Text(label ='Text Query')
221
+ with gr.Column():
222
+ img_answer = gr.Text(label ='Answer')
223
+
224
+ section_btn = gr.Button("Submit")
225
+ section_btn.click(model_generate_ans, inputs=[img_input,img_audio,img_question], outputs=[img_answer])
226
+
227
+ demo.launch()
model_chkpt/finetuned_resblock.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:11a26279751b1c92a8bf42360ee424976019a2a79549995030b941f5cdde3b9f
3
+ size 52472630
model_chkpt/finetunned_projection.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4ed7d9eeccd4d0e6db66bd78d58bbeb371d9e68f0a8bf4abf154936004bbbe6d
3
+ size 7876204
model_chkpt/lora_adaptor/adapter_config.json ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "alpha_pattern": {},
3
+ "auto_mapping": null,
4
+ "base_model_name_or_path": "microsoft/phi-2",
5
+ "bias": "none",
6
+ "fan_in_fan_out": false,
7
+ "inference_mode": true,
8
+ "init_lora_weights": true,
9
+ "layer_replication": null,
10
+ "layers_pattern": null,
11
+ "layers_to_transform": null,
12
+ "loftq_config": {},
13
+ "lora_alpha": 16,
14
+ "lora_dropout": 0.1,
15
+ "megatron_config": null,
16
+ "megatron_core": "megatron.core",
17
+ "modules_to_save": null,
18
+ "peft_type": "LORA",
19
+ "r": 64,
20
+ "rank_pattern": {},
21
+ "revision": null,
22
+ "target_modules": [
23
+ "gate_proj",
24
+ "k_proj",
25
+ "up_proj",
26
+ "down_proj",
27
+ "o_proj",
28
+ "v_proj",
29
+ "q_proj"
30
+ ],
31
+ "task_type": "CAUSAL_LM",
32
+ "use_dora": false,
33
+ "use_rslora": false
34
+ }
model_chkpt/lora_adaptor/adapter_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3fcda1f14a5c72b01440f752d4680078b4c591a6cc2106e49fb8f2dab8b85572
3
+ size 125855064
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ peft
3
+ accelerate
4
+ transformers==4.37
5
+ einops
6
+ git+https://github.com/m-bain/whisperx.git
7
+ bitsandbytes
8
+ wandb
9
+ ffmpeg
10
+ pydub
11
+ gradio