Vasudevakrishna commited on
Commit
91e0ece
·
verified ·
1 Parent(s): 912afa9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -18
app.py CHANGED
@@ -1,41 +1,60 @@
1
- import gradio as gr
2
  import torch
3
  import whisperx
4
- from model import MainQLoraModel
 
5
  from configs import get_config_phase2
6
- from transformers import AutoTokenizer, AutoProcessor
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
- # get config
9
- config = get_config_phase2()
10
  # tokenizer
11
  tokenizer = AutoTokenizer.from_pretrained(config.get("phi2_model_name"), trust_remote_code=True)
12
  processor = AutoProcessor.from_pretrained(config.get("clip_model_name"), trust_remote_code=True)
13
- llmModel = MainQLoraModel(tokenizer, config).to(config.get("device"))
14
- audio_model = whisperx.load_model('tiny', 'cpu', compute_type="float16")
15
 
16
 
17
  def generate_answers(img=None, aud = None, q = None, max_tokens = 30):
 
18
 
19
  batch_size = 1
20
  start_iq = tokenizer.encode("<iQ>")
21
  end_iq = tokenizer.encode("</iQ>")
22
  start_iq_embeds = torch.tensor(start_iq).repeat(batch_size, 1)
23
  end_iq_embeds = torch.tensor(end_iq).repeat(batch_size, 1)
24
- start_iq_embeds = llmModel.phi2_model.model.model.embed_tokens(start_iq_embeds.to(config.get("device")))
25
- end_iq_embeds = llmModel.phi2_model.model.model.embed_tokens(end_iq_embeds.to(config.get("device")))
26
 
27
  inputs_embeddings = []
28
  inputs_embeddings.append(start_iq_embeds)
29
 
30
- predicted_caption = torch.full((batch_size, max_tokens), llmModel.EOS_TOKEN_ID, dtype=torch.long, device=config.get('device'))
31
 
32
- if images is not None:
33
- images = processor(images=img, return_tensors="pt").to(config.get("device"))
34
  images = {'pixel_values': images.to(config.get("device"))}
35
- clip_outputs = llmModel.clip_model(**images)
36
  # remove cls token
37
  images = clip_outputs.last_hidden_state[:, 1:, :]
38
- image_embeddings = llmModel.projection_layer(images).to(torch.float16)
39
  inputs_embeddings.append(image_embeddings)
40
 
41
  if aud is not None:
@@ -44,13 +63,14 @@ def generate_answers(img=None, aud = None, q = None, max_tokens = 30):
44
  for seg in trans['segments']:
45
  audio_res += seg['text']
46
  audio_res = audio_res.strip()
 
47
  audio_tokens = tokenizer(q,return_tensors="pt", return_attention_mask=False)['input_ids']
48
- audio_embeds = llmModel.phi2_model.model.model.embed_tokens(audio_tokens.to(config.get("device")))
49
  inputs_embeddings.append(audio_embeds)
50
 
51
  if q is not None:
52
  ques = tokenizer(q, return_tensors="pt", return_attention_mask=False)['input_ids']
53
- q_embeds = llmModel.phi2_model.model.model.embed_tokens(ques.to(config.get("device")))
54
  inputs_embeddings.append(q_embeds)
55
 
56
  inputs_embeddings.append(end_iq_embeds)
@@ -58,11 +78,12 @@ def generate_answers(img=None, aud = None, q = None, max_tokens = 30):
58
  combined_embeds = torch.cat(inputs_embeddings, dim=1)
59
 
60
  for pos in range(max_tokens - 1):
61
- model_output_logits = llmModel.phi2_model.forward(inputs_embeds = combined_embeds)['logits']
 
62
  predicted_word_token_logits = model_output_logits[:, -1, :].unsqueeze(1)
63
  predicted_word_token = torch.argmax(predicted_word_token_logits, dim = -1)
64
  predicted_caption[:, pos] = predicted_word_token.view(1,-1).to('cpu')
65
- next_token_embeds = llmModel.phi2_model.model.model.embed_tokens(predicted_word_token)
66
  combined_embeds = torch.cat([combined_embeds, next_token_embeds], dim=1)
67
  predicted_captions_decoded = tokenizer.batch_decode(predicted_caption,ignore_index = 50256)[0]
68
  return predicted_captions_decoded
 
 
1
  import torch
2
  import whisperx
3
+ import gradio as gr
4
+ from peft import PeftModel
5
  from configs import get_config_phase2
6
+ from transformers import AutoTokenizer, AutoProcessor, CLIPVisionModel, AutoModelForCausalLM
7
+
8
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
9
+
10
+ clip_model = CLIPVisionModel.from_pretrained(config.get("clip_model_name"))
11
+
12
+ base_model = AutoModelForCausalLM.from_pretrained(
13
+ config.get("phi2_model_name"),
14
+ low_cpu_mem_usage=True,
15
+ return_dict=True,
16
+ torch_dtype=torch.float32,
17
+ trust_remote_code=True
18
+ )
19
+
20
+
21
+ ckpts = "ckpts/Qlora_adaptor/"
22
+ phi2_model = PeftModel.from_pretrained(base_model, ckpts)
23
+ phi2_model = phi2_model.merge_and_unload().to(device)
24
+
25
+ projection_layer = torch.nn.Linear(config.get("clip_embed"), config.get("phi_embed"))
26
+ projection_layer.load_state_dict(torch.load('./ckpts/model_phase2.pth', map_location=config.get("device")))
27
 
 
 
28
  # tokenizer
29
  tokenizer = AutoTokenizer.from_pretrained(config.get("phi2_model_name"), trust_remote_code=True)
30
  processor = AutoProcessor.from_pretrained(config.get("clip_model_name"), trust_remote_code=True)
31
+
32
+ audio_model = whisperx.load_model('tiny', 'cpu', compute_type="float32")
33
 
34
 
35
  def generate_answers(img=None, aud = None, q = None, max_tokens = 30):
36
+ print(img, aud, q)
37
 
38
  batch_size = 1
39
  start_iq = tokenizer.encode("<iQ>")
40
  end_iq = tokenizer.encode("</iQ>")
41
  start_iq_embeds = torch.tensor(start_iq).repeat(batch_size, 1)
42
  end_iq_embeds = torch.tensor(end_iq).repeat(batch_size, 1)
43
+ start_iq_embeds = phi2_model.model.embed_tokens(start_iq_embeds.to(config.get("device")))
44
+ end_iq_embeds = phi2_model.model.embed_tokens(end_iq_embeds.to(config.get("device")))
45
 
46
  inputs_embeddings = []
47
  inputs_embeddings.append(start_iq_embeds)
48
 
49
+ predicted_caption = torch.full((batch_size, max_tokens), 50256, dtype=torch.long, device=config.get('device'))
50
 
51
+ if img is not None:
52
+ images = processor(images=img, return_tensors="pt")['pixel_values'].to(device)
53
  images = {'pixel_values': images.to(config.get("device"))}
54
+ clip_outputs = clip_model(**images)
55
  # remove cls token
56
  images = clip_outputs.last_hidden_state[:, 1:, :]
57
+ image_embeddings = projection_layer(images).to(torch.float32)
58
  inputs_embeddings.append(image_embeddings)
59
 
60
  if aud is not None:
 
63
  for seg in trans['segments']:
64
  audio_res += seg['text']
65
  audio_res = audio_res.strip()
66
+ print(audio_res)
67
  audio_tokens = tokenizer(q,return_tensors="pt", return_attention_mask=False)['input_ids']
68
+ audio_embeds = phi2_model.model.embed_tokens(audio_tokens.to(config.get("device")))
69
  inputs_embeddings.append(audio_embeds)
70
 
71
  if q is not None:
72
  ques = tokenizer(q, return_tensors="pt", return_attention_mask=False)['input_ids']
73
+ q_embeds = phi2_model.model.embed_tokens(ques.to(config.get("device")))
74
  inputs_embeddings.append(q_embeds)
75
 
76
  inputs_embeddings.append(end_iq_embeds)
 
78
  combined_embeds = torch.cat(inputs_embeddings, dim=1)
79
 
80
  for pos in range(max_tokens - 1):
81
+ model_output_logits = phi2_model.forward(inputs_embeds = combined_embeds)['logits']
82
+ print(model_output_logits.shape)
83
  predicted_word_token_logits = model_output_logits[:, -1, :].unsqueeze(1)
84
  predicted_word_token = torch.argmax(predicted_word_token_logits, dim = -1)
85
  predicted_caption[:, pos] = predicted_word_token.view(1,-1).to('cpu')
86
+ next_token_embeds = phi2_model.model.embed_tokens(predicted_word_token)
87
  combined_embeds = torch.cat([combined_embeds, next_token_embeds], dim=1)
88
  predicted_captions_decoded = tokenizer.batch_decode(predicted_caption,ignore_index = 50256)[0]
89
  return predicted_captions_decoded