damerajee commited on
Commit
3fbcf10
·
verified ·
1 Parent(s): cf50367

Update modeling_gpt2vision.py

Browse files
Files changed (1) hide show
  1. modeling_gpt2vision.py +86 -82
modeling_gpt2vision.py CHANGED
@@ -20,102 +20,106 @@ def resize_token_embeds(model_name="openai-community/gpt2"):
20
 
21
  tokenizer = resize_token_embeds()
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  class GPT2Vision(PreTrainedModel):
24
  config_class = GPT2VisionConfig
25
 
26
  def __init__(self, config):
27
  super().__init__(config)
28
  self.vision_encoder = VisionEncoder()
29
-
30
- if isinstance(config.gpt2_config, dict):
31
- gpt2_config = GPT2Config(**config.gpt2_config)
32
- else:
33
- gpt2_config = config.gpt2_config
34
- self.text_model = GPT2LMHeadModel(gpt2_config)
35
-
36
- self.text_model.resize_token_embeddings(len(tokenizer))
37
  self.tokenizer = tokenizer
38
  tokenizer.pad_token = tokenizer.eos_token
39
  self.image_token_id = self.tokenizer.convert_tokens_to_ids(IMAGE_TOKEN)
 
40
 
41
-
42
  @property
43
  def device(self):
44
- return self.text_model.device
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
- def encode_image(self, image,device):
47
- return self.vision_encoder(image,device=device)
48
 
49
- def input_embeds(self, prompt, image_embeds, tokenizer):
50
- def _tokenize(txt):
51
- return tokenizer(
52
- txt, return_tensors="pt", add_special_tokens=False
53
- ).input_ids.to(self.device)
54
 
55
- text_emb = self.text_model.get_input_embeddings()
56
 
57
- # Add BOS token
58
- embeds = []
59
- embeds.append(
60
- text_emb((torch.tensor([[tokenizer.bos_token_id]], device=self.device)))
 
 
 
61
  )
62
-
63
- if "<image>" not in prompt:
64
- embeds.append(text_emb(_tokenize(prompt)))
65
- else:
66
- assert prompt.count("<image>") == 1
67
- before, after = prompt.split("<image>")
68
- embeds.append(text_emb(_tokenize(f"{before}<image>")))
69
- embeds.append(image_embeds.to(self.device))
70
- embeds.append(text_emb(_tokenize(f"</image>{after}")))
71
-
72
- return torch.cat(embeds, dim=1)
73
-
74
- def generate(
75
- self,
76
- image_embeds,
77
- prompt,
78
- tokenizer,
79
- eos_text="<|endoftext|>",
80
- max_new_tokens=128,
81
- **kwargs,
82
- ):
83
- eos_tokens = tokenizer(eos_text, add_special_tokens=False)["input_ids"]
84
-
85
- generate_config = {
86
- "eos_token_id": eos_tokens,
87
- "bos_token_id": tokenizer.bos_token_id,
88
- "pad_token_id": tokenizer.eos_token_id,
89
- "max_new_tokens": max_new_tokens,
90
- **kwargs,
91
- }
92
-
93
- with torch.no_grad():
94
-
95
- inputs_embeds = self.input_embeds(prompt, image_embeds, tokenizer)
96
- print("inputs_embeds",inputs_embeds.size())
97
- output_ids = self.text_model.generate(
98
- inputs_embeds=inputs_embeds, **generate_config
99
- )
100
-
101
- return tokenizer.batch_decode(output_ids, skip_special_tokens=True)
102
-
103
- def answer_question(
104
- self,
105
- image_embeds,
106
- question,
107
- tokenizer,
108
- chat_history="",
109
- result_queue=None,
110
- **kwargs,
111
- ):
112
- prompt = f"<image>\n\n{chat_history}Question: {question}\n\nAnswer: "
113
- answer = self.generate(
114
- image_embeds,
115
- prompt,
116
- tokenizer,
117
- eos_text="<|endoftext|>",
118
- max_new_tokens=256,
119
- **kwargs,
120
- )[0]
121
- return answer
 
20
 
21
  tokenizer = resize_token_embeds()
22
 
23
+ class MLP(nn.Module):
24
+ def __init__(self, in_features: int, hidden_features: int = None, out_features: int = None):
25
+ super().__init__()
26
+ out_features = out_features or in_features
27
+ hidden_features = hidden_features or in_features
28
+ self.fc1 = nn.Linear(in_features, hidden_features)
29
+ self.act = nn.GELU(approximate="tanh")
30
+ self.fc2 = nn.Linear(hidden_features, out_features)
31
+ self.dropout = nn.Dropout(p=0.1)
32
+
33
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
34
+ x = self.fc1(x)
35
+ x = self.act(x)
36
+ x = self.dropout(x)
37
+ x = self.fc2(x)
38
+ return x
39
+
40
  class GPT2Vision(PreTrainedModel):
41
  config_class = GPT2VisionConfig
42
 
43
  def __init__(self, config):
44
  super().__init__(config)
45
  self.vision_encoder = VisionEncoder()
46
+ self.mlp = MLP(in_features=768, hidden_features=768 * 4, out_features=768)
47
+ self.language_model = GPT2LMHeadModel(config.gpt2_config)
48
+ self.language_model.resize_token_embeddings(len(tokenizer))
 
 
 
 
 
49
  self.tokenizer = tokenizer
50
  tokenizer.pad_token = tokenizer.eos_token
51
  self.image_token_id = self.tokenizer.convert_tokens_to_ids(IMAGE_TOKEN)
52
+ self.img_tokens = 197 # This should match IMG_TOKENS in your training code
53
 
 
54
  @property
55
  def device(self):
56
+ return next(self.language_model.parameters()).device
57
+
58
+ def tokenize_encode(self, batch, device):
59
+ text = batch['text']
60
+ images = batch['image']
61
+ if isinstance(text, str):
62
+ text = [text]
63
+ input_texts = [f"{IMAGE_TOKEN}{t}" for t in text]
64
+ text_inputs = self.tokenizer(
65
+ input_texts,
66
+ padding='max_length',
67
+ truncation=True,
68
+ max_length=768,
69
+ return_tensors="pt",
70
+ ).to(device)
71
+
72
+ # Adjust attention mask to account for image tokens and the extra <image> token
73
+ batch_size = text_inputs.input_ids.shape[0]
74
+ img_attention = torch.ones((batch_size, self.img_tokens + 1), dtype=torch.long, device=device)
75
+ attention_mask = torch.cat([img_attention, text_inputs.attention_mask[:, 1:]], dim=1)
76
+
77
+ return {
78
+ "input_ids": text_inputs.input_ids,
79
+ "attention_mask": attention_mask,
80
+ "images": images
81
+ }
82
+
83
+ def preprocess_inputs(self, batch):
84
+ images = batch['images']
85
+ input_ids = batch['input_ids'].to(self.device)
86
+ attention_mask = batch['attention_mask'].to(self.device)
87
+
88
+ img_embs = self.vision_encoder(images, device=self.device)
89
+ print("img_embs",img_embs.size())
90
+ img_embs = self.mlp(img_embs)
91
+
92
+ tok_embs = self.language_model.get_input_embeddings()(input_ids)
93
 
94
+ inputs_embeds = torch.cat((tok_embs[:, 0:1, :], img_embs, tok_embs[:, 1:, :]), dim=1)
 
95
 
96
+ # Ensure the attention mask aligns with the inputs_embeds
97
+ assert inputs_embeds.shape[1] == attention_mask.shape[1], f"Mismatch between embeddings ({inputs_embeds.shape[1]}) and attention mask length ({attention_mask.shape[1]})."
 
 
 
98
 
99
+ return inputs_embeds, attention_mask
100
 
101
+ def forward(self, batch, **kwargs):
102
+ inputs_embeds, attention_mask = self.preprocess_inputs(batch)
103
+
104
+ outputs = self.language_model(
105
+ inputs_embeds=inputs_embeds,
106
+ attention_mask=attention_mask,
107
+ **kwargs
108
  )
109
+ return outputs
110
+
111
+ def generate(self, question, image, max_new_tokens=30, **kwargs):
112
+ prompt = f"Question: {question}\nAnswer:"
113
+ batch = {"image": [image], "text": prompt}
114
+ encoded_batch = self.tokenize_encode(batch, self.device)
115
+ inputs_embeds, attention_mask = self.preprocess_inputs(encoded_batch)
116
+ output_sequences = self.language_model.generate(
117
+ inputs_embeds=inputs_embeds,
118
+ attention_mask=attention_mask,
119
+ pad_token_id=self.tokenizer.eos_token_id,
120
+ eos_token_id=self.tokenizer.eos_token_id,
121
+ max_new_tokens=max_new_tokens,
122
+ **kwargs
123
+ )
124
+ output = self.tokenizer.decode(output_sequences[0], skip_special_tokens=True)
125
+ return output