Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -15,30 +15,47 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
| 15 |
clip_embed = 768
|
| 16 |
phi_embed = 2560
|
| 17 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
# models
|
| 19 |
clip_model = CLIPVisionModel.from_pretrained(clip_model_name).to(device)
|
| 20 |
projection = torch.nn.Linear(clip_embed, phi_embed).to(device)
|
|
|
|
|
|
|
| 21 |
phi_model = AutoModelForCausalLM.from_pretrained(phi_model_name,trust_remote_code=True).to(device)
|
| 22 |
|
| 23 |
# load weights
|
| 24 |
model_to_merge = PeftModel.from_pretrained(phi_model,'./model_chkpt/lora_adaptor')
|
| 25 |
merged_model = model_to_merge.merge_and_unload()
|
| 26 |
projection.load_state_dict(torch.load('./model_chkpt/step2_projection.pth',map_location=torch.device(device)))
|
|
|
|
| 27 |
|
| 28 |
def model_generate_ans(img,val_q):
|
| 29 |
|
| 30 |
max_generate_length = 100
|
| 31 |
|
| 32 |
# image
|
| 33 |
-
image_processed
|
| 34 |
clip_val_outputs = clip_model(**image_processed).last_hidden_state[:,1:,:]
|
| 35 |
-
val_image_embeds = projection(clip_val_outputs)
|
|
|
|
| 36 |
|
| 37 |
img_token_tensor = torch.tensor(IMAGE_TOKEN_ID).to(device)
|
| 38 |
-
img_token_embeds = merged_model.model.
|
| 39 |
|
| 40 |
val_q_tokenised = tokenizer(val_q, return_tensors="pt", return_attention_mask=False)['input_ids'].squeeze(0)
|
| 41 |
-
val_q_embeds
|
| 42 |
|
| 43 |
val_combined_embeds = torch.cat([val_image_embeds, img_token_embeds, val_q_embeds], dim=1) # 4, 69, 2560
|
| 44 |
|
|
@@ -49,6 +66,8 @@ def model_generate_ans(img,val_q):
|
|
| 49 |
predicted_word_token_logits = phi_output_logits[:, -1, :].unsqueeze(1) # 4,1,51200
|
| 50 |
predicted_word_token = torch.argmax(predicted_word_token_logits, dim = -1) # 4,1
|
| 51 |
predicted_caption[:,g] = predicted_word_token.view(1,-1).to(device)
|
|
|
|
|
|
|
| 52 |
|
| 53 |
predicted_captions_decoded = tokenizer.batch_decode(predicted_caption,ignore_index = 50256)
|
| 54 |
|
|
|
|
| 15 |
clip_embed = 768
|
| 16 |
phi_embed = 2560
|
| 17 |
|
| 18 |
+
class SimpleResBlock(nn.Module):
|
| 19 |
+
def __init__(self, phi_embed):
|
| 20 |
+
super().__init__()
|
| 21 |
+
self.pre_norm = nn.LayerNorm(phi_embed)
|
| 22 |
+
self.proj = nn.Sequential(
|
| 23 |
+
nn.Linear(phi_embed, phi_embed),
|
| 24 |
+
nn.GELU(),
|
| 25 |
+
nn.Linear(phi_embed, phi_embed)
|
| 26 |
+
)
|
| 27 |
+
def forward(self, x):
|
| 28 |
+
x = self.pre_norm(x)
|
| 29 |
+
return x + self.proj(x)
|
| 30 |
+
|
| 31 |
# models
|
| 32 |
clip_model = CLIPVisionModel.from_pretrained(clip_model_name).to(device)
|
| 33 |
projection = torch.nn.Linear(clip_embed, phi_embed).to(device)
|
| 34 |
+
resblock = SimpleResBlock(phi_embed).to(device)
|
| 35 |
+
|
| 36 |
phi_model = AutoModelForCausalLM.from_pretrained(phi_model_name,trust_remote_code=True).to(device)
|
| 37 |
|
| 38 |
# load weights
|
| 39 |
model_to_merge = PeftModel.from_pretrained(phi_model,'./model_chkpt/lora_adaptor')
|
| 40 |
merged_model = model_to_merge.merge_and_unload()
|
| 41 |
projection.load_state_dict(torch.load('./model_chkpt/step2_projection.pth',map_location=torch.device(device)))
|
| 42 |
+
resblock.load_state_dict(torch.load('./model_chkpt/step2_resblock.pth',map_location=torch.device(device)))
|
| 43 |
|
| 44 |
def model_generate_ans(img,val_q):
|
| 45 |
|
| 46 |
max_generate_length = 100
|
| 47 |
|
| 48 |
# image
|
| 49 |
+
image_processed = processor(images=img, return_tensors="pt").to(device)
|
| 50 |
clip_val_outputs = clip_model(**image_processed).last_hidden_state[:,1:,:]
|
| 51 |
+
val_image_embeds = projection(clip_val_outputs)
|
| 52 |
+
val_image_embeds = resblock(val_image_embeds).to(torch.float16)
|
| 53 |
|
| 54 |
img_token_tensor = torch.tensor(IMAGE_TOKEN_ID).to(device)
|
| 55 |
+
img_token_embeds = merged_model.model.embed_tokens(img_token_tensor).unsqueeze(0).unsqueeze(0)
|
| 56 |
|
| 57 |
val_q_tokenised = tokenizer(val_q, return_tensors="pt", return_attention_mask=False)['input_ids'].squeeze(0)
|
| 58 |
+
val_q_embeds = merged_model.model.embed_tokens(val_q_tokenised).unsqueeze(0)
|
| 59 |
|
| 60 |
val_combined_embeds = torch.cat([val_image_embeds, img_token_embeds, val_q_embeds], dim=1) # 4, 69, 2560
|
| 61 |
|
|
|
|
| 66 |
predicted_word_token_logits = phi_output_logits[:, -1, :].unsqueeze(1) # 4,1,51200
|
| 67 |
predicted_word_token = torch.argmax(predicted_word_token_logits, dim = -1) # 4,1
|
| 68 |
predicted_caption[:,g] = predicted_word_token.view(1,-1).to(device)
|
| 69 |
+
next_token_embeds = phi_model.model.embed_tokens(predicted_word_token) # 4,1,2560
|
| 70 |
+
val_combined_embeds = torch.cat([val_combined_embeds, next_token_embeds], dim=1)
|
| 71 |
|
| 72 |
predicted_captions_decoded = tokenizer.batch_decode(predicted_caption,ignore_index = 50256)
|
| 73 |
|