Update modeling_gpt2vision.py
Browse files- modeling_gpt2vision.py +52 -2
modeling_gpt2vision.py
CHANGED
|
@@ -18,6 +18,30 @@ def resize_token_embeds(model_name="openai-community/gpt2"):
|
|
| 18 |
|
| 19 |
tokenizer = resize_token_embeds()
|
| 20 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
class MLP(nn.Module):
|
| 22 |
def __init__(self, in_features: int, hidden_features: int = None, out_features: int = None):
|
| 23 |
super().__init__()
|
|
@@ -62,7 +86,7 @@ class GPT2Vision(PreTrainedModel):
|
|
| 62 |
input_texts,
|
| 63 |
padding='max_length',
|
| 64 |
truncation=True,
|
| 65 |
-
max_length=
|
| 66 |
return_tensors="pt",
|
| 67 |
).to(device)
|
| 68 |
pixel_values = self.vision_encoder(images, device)
|
|
@@ -72,20 +96,46 @@ class GPT2Vision(PreTrainedModel):
|
|
| 72 |
"pixel_values": pixel_values
|
| 73 |
}
|
| 74 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
def preprocess_inputs(self, batch):
|
| 76 |
pixel_values = batch['pixel_values'].squeeze(1)
|
| 77 |
input_ids = batch['input_ids'].squeeze(1)
|
| 78 |
attention_mask = batch['attention_mask'].squeeze(1)
|
|
|
|
| 79 |
input_ids = input_ids.to(self.device)
|
| 80 |
attention_mask = attention_mask.to(self.device)
|
| 81 |
pixel_values = pixel_values.to(self.device)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
img_embs = self.mlp(pixel_values)
|
| 83 |
tok_embs = self.language_model.get_input_embeddings()(input_ids)
|
|
|
|
| 84 |
inputs_embeds = torch.cat((tok_embs[:, 0:1, :], img_embs, tok_embs[:, 1:, :]), dim=1)
|
|
|
|
| 85 |
img_attention = torch.ones((img_embs.size(0), img_embs.size(1)), dtype=torch.long, device=self.device)
|
| 86 |
attention_mask = torch.cat((attention_mask[:, 0:1], img_attention, attention_mask[:, 1:]), dim=1)
|
| 87 |
-
return inputs_embeds, attention_mask, input_ids
|
| 88 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 89 |
def generate(self, question, image, max_new_tokens=30, **kwargs):
|
| 90 |
prompt = f"Question: {question}\nAnswer:"
|
| 91 |
batch = {"image": [image], "text": prompt}
|
|
|
|
| 18 |
|
| 19 |
tokenizer = resize_token_embeds()
|
| 20 |
|
| 21 |
+
def create_labels(input_ids, tokenizer, attention_mask):
|
| 22 |
+
labels = input_ids.clone()
|
| 23 |
+
|
| 24 |
+
labels[attention_mask == 0] = -100
|
| 25 |
+
|
| 26 |
+
answer_start_tokens = tokenizer.encode("Answer:", add_special_tokens=False)
|
| 27 |
+
|
| 28 |
+
for i, seq in enumerate(input_ids):
|
| 29 |
+
# Find the start of the answer
|
| 30 |
+
answer_start = (seq == answer_start_tokens[0]).nonzero(as_tuple=True)[0]
|
| 31 |
+
if len(answer_start) > 0:
|
| 32 |
+
answer_start = answer_start[0]
|
| 33 |
+
if seq[answer_start:answer_start+len(answer_start_tokens)].tolist() == answer_start_tokens:
|
| 34 |
+
# Mask out everything before the answer
|
| 35 |
+
labels[i, :answer_start] = -100
|
| 36 |
+
|
| 37 |
+
# Find the end of the sequence (last non-padding token)
|
| 38 |
+
sequence_end = attention_mask[i].nonzero(as_tuple=True)[0][-1]
|
| 39 |
+
|
| 40 |
+
# Keep the last token (EOS) as part of the label
|
| 41 |
+
labels[i, sequence_end+1:] = -100
|
| 42 |
+
|
| 43 |
+
return labels
|
| 44 |
+
|
| 45 |
class MLP(nn.Module):
|
| 46 |
def __init__(self, in_features: int, hidden_features: int = None, out_features: int = None):
|
| 47 |
super().__init__()
|
|
|
|
| 86 |
input_texts,
|
| 87 |
padding='max_length',
|
| 88 |
truncation=True,
|
| 89 |
+
max_length=768,
|
| 90 |
return_tensors="pt",
|
| 91 |
).to(device)
|
| 92 |
pixel_values = self.vision_encoder(images, device)
|
|
|
|
| 96 |
"pixel_values": pixel_values
|
| 97 |
}
|
| 98 |
|
| 99 |
+
def freeze_model_components(self, freeze_vision=True, freeze_language=True, freeze_mlp=False):
|
| 100 |
+
for param in self.vision_model.parameters():
|
| 101 |
+
param.requires_grad = not freeze_vision
|
| 102 |
+
for param in self.language_model.parameters():
|
| 103 |
+
param.requires_grad = not freeze_language
|
| 104 |
+
for param in self.mlp.parameters():
|
| 105 |
+
param.requires_grad = not freeze_mlp
|
| 106 |
+
|
| 107 |
def preprocess_inputs(self, batch):
|
| 108 |
pixel_values = batch['pixel_values'].squeeze(1)
|
| 109 |
input_ids = batch['input_ids'].squeeze(1)
|
| 110 |
attention_mask = batch['attention_mask'].squeeze(1)
|
| 111 |
+
|
| 112 |
input_ids = input_ids.to(self.device)
|
| 113 |
attention_mask = attention_mask.to(self.device)
|
| 114 |
pixel_values = pixel_values.to(self.device)
|
| 115 |
+
|
| 116 |
+
labels = create_labels(input_ids, self.tokenizer, attention_mask)
|
| 117 |
+
labels = labels.to(self.device)
|
| 118 |
+
|
| 119 |
img_embs = self.mlp(pixel_values)
|
| 120 |
tok_embs = self.language_model.get_input_embeddings()(input_ids)
|
| 121 |
+
|
| 122 |
inputs_embeds = torch.cat((tok_embs[:, 0:1, :], img_embs, tok_embs[:, 1:, :]), dim=1)
|
| 123 |
+
|
| 124 |
img_attention = torch.ones((img_embs.size(0), img_embs.size(1)), dtype=torch.long, device=self.device)
|
| 125 |
attention_mask = torch.cat((attention_mask[:, 0:1], img_attention, attention_mask[:, 1:]), dim=1)
|
|
|
|
| 126 |
|
| 127 |
+
img_labels = torch.full((labels.size(0), img_embs.size(1)), fill_value=-100, dtype=torch.long, device=self.device)
|
| 128 |
+
labels = torch.cat((labels[:, 0:1], img_labels, labels[:, 1:]), dim=1)
|
| 129 |
+
|
| 130 |
+
return inputs_embeds, attention_mask, input_ids, labels
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def forward(self, batch, **kwargs):
|
| 134 |
+
inputs_embeds, attention_mask, input_ids, labels = self.preprocess_inputs(batch)
|
| 135 |
+
|
| 136 |
+
outputs = self.language_model(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels)
|
| 137 |
+
return outputs
|
| 138 |
+
|
| 139 |
def generate(self, question, image, max_new_tokens=30, **kwargs):
|
| 140 |
prompt = f"Question: {question}\nAnswer:"
|
| 141 |
batch = {"image": [image], "text": prompt}
|