Update modeling_minicpmv.py (#56)
Browse files- Update modeling_minicpmv.py (da88bdc057fcaf87792be979f5f695fe12350716)
Co-authored-by: qianyu chen <qianyuchen@users.noreply.huggingface.co>
- modeling_minicpmv.py +8 -8
modeling_minicpmv.py
CHANGED
|
@@ -42,13 +42,13 @@ class MiniCPMV(MiniCPMVPreTrainedModel):
|
|
| 42 |
|
| 43 |
return model
|
| 44 |
|
| 45 |
-
def init_resampler(self, embed_dim, vision_dim):
|
| 46 |
return Resampler(
|
| 47 |
num_queries=self.config.query_num,
|
| 48 |
embed_dim=embed_dim,
|
| 49 |
num_heads=embed_dim // 128,
|
| 50 |
kv_dim=vision_dim,
|
| 51 |
-
adaptive=True
|
| 52 |
)
|
| 53 |
|
| 54 |
def init_transform(self):
|
|
@@ -60,17 +60,17 @@ class MiniCPMV(MiniCPMVPreTrainedModel):
|
|
| 60 |
),
|
| 61 |
]
|
| 62 |
)
|
| 63 |
-
|
| 64 |
def get_input_embeddings(self):
|
| 65 |
return self.llm.get_input_embeddings()
|
| 66 |
|
| 67 |
def set_input_embeddings(self, value):
|
| 68 |
self.llm.embed_tokens = value
|
| 69 |
-
|
| 70 |
def get_vllm_embedding(self, data):
|
| 71 |
if 'vision_hidden_states' not in data:
|
| 72 |
-
dtype = self.
|
| 73 |
-
device = self.
|
| 74 |
tgt_sizes = data['tgt_sizes']
|
| 75 |
pixel_values_list = data['pixel_values']
|
| 76 |
vision_hidden_states = []
|
|
@@ -107,6 +107,7 @@ class MiniCPMV(MiniCPMVPreTrainedModel):
|
|
| 107 |
single_pixel_values = single_pixel_values.permute(0, 2, 1).reshape(B, 3, -1, L)
|
| 108 |
single_vision_embedding = self.vpm(single_pixel_values.type(dtype)).last_hidden_state
|
| 109 |
single_vision_embedding = self.resampler(single_vision_embedding, single_tgt_size.unsqueeze(0))
|
|
|
|
| 110 |
vision_embedding.append(single_vision_embedding)
|
| 111 |
vision_embedding = torch.vstack(vision_embedding)
|
| 112 |
|
|
@@ -152,14 +153,13 @@ class MiniCPMV(MiniCPMVPreTrainedModel):
|
|
| 152 |
image_indices = torch.stack(
|
| 153 |
[torch.arange(r[0], r[1], dtype=torch.long) for r in cur_image_bound]
|
| 154 |
).to(vllm_embedding.device)
|
| 155 |
-
|
| 156 |
cur_vllm_emb.scatter_(0, image_indices.view(-1, 1).repeat(1, cur_vllm_emb.shape[-1]),
|
| 157 |
cur_vs_hs.view(-1, cur_vs_hs.shape[-1]))
|
| 158 |
elif self.training:
|
| 159 |
cur_vllm_emb += cur_vs_hs[0].mean() * 0
|
| 160 |
|
| 161 |
return vllm_embedding, vision_hidden_states
|
| 162 |
-
|
| 163 |
def forward(self, data, **kwargs):
|
| 164 |
vllm_embedding, vision_hidden_states = self.get_vllm_embedding(data)
|
| 165 |
position_ids = data["position_ids"]
|
|
|
|
| 42 |
|
| 43 |
return model
|
| 44 |
|
| 45 |
+
def init_resampler(self, embed_dim, vision_dim,):
|
| 46 |
return Resampler(
|
| 47 |
num_queries=self.config.query_num,
|
| 48 |
embed_dim=embed_dim,
|
| 49 |
num_heads=embed_dim // 128,
|
| 50 |
kv_dim=vision_dim,
|
| 51 |
+
adaptive=True,
|
| 52 |
)
|
| 53 |
|
| 54 |
def init_transform(self):
|
|
|
|
| 60 |
),
|
| 61 |
]
|
| 62 |
)
|
| 63 |
+
|
| 64 |
def get_input_embeddings(self):
|
| 65 |
return self.llm.get_input_embeddings()
|
| 66 |
|
| 67 |
def set_input_embeddings(self, value):
|
| 68 |
self.llm.embed_tokens = value
|
| 69 |
+
|
| 70 |
def get_vllm_embedding(self, data):
|
| 71 |
if 'vision_hidden_states' not in data:
|
| 72 |
+
dtype = self.llm.model.embed_tokens.weight.dtype
|
| 73 |
+
device = self.llm.model.embed_tokens.weight.device
|
| 74 |
tgt_sizes = data['tgt_sizes']
|
| 75 |
pixel_values_list = data['pixel_values']
|
| 76 |
vision_hidden_states = []
|
|
|
|
| 107 |
single_pixel_values = single_pixel_values.permute(0, 2, 1).reshape(B, 3, -1, L)
|
| 108 |
single_vision_embedding = self.vpm(single_pixel_values.type(dtype)).last_hidden_state
|
| 109 |
single_vision_embedding = self.resampler(single_vision_embedding, single_tgt_size.unsqueeze(0))
|
| 110 |
+
|
| 111 |
vision_embedding.append(single_vision_embedding)
|
| 112 |
vision_embedding = torch.vstack(vision_embedding)
|
| 113 |
|
|
|
|
| 153 |
image_indices = torch.stack(
|
| 154 |
[torch.arange(r[0], r[1], dtype=torch.long) for r in cur_image_bound]
|
| 155 |
).to(vllm_embedding.device)
|
|
|
|
| 156 |
cur_vllm_emb.scatter_(0, image_indices.view(-1, 1).repeat(1, cur_vllm_emb.shape[-1]),
|
| 157 |
cur_vs_hs.view(-1, cur_vs_hs.shape[-1]))
|
| 158 |
elif self.training:
|
| 159 |
cur_vllm_emb += cur_vs_hs[0].mean() * 0
|
| 160 |
|
| 161 |
return vllm_embedding, vision_hidden_states
|
| 162 |
+
|
| 163 |
def forward(self, data, **kwargs):
|
| 164 |
vllm_embedding, vision_hidden_states = self.get_vllm_embedding(data)
|
| 165 |
position_ids = data["position_ids"]
|