damerajee commited on
Commit
8486052
·
verified ·
1 Parent(s): a32f3ab

Update modeling_gpt2vision.py

Browse files
Files changed (1) hide show
  1. modeling_gpt2vision.py +37 -35
modeling_gpt2vision.py CHANGED
@@ -52,46 +52,46 @@ class GPT2Vision(PreTrainedModel):
52
  def device(self):
53
  return next(self.language_model.parameters()).device
54
 
55
- def tokenize_encode(self, batch, device):
56
- text = batch['text']
57
- images = batch['image']
58
- if isinstance(text, str):
59
- text = [text]
60
- input_texts = [f"{IMAGE_TOKEN}{t}" for t in text]
61
- text_inputs = self.tokenizer(
62
- input_texts,
 
63
  padding='max_length',
64
  truncation=True,
65
- max_length=768,
66
- return_tensors="pt",
67
- pad_to_multiple_of=8,
68
- ).to(device)
69
- pixel_values = self.vision_encoder(images, device)
70
- return {
71
- "input_ids": text_inputs.input_ids,
72
- "attention_mask": text_inputs.attention_mask,
73
- "pixel_values": pixel_values
74
- }
75
-
76
- def preprocess_inputs(self, batch):
77
- pixel_values = batch['pixel_values'].squeeze(1)
78
- input_ids = batch['input_ids'].squeeze(1)
79
- attention_mask = batch['attention_mask'].squeeze(1)
80
- input_ids = input_ids.to(self.device)
81
- attention_mask = attention_mask.to(self.device)
82
- pixel_values = pixel_values.to(self.device)
83
- img_embs = self.mlp(pixel_values)
84
- tok_embs = self.language_model.get_input_embeddings()(input_ids)
85
  inputs_embeds = torch.cat((tok_embs[:, 0:1, :], img_embs, tok_embs[:, 1:, :]), dim=1)
86
- img_attention = torch.ones((img_embs.size(0), img_embs.size(1)), dtype=torch.long, device=self.device)
 
 
87
  attention_mask = torch.cat((attention_mask[:, 0:1], img_attention, attention_mask[:, 1:]), dim=1)
88
- return inputs_embeds, attention_mask, input_ids
 
89
 
90
  def generate(self, question, image, max_new_tokens=30, **kwargs):
91
- prompt = f"Question: {question}\nAnswer:"
92
- batch = {"image": [image], "text": prompt}
93
- encoded_batch = self.tokenize_encode(batch, self.device)
94
- inputs_embeds, attention_mask, input_ids = self.preprocess_inputs(encoded_batch)
95
  output_sequences = self.language_model.generate(
96
  inputs_embeds=inputs_embeds,
97
  attention_mask=attention_mask,
@@ -100,5 +100,7 @@ class GPT2Vision(PreTrainedModel):
100
  max_new_tokens=max_new_tokens,
101
  **kwargs
102
  )
103
- output = self.tokenizer.decode(output_sequences[0], skip_special_tokens=True)
 
 
104
  return output
 
52
  def device(self):
53
  return next(self.language_model.parameters()).device
54
 
55
+ def preprocess_inputs(self, question, image):
56
+ # Convert image to RGB
57
+ image = image.convert("RGB")
58
+
59
+ # Tokenize the question
60
+ tokens = [self.tokenizer.bos_token_id]
61
+ q_tokens = self.tokenizer(
62
+ f"\n\nQuestion: {question}\n\nAnswer:",
63
+ add_special_tokens=False,
64
  padding='max_length',
65
  truncation=True,
66
+ max_length=384,
67
+ ).input_ids
68
+ tokens.extend(q_tokens)
69
+
70
+ # Convert tokens to tensor
71
+ tokens = torch.tensor([tokens], dtype=torch.long).to(self.device)
72
+
73
+ # Create attention mask
74
+ attention_mask = torch.ones_like(tokens, dtype=torch.bool)
75
+
76
+ # Process image
77
+ with torch.no_grad():
78
+ img_embs = self.vision_encoder([image], device=self.device)
79
+
80
+ # Get token embeddings
81
+ tok_embs = self.language_model.get_input_embeddings()(tokens)
82
+
83
+ # Concatenate image embeddings and token embeddings
 
 
84
  inputs_embeds = torch.cat((tok_embs[:, 0:1, :], img_embs, tok_embs[:, 1:, :]), dim=1)
85
+
86
+ # Update attention mask to include image tokens
87
+ img_attention = torch.ones((img_embs.size(0), img_embs.size(1)), dtype=torch.bool, device=self.device)
88
  attention_mask = torch.cat((attention_mask[:, 0:1], img_attention, attention_mask[:, 1:]), dim=1)
89
+
90
+ return inputs_embeds, attention_mask
91
 
92
  def generate(self, question, image, max_new_tokens=30, **kwargs):
93
+ inputs_embeds, attention_mask = self.preprocess_inputs(question, image)
94
+
 
 
95
  output_sequences = self.language_model.generate(
96
  inputs_embeds=inputs_embeds,
97
  attention_mask=attention_mask,
 
100
  max_new_tokens=max_new_tokens,
101
  **kwargs
102
  )
103
+
104
+ # Decode the output, skipping the input tokens
105
+ output = self.tokenizer.decode(output_sequences[0][inputs_embeds.size(1):], skip_special_tokens=True)
106
  return output