damerajee commited on
Commit
8a1e458
·
verified ·
1 Parent(s): 8486052

Update modeling_gpt2vision.py

Browse files
Files changed (1) hide show
  1. modeling_gpt2vision.py +35 -37
modeling_gpt2vision.py CHANGED
@@ -52,46 +52,46 @@ class GPT2Vision(PreTrainedModel):
52
  def device(self):
53
  return next(self.language_model.parameters()).device
54
 
55
- def preprocess_inputs(self, question, image):
56
- # Convert image to RGB
57
- image = image.convert("RGB")
58
-
59
- # Tokenize the question
60
- tokens = [self.tokenizer.bos_token_id]
61
- q_tokens = self.tokenizer(
62
- f"\n\nQuestion: {question}\n\nAnswer:",
63
- add_special_tokens=False,
64
  padding='max_length',
65
  truncation=True,
66
- max_length=384,
67
- ).input_ids
68
- tokens.extend(q_tokens)
69
-
70
- # Convert tokens to tensor
71
- tokens = torch.tensor([tokens], dtype=torch.long).to(self.device)
72
-
73
- # Create attention mask
74
- attention_mask = torch.ones_like(tokens, dtype=torch.bool)
75
-
76
- # Process image
77
- with torch.no_grad():
78
- img_embs = self.vision_encoder([image], device=self.device)
79
-
80
- # Get token embeddings
81
- tok_embs = self.language_model.get_input_embeddings()(tokens)
82
-
83
- # Concatenate image embeddings and token embeddings
 
 
84
  inputs_embeds = torch.cat((tok_embs[:, 0:1, :], img_embs, tok_embs[:, 1:, :]), dim=1)
85
-
86
- # Update attention mask to include image tokens
87
- img_attention = torch.ones((img_embs.size(0), img_embs.size(1)), dtype=torch.bool, device=self.device)
88
  attention_mask = torch.cat((attention_mask[:, 0:1], img_attention, attention_mask[:, 1:]), dim=1)
89
-
90
- return inputs_embeds, attention_mask
91
 
92
  def generate(self, question, image, max_new_tokens=30, **kwargs):
93
- inputs_embeds, attention_mask = self.preprocess_inputs(question, image)
94
-
 
 
95
  output_sequences = self.language_model.generate(
96
  inputs_embeds=inputs_embeds,
97
  attention_mask=attention_mask,
@@ -100,7 +100,5 @@ class GPT2Vision(PreTrainedModel):
100
  max_new_tokens=max_new_tokens,
101
  **kwargs
102
  )
103
-
104
- # Decode the output, skipping the input tokens
105
- output = self.tokenizer.decode(output_sequences[0][inputs_embeds.size(1):], skip_special_tokens=True)
106
  return output
 
52
  def device(self):
53
  return next(self.language_model.parameters()).device
54
 
55
+ def tokenize_encode(self, batch, device):
56
+ text = batch['text']
57
+ images = batch['image']
58
+ if isinstance(text, str):
59
+ text = [text]
60
+ input_texts = [f"{IMAGE_TOKEN}{t}" for t in text]
61
+ text_inputs = self.tokenizer(
62
+ input_texts,
 
63
  padding='max_length',
64
  truncation=True,
65
+ max_length=768,
66
+ return_tensors="pt",
67
+ pad_to_multiple_of=8,
68
+ ).to(device)
69
+ pixel_values = self.vision_encoder(images, device)
70
+ return {
71
+ "input_ids": text_inputs.input_ids,
72
+ "attention_mask": text_inputs.attention_mask,
73
+ "pixel_values": pixel_values
74
+ }
75
+
76
+ def preprocess_inputs(self, batch):
77
+ pixel_values = batch['pixel_values'].squeeze(1)
78
+ input_ids = batch['input_ids'].squeeze(1)
79
+ attention_mask = batch['attention_mask'].squeeze(1)
80
+ input_ids = input_ids.to(self.device)
81
+ attention_mask = attention_mask.to(self.device)
82
+ pixel_values = pixel_values.to(self.device)
83
+ img_embs = self.mlp(pixel_values)
84
+ tok_embs = self.language_model.get_input_embeddings()(input_ids)
85
  inputs_embeds = torch.cat((tok_embs[:, 0:1, :], img_embs, tok_embs[:, 1:, :]), dim=1)
86
+ img_attention = torch.ones((img_embs.size(0), img_embs.size(1)), dtype=torch.long, device=self.device)
 
 
87
  attention_mask = torch.cat((attention_mask[:, 0:1], img_attention, attention_mask[:, 1:]), dim=1)
88
+ return inputs_embeds, attention_mask, input_ids
 
89
 
90
  def generate(self, question, image, max_new_tokens=30, **kwargs):
91
+ prompt = f"Question: {question}\nAnswer:"
92
+ batch = {"image": [image], "text": prompt}
93
+ encoded_batch = self.tokenize_encode(batch, self.device)
94
+ inputs_embeds, attention_mask, input_ids = self.preprocess_inputs(encoded_batch)
95
  output_sequences = self.language_model.generate(
96
  inputs_embeds=inputs_embeds,
97
  attention_mask=attention_mask,
 
100
  max_new_tokens=max_new_tokens,
101
  **kwargs
102
  )
103
+ output = self.tokenizer.decode(output_sequences[0], skip_special_tokens=True)
 
 
104
  return output