Update modeling_minicpmv.py
Browse files- modeling_minicpmv.py +37 -34
modeling_minicpmv.py
CHANGED
|
@@ -274,7 +274,7 @@ class MiniCPMV(MiniCPMVPreTrainedModel):
|
|
| 274 |
|
| 275 |
def chat(
|
| 276 |
self,
|
| 277 |
-
|
| 278 |
msgs,
|
| 279 |
tokenizer,
|
| 280 |
processor=None,
|
|
@@ -290,42 +290,45 @@ class MiniCPMV(MiniCPMVPreTrainedModel):
|
|
| 290 |
processor = AutoProcessor.from_pretrained(self.config._name_or_path, trust_remote_code=True)
|
| 291 |
if isinstance(msgs, str):
|
| 292 |
msgs = json.loads(msgs)
|
| 293 |
-
copy_msgs = deepcopy(msgs)
|
| 294 |
|
| 295 |
assert len(msgs) > 0, "msgs is empty"
|
| 296 |
assert sampling or not stream, "if use stream mode, make sure sampling=True"
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
copy_msgs[
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
assert role
|
| 312 |
-
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
|
| 325 |
-
|
| 326 |
-
|
| 327 |
-
|
| 328 |
-
|
|
|
|
|
|
|
|
|
|
| 329 |
|
| 330 |
if sampling:
|
| 331 |
generation_config = {
|
|
|
|
| 274 |
|
| 275 |
def chat(
|
| 276 |
self,
|
| 277 |
+
images,
|
| 278 |
msgs,
|
| 279 |
tokenizer,
|
| 280 |
processor=None,
|
|
|
|
| 290 |
processor = AutoProcessor.from_pretrained(self.config._name_or_path, trust_remote_code=True)
|
| 291 |
if isinstance(msgs, str):
|
| 292 |
msgs = json.loads(msgs)
|
| 293 |
+
# copy_msgs = deepcopy(msgs)
|
| 294 |
|
| 295 |
assert len(msgs) > 0, "msgs is empty"
|
| 296 |
assert sampling or not stream, "if use stream mode, make sure sampling=True"
|
| 297 |
+
assert(len(msgs) == len(images)), "Make sure to have one image per item in your batch"
|
| 298 |
+
batchM = []
|
| 299 |
+
batchI = []
|
| 300 |
+
for ind in range(len(images)):
|
| 301 |
+
image = images[ind]
|
| 302 |
+
if image is not None and isinstance(copy_msgs[0]["content"], str):
|
| 303 |
+
# deep copy element
|
| 304 |
+
copy_msgs = deepcopy(msgs[ind])
|
| 305 |
+
copy_msgs[0]["content"] = [image, copy_msgs[0]["content"]]
|
| 306 |
+
|
| 307 |
+
imagelist = []
|
| 308 |
+
for i, msg in enumerate(copy_msgs):
|
| 309 |
+
role = msg["role"]
|
| 310 |
+
content = msg["content"]
|
| 311 |
+
assert role in ["user", "assistant"]
|
| 312 |
+
if i == 0:
|
| 313 |
+
assert role == "user", "The role of first msg should be user"
|
| 314 |
+
if isinstance(content, str):
|
| 315 |
+
content = [content]
|
| 316 |
+
cur_msgs = []
|
| 317 |
+
for c in content:
|
| 318 |
+
if isinstance(c, Image.Image):
|
| 319 |
+
imagelist.append(c)
|
| 320 |
+
cur_msgs.append("(<image>./</image>)")
|
| 321 |
+
elif isinstance(c, str):
|
| 322 |
+
cur_msgs.append(c)
|
| 323 |
+
msg["content"] = "\n".join(cur_msgs)
|
| 324 |
+
|
| 325 |
+
if system_prompt:
|
| 326 |
+
sys_msg = {'role': 'system', 'content': system_prompt}
|
| 327 |
+
copy_msgs = [sys_msg] + copy_msgs
|
| 328 |
+
batchM.append(copy_msgs)
|
| 329 |
+
batchI.append(imagelist)
|
| 330 |
+
prompt = processor.tokenizer.apply_chat_template(batchM, tokenize=False, add_generation_prompt=True)
|
| 331 |
+
inputs = processor(prompt, batchI, return_tensors="pt", max_length=max_inp_length).to(self.device)
|
| 332 |
|
| 333 |
if sampling:
|
| 334 |
generation_config = {
|