Update modeling_minicpmv.py
Browse files- modeling_minicpmv.py +33 -29
modeling_minicpmv.py
CHANGED
|
@@ -231,44 +231,48 @@ class MiniCPMV(MiniCPMVPreTrainedModel):
|
|
| 231 |
|
| 232 |
def generate(
|
| 233 |
self,
|
| 234 |
-
|
| 235 |
tokenizer=None,
|
| 236 |
vision_hidden_states=None,
|
| 237 |
stream=False,
|
| 238 |
**kwargs
|
| 239 |
):
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
for
|
| 251 |
-
img_inps
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 265 |
|
| 266 |
# output_ids = self._decode(input_embeds, tokenizer, **kwargs)
|
| 267 |
if stream:
|
| 268 |
kwargs.pop("decode_text")
|
| 269 |
-
result = self._decode_stream(
|
| 270 |
else:
|
| 271 |
-
result = self._decode(
|
| 272 |
|
| 273 |
return result
|
| 274 |
|
|
@@ -366,5 +370,5 @@ class MiniCPMV(MiniCPMVPreTrainedModel):
|
|
| 366 |
return stream_gen()
|
| 367 |
|
| 368 |
else:
|
| 369 |
-
answer = res
|
| 370 |
return answer
|
|
|
|
| 231 |
|
| 232 |
def generate(
|
| 233 |
self,
|
| 234 |
+
model_inputs_batch,
|
| 235 |
tokenizer=None,
|
| 236 |
vision_hidden_states=None,
|
| 237 |
stream=False,
|
| 238 |
**kwargs
|
| 239 |
):
|
| 240 |
+
batch = []
|
| 241 |
+
for model_inputs in model_inputs_batch:
|
| 242 |
+
bs = len(model_inputs["input_ids"])
|
| 243 |
+
img_list = model_inputs["pixel_values"]
|
| 244 |
+
tgt_sizes = model_inputs["tgt_sizes"]
|
| 245 |
+
if img_list is None:
|
| 246 |
+
img_list = [[] for i in range(bs)]
|
| 247 |
+
assert bs == len(img_list)
|
| 248 |
+
if vision_hidden_states is None:
|
| 249 |
+
pixel_values = []
|
| 250 |
+
for i in range(bs):
|
| 251 |
+
img_inps = []
|
| 252 |
+
for img in img_list[i]:
|
| 253 |
+
img_inps.append(img.to(self.device))
|
| 254 |
+
if img_inps:
|
| 255 |
+
pixel_values.append(img_inps)
|
| 256 |
+
else:
|
| 257 |
+
pixel_values.append([])
|
| 258 |
+
model_inputs["pixel_values"] = pixel_values
|
| 259 |
+
model_inputs['tgt_sizes'] = tgt_sizes
|
| 260 |
+
else:
|
| 261 |
+
model_inputs["vision_hidden_states"] = vision_hidden_states
|
| 262 |
+
|
| 263 |
+
(
|
| 264 |
+
input_embeds,
|
| 265 |
+
vision_hidden_states,
|
| 266 |
+
) = self.get_vllm_embedding(model_inputs)
|
| 267 |
+
batch.append(input_embeds)
|
| 268 |
+
|
| 269 |
|
| 270 |
# output_ids = self._decode(input_embeds, tokenizer, **kwargs)
|
| 271 |
if stream:
|
| 272 |
kwargs.pop("decode_text")
|
| 273 |
+
result = self._decode_stream(batch, tokenizer, **kwargs)
|
| 274 |
else:
|
| 275 |
+
result = self._decode(batch, tokenizer, **kwargs)
|
| 276 |
|
| 277 |
return result
|
| 278 |
|
|
|
|
| 370 |
return stream_gen()
|
| 371 |
|
| 372 |
else:
|
| 373 |
+
answer = res
|
| 374 |
return answer
|