Upload modeling_mplug_owl2.py with huggingface_hub
Browse files- modeling_mplug_owl2.py +58 -1
modeling_mplug_owl2.py
CHANGED
|
@@ -37,6 +37,40 @@ IMAGE_TOKEN_INDEX = -200
|
|
| 37 |
DEFAULT_IMAGE_TOKEN = "<|image|>"
|
| 38 |
from icecream import ic
|
| 39 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
class MPLUGOwl2MetaModel:
|
| 41 |
def __init__(self, config):
|
| 42 |
super(MPLUGOwl2MetaModel, self).__init__(config)
|
|
@@ -218,13 +252,36 @@ class MPLUGOwl2LlamaForCausalLM(LlamaForCausalLM, MPLUGOwl2MetaForCausalLM):
|
|
| 218 |
self.model = MPLUGOwl2LlamaModel(config)
|
| 219 |
|
| 220 |
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
|
|
|
|
|
|
| 221 |
|
| 222 |
# Initialize weights and apply final processing
|
| 223 |
self.post_init()
|
| 224 |
|
| 225 |
def get_model(self):
|
| 226 |
return self.model
|
| 227 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 228 |
def forward(
|
| 229 |
self,
|
| 230 |
input_ids: torch.LongTensor = None,
|
|
|
|
| 37 |
DEFAULT_IMAGE_TOKEN = "<|image|>"
|
| 38 |
from icecream import ic
|
| 39 |
|
| 40 |
+
def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):
|
| 41 |
+
prompt_chunks = [tokenizer(chunk).input_ids if len(chunk) > 0 else [] for chunk in prompt.split(DEFAULT_IMAGE_TOKEN)]
|
| 42 |
+
|
| 43 |
+
def insert_separator(X, sep):
|
| 44 |
+
return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]
|
| 45 |
+
|
| 46 |
+
input_ids = []
|
| 47 |
+
offset = 0
|
| 48 |
+
if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
|
| 49 |
+
offset = 1
|
| 50 |
+
input_ids.append(prompt_chunks[0][0])
|
| 51 |
+
|
| 52 |
+
for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
|
| 53 |
+
input_ids.extend(x[offset:])
|
| 54 |
+
|
| 55 |
+
if return_tensors is not None:
|
| 56 |
+
if return_tensors == 'pt':
|
| 57 |
+
return torch.tensor(input_ids, dtype=torch.long)
|
| 58 |
+
raise ValueError(f'Unsupported tensor type: {return_tensors}')
|
| 59 |
+
return input_ids
|
| 60 |
+
|
| 61 |
+
def expand2square(pil_img, background_color):
|
| 62 |
+
width, height = pil_img.size
|
| 63 |
+
if width == height:
|
| 64 |
+
return pil_img
|
| 65 |
+
elif width > height:
|
| 66 |
+
result = Image.new(pil_img.mode, (width, width), background_color)
|
| 67 |
+
result.paste(pil_img, (0, (width - height) // 2))
|
| 68 |
+
return result
|
| 69 |
+
else:
|
| 70 |
+
result = Image.new(pil_img.mode, (height, height), background_color)
|
| 71 |
+
result.paste(pil_img, ((height - width) // 2, 0))
|
| 72 |
+
return result
|
| 73 |
+
|
| 74 |
class MPLUGOwl2MetaModel:
|
| 75 |
def __init__(self, config):
|
| 76 |
super(MPLUGOwl2MetaModel, self).__init__(config)
|
|
|
|
| 252 |
self.model = MPLUGOwl2LlamaModel(config)
|
| 253 |
|
| 254 |
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 255 |
+
self.preferential_ids_ = [id_[1] for id_ in tokenizer(["excellent","good","fair","poor","bad"])["input_ids"]]
|
| 256 |
+
self.weight_tensor = torch.Tensor([5.,4.,3.,2.,1.]).half().to(model.device)
|
| 257 |
|
| 258 |
# Initialize weights and apply final processing
|
| 259 |
self.post_init()
|
| 260 |
|
| 261 |
def get_model(self):
|
| 262 |
return self.model
|
| 263 |
+
|
| 264 |
+
def score(self, images,
|
| 265 |
+
task_: str = "quality",
|
| 266 |
+
input_: str = "image",
|
| 267 |
+
):
|
| 268 |
+
prompt = "USER: How would you rate the {} of this {}?\n<|image|>\nASSISTANT: The {} of the {} is".format(task_, input_, input_, task_)
|
| 269 |
+
if input_ == "image":
|
| 270 |
+
images = [expand2square(img, tuple(int(x*255) for x in self.image_processor.image_mean)) for img in images]
|
| 271 |
+
input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(self.device)
|
| 272 |
+
with torch.inference_mode():
|
| 273 |
+
image_tensor = self.image_processor.preprocess(images, return_tensors="pt")["pixel_values"].half().to(self.device)
|
| 274 |
+
output_logits = model(input_ids.repeat(image_tensor.shape[0], 1),
|
| 275 |
+
images=image_tensor)["logits"][:,-1, self.preferential_ids_]
|
| 276 |
+
return torch.softmax(output_logits, -1) @ self.weight_tensor
|
| 277 |
+
else:
|
| 278 |
+
video = [[expand2square(frame, tuple(int(x*255) for x in self.image_processor.image_mean)) for frame in vid] for vid in images]
|
| 279 |
+
with torch.inference_mode():
|
| 280 |
+
video_tensors = [self.image_processor.preprocess(vid, return_tensors="pt")["pixel_values"].half().to(self.model.device) for vid in video]
|
| 281 |
+
output_logits = self.model(self.input_ids.repeat(len(video_tensors), 1),
|
| 282 |
+
images=video_tensors)["logits"][:,-1, self.preferential_ids_]
|
| 283 |
+
return torch.softmax(output_logits, -1) @ self.weight_tensor
|
| 284 |
+
|
| 285 |
def forward(
|
| 286 |
self,
|
| 287 |
input_ids: torch.LongTensor = None,
|