Spaces:
Build error
Build error
| import gradio as gui | |
| import peft | |
| from peft import LoraConfig | |
| from transformers import AutoTokenizer,BitsAndBytesConfig, AutoModelForCausalLM, CLIPVisionModel, AutoProcessor | |
| import torch | |
| from peft import PeftModel | |
| import torch.nn as nn | |
| import whisper | |
| import os | |
| clip_model_name = "openai/clip-vit-base-patch32" | |
| phi_model_name = "microsoft/phi-2" | |
| tokenizer = AutoTokenizer.from_pretrained(phi_model_name, trust_remote_code=True) | |
| processor = AutoProcessor.from_pretrained(clip_model_name) | |
| tokenizer.pad_token = tokenizer.eos_token | |
| IMAGE_TOKEN_ID = 23893 # token for word comment | |
| QA_TOKEN_ID = 50295 # token for qa | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| clip_embed = 768 | |
| phi_embed = 2560 | |
| audio_batch_size = 16 | |
| current_dir = os.getcwd() | |
| class SimpleResBlock(nn.Module): | |
| def __init__(self, phi_embed): | |
| super().__init__() | |
| self.pre_norm = nn.LayerNorm(phi_embed) | |
| self.proj = nn.Sequential( | |
| nn.Linear(phi_embed, phi_embed), | |
| nn.GELU(), | |
| nn.Linear(phi_embed, phi_embed) | |
| ) | |
| def forward(self, x): | |
| x = self.pre_norm(x) | |
| return x + self.proj(x) | |
| # models | |
| clip_model = CLIPVisionModel.from_pretrained(clip_model_name).to(device) | |
| projection = torch.nn.Linear(clip_embed, phi_embed).to(device) | |
| resblock = SimpleResBlock(phi_embed).to(device) | |
| phi_model = AutoModelForCausalLM.from_pretrained(phi_model_name,trust_remote_code=True).to(device) | |
| audio_model = whisper.load_model("tiny", device=device) | |
| lora_adaptor_path = os.path.join(current_dir, 'model_chkpt', 'lora_adaptor') | |
| projection_path = os.path.join(current_dir, 'model_chkpt', 'step2_projection.pth') | |
| resblock_path = os.path.join(current_dir, 'model_chkpt', 'step2_resblock.pth') | |
| # load weights | |
| model_to_merge = PeftModel.from_pretrained(phi_model,lora_adaptor_path, local_files_only=True, device_map={'': device}) | |
| merged_model = model_to_merge.merge_and_unload() | |
| projection.load_state_dict(torch.load(projection_path,map_location=torch.device(device))) | |
| resblock.load_state_dict(torch.load(resblock_path,map_location=torch.device(device))) | |
| def generate_response(img=None,img_audio=None,val_q=None): | |
| max_generate_length = 100 | |
| val_combined_embeds = [] | |
| with torch.no_grad(): | |
| # image | |
| if img is not None: | |
| image_processed = processor(images=img, return_tensors="pt").to(device) | |
| clip_val_outputs = clip_model(**image_processed).last_hidden_state[:,1:,:] | |
| val_image_embeds = projection(clip_val_outputs) | |
| val_image_embeds = resblock(val_image_embeds).to(torch.float16) | |
| img_token_tensor = torch.tensor(IMAGE_TOKEN_ID).to(device) | |
| img_token_embeds = merged_model.model.embed_tokens(img_token_tensor).unsqueeze(0).unsqueeze(0) | |
| val_combined_embeds.append(val_image_embeds) | |
| val_combined_embeds.append(img_token_embeds) | |
| # audio | |
| if img_audio is not None: | |
| audio_result = audio_model.transcribe(img_audio) | |
| audio_text = '' | |
| for seg in audio_result['segments']: | |
| audio_text += seg['text'] | |
| audio_text = audio_text.strip() | |
| audio_tokens = tokenizer(audio_text, return_tensors="pt", return_attention_mask=False)['input_ids'].squeeze(0).to(device) | |
| audio_embeds = merged_model.model.embed_tokens(audio_tokens).unsqueeze(0) | |
| val_combined_embeds.append(audio_embeds) | |
| # text question | |
| if len(val_q) != 0: | |
| val_q_tokenised = tokenizer(val_q, return_tensors="pt", return_attention_mask=False)['input_ids'].squeeze(0).to(device) | |
| val_q_embeds = merged_model.model.embed_tokens(val_q_tokenised).unsqueeze(0) | |
| val_combined_embeds.append(val_q_embeds) | |
| if img_audio is not None or len(val_q) != 0: # add QA Token | |
| QA_token_tensor = torch.tensor(QA_TOKEN_ID).to(device) | |
| QA_token_embeds = merged_model.model.embed_tokens(QA_token_tensor).unsqueeze(0).unsqueeze(0) | |
| val_combined_embeds.append(QA_token_embeds) | |
| val_combined_embeds = torch.cat(val_combined_embeds,dim=1) | |
| predicted_caption = merged_model.generate(inputs_embeds=val_combined_embeds, | |
| max_new_tokens=max_generate_length, | |
| return_dict_in_generate = True) | |
| predicted_captions_decoded = tokenizer.batch_decode(predicted_caption.sequences[:, 1:])[0] | |
| predicted_captions_decoded = predicted_captions_decoded.replace("<|endoftext|>", "") | |
| return predicted_captions_decoded | |
| # Gradio interface setup with added styling | |
| with gui.Blocks() as app_interface: | |
| with gui.Row(): | |
| with gui.Column(): | |
| image_input = gui.Image(label='Upload Image', type="pil") | |
| with gui.Column(): | |
| audio_input = gui.Audio(label="Audio Input", sources=['microphone', 'upload'], type='filepath') | |
| text_input = gui.Text(label='Enter Text', placeholder="Type your query here...") | |
| with gui.Row(): | |
| output_response = gui.Textbox(label='Generated Response', placeholder="Response will appear here...", lines=5) | |
| submit_button = gui.Button("Generate Response", variant="primary") | |
| submit_button.click(generate_response, inputs=[image_input, audio_input, text_input], outputs=output_response) | |
| if __name__ == "__main__": | |
| app_interface.launch(share=True) |