sanjanatule commited on
Commit
9c48aca
·
verified ·
1 Parent(s): b6bd78d

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +109 -0
app.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import peft
3
+ from peft import LoraConfig
4
+ from transformers import AutoTokenizer,BitsAndBytesConfig, AutoModelForCausalLM, CLIPVisionModel, AutoProcessor
5
+ import torch
6
+ from PIL import Image
7
+ import requests
8
+ import numpy as np
9
+
10
+ clip_model_name = "openai/clip-vit-base-patch32"
11
+ phi_model_name = "microsoft/phi-2"
12
+ tokenizer = AutoTokenizer.from_pretrained(phi_model_name, trust_remote_code=True)
13
+ processor = AutoProcessor.from_pretrained(clip_model_name)
14
+ tokenizer.pad_token = tokenizer.eos_token
15
+ IMAGE_TOKEN_ID = 23893 # token for word comment
16
+ device = "cuda" if torch.cuda.is_available() else "cpu"
17
+ clip_embed = 768
18
+ phi_embed = 2560
19
+
20
+ # models
21
+ clip_model = CLIPVisionModel.from_pretrained(clip_model_name).to(device)
22
+ projection = torch.nn.Linear(clip_embed, phi_embed).to(device)
23
+ bnb_config = BitsAndBytesConfig(
24
+ load_in_4bit=True,
25
+ bnb_4bit_quant_type="nf4",
26
+ bnb_4bit_compute_dtype=torch.float16,)
27
+
28
+ phi_model = AutoModelForCausalLM.from_pretrained(
29
+ phi_model_name,
30
+ torch_dtype=torch.float32,
31
+ quantization_config=bnb_config,
32
+ trust_remote_code=True
33
+ )
34
+ lora_alpha = 16
35
+ lora_dropout = 0.1
36
+ lora_r = 64
37
+ peft_config = LoraConfig(
38
+ lora_alpha=lora_alpha,
39
+ lora_dropout=lora_dropout,
40
+ r=lora_r,
41
+ bias="none",
42
+ task_type="CAUSAL_LM",
43
+ target_modules=[
44
+ "q_proj",
45
+ 'k_proj',
46
+ 'v_proj',
47
+ 'fc1',
48
+ 'fc2'
49
+ ]
50
+ )
51
+ peft_model = peft.get_peft_model(phi_model, peft_config).to(device)
52
+
53
+ # load weights
54
+ model_to_merge = peft_model.from_pretrained(phi_model,'./model_chkpt/lora_adaptor')
55
+ merged_model = model_to_merge.merge_and_unload()
56
+ projection.load_state_dict(torch.load('./model_chkpt/step2_projection.pth'))
57
+
58
+ def model_generate_ans(img,val_q):
59
+
60
+ max_generate_length = 100
61
+
62
+ # image
63
+ image_processed = processor(images=img, return_tensors="pt").to(device)
64
+ clip_val_outputs = clip_model(**image_processed).last_hidden_state[:,1:,:]
65
+ val_image_embeds = projection(clip_val_outputs).to(torch.float16)
66
+
67
+ img_token_tensor = torch.tensor(IMAGE_TOKEN_ID).to(device)
68
+ img_token_embeds = peft_model.model.model.embed_tokens(img_token_tensor).unsqueeze(0).unsqueeze(0)
69
+
70
+ val_q_tokenised = tokenizer(val_q, return_tensors="pt", return_attention_mask=False)['input_ids'].squeeze(0)
71
+ val_q_embeds = peft_model.model.model.embed_tokens(val_q_tokenised).unsqueeze(0)
72
+
73
+ val_combined_embeds = torch.cat([val_image_embeds, img_token_embeds, val_q_embeds], dim=1) # 4, 69, 2560
74
+
75
+ predicted_caption = torch.full((1,max_generate_length),50256)
76
+
77
+ for g in range(max_generate_length):
78
+ phi_output_logits = peft_model(inputs_embeds=val_combined_embeds)['logits'] # 4, 69, 51200
79
+ predicted_word_token_logits = phi_output_logits[:, -1, :].unsqueeze(1) # 4,1,51200
80
+ predicted_word_token = torch.argmax(predicted_word_token_logits, dim = -1) # 4,1
81
+ predicted_caption[:,g] = predicted_word_token.view(1,-1).to('cpu')
82
+
83
+ predicted_captions_decoded = tokenizer.batch_decode(predicted_caption,ignore_index = 50256)
84
+
85
+ return predicted_captions_decoded
86
+
87
+
88
+ with gr.Blocks() as demo:
89
+
90
+ gr.Markdown(
91
+ """
92
+ # Chat with MultiModal GPT !
93
+ Build using combining clip model and phi-2 model.
94
+ """
95
+ )
96
+
97
+ # app GUI
98
+ with gr.Row():
99
+ with gr.Column():
100
+ img_input = gr.Image(label='Image')
101
+ img_question = gr.Text(label ='Question')
102
+ with gr.Column():
103
+ img_answer = gr.Text(label ='Answer')
104
+
105
+ section_btn = gr.Button("Submit")
106
+ section_btn.click(model_generate_ans, inputs=[img_input,img_question], outputs=[img_answer])
107
+
108
+ if __name__ == "__main__":
109
+ demo.launch()