Update modeling_mplug_owl2.py
Browse files- modeling_mplug_owl2.py +7 -4
modeling_mplug_owl2.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
# Copyright 2023 Haotian Liu & Qinghao Ye (Modified from LLaVA)
|
| 2 |
#
|
| 3 |
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
# you may not use this file except in compliance with the License.
|
|
@@ -271,20 +271,23 @@ class MPLUGOwl2LlamaForCausalLM(LlamaForCausalLM, MPLUGOwl2MetaForCausalLM):
|
|
| 271 |
task_: str = "quality",
|
| 272 |
input_: str = "image",
|
| 273 |
return_dict=False,
|
|
|
|
| 274 |
):
|
| 275 |
if not hasattr(self, "weight_tensor"):
|
| 276 |
self.weight_tensor = torch.Tensor([5.,4.,3.,2.,1.]).half().to(self.device)
|
| 277 |
prompt = "USER: How would you rate the {} of this {}?\n<|image|>\nASSISTANT: The {} of the {} is".format(task_, input_, task_, input_)
|
| 278 |
if input_ == "image":
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
image_tensor = self.image_processor.preprocess(images, return_tensors="pt")["pixel_values"].half().to(self.device)
|
|
|
|
| 283 |
output_logits = self(input_ids.repeat(image_tensor.shape[0], 1),
|
| 284 |
images=image_tensor)["logits"][:,-1, self.preferential_ids_]
|
| 285 |
if return_dict:
|
| 286 |
return {"logits": output_logits, "scores": torch.softmax(output_logits, -1) @ self.weight_tensor}
|
| 287 |
return torch.softmax(output_logits, -1) @ self.weight_tensor
|
|
|
|
| 288 |
else:
|
| 289 |
video = [[expand2square(frame, tuple(int(x*255) for x in self.image_processor.image_mean)) for frame in vid] for vid in images]
|
| 290 |
input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(self.device)
|
|
|
|
| 1 |
+
# Copyright 2023 Haotian Liu & Qinghao Ye & Haoning Wu (Modified from LLaVA, and mPLUG-Owl2)
|
| 2 |
#
|
| 3 |
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
# you may not use this file except in compliance with the License.
|
|
|
|
| 271 |
task_: str = "quality",
|
| 272 |
input_: str = "image",
|
| 273 |
return_dict=False,
|
| 274 |
+
image_tensor = None,
|
| 275 |
):
|
| 276 |
if not hasattr(self, "weight_tensor"):
|
| 277 |
self.weight_tensor = torch.Tensor([5.,4.,3.,2.,1.]).half().to(self.device)
|
| 278 |
prompt = "USER: How would you rate the {} of this {}?\n<|image|>\nASSISTANT: The {} of the {} is".format(task_, input_, task_, input_)
|
| 279 |
if input_ == "image":
|
| 280 |
+
if image_tensor is None:
|
| 281 |
+
images = [expand2square(img, tuple(int(x*255) for x in self.image_processor.image_mean)) for img in images]
|
| 282 |
+
input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(self.device)
|
| 283 |
image_tensor = self.image_processor.preprocess(images, return_tensors="pt")["pixel_values"].half().to(self.device)
|
| 284 |
+
with torch.inference_mode():
|
| 285 |
output_logits = self(input_ids.repeat(image_tensor.shape[0], 1),
|
| 286 |
images=image_tensor)["logits"][:,-1, self.preferential_ids_]
|
| 287 |
if return_dict:
|
| 288 |
return {"logits": output_logits, "scores": torch.softmax(output_logits, -1) @ self.weight_tensor}
|
| 289 |
return torch.softmax(output_logits, -1) @ self.weight_tensor
|
| 290 |
+
|
| 291 |
else:
|
| 292 |
video = [[expand2square(frame, tuple(int(x*255) for x in self.image_processor.image_mean)) for frame in vid] for vid in images]
|
| 293 |
input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(self.device)
|