|
|
import copy |
|
|
from dataclasses import dataclass, field |
|
|
from typing import Dict, List, Optional, Sequence |
|
|
|
|
|
from PIL import ImageFile |
|
|
|
|
|
ImageFile.LOAD_TRUNCATED_IMAGES = True |
|
|
|
|
|
from dataclasses import dataclass, field |
|
|
from typing import List, Optional |
|
|
|
|
|
import torch |
|
|
import torch.distributed as dist |
|
|
import transformers |
|
|
from PIL import Image |
|
|
|
|
|
from src import conversation as conversation_lib |
|
|
from src.constants import DEFAULT_IMAGE_TOKEN, IGNORE_INDEX |
|
|
from src.mm_utils import tokenizer_image_token |
|
|
|
|
|
|
|
|
def rank0_print(*args): |
|
|
try: |
|
|
if dist.get_rank() == 0: |
|
|
print(*args) |
|
|
except: |
|
|
print(*args) |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class DataArguments: |
|
|
data_paths: List[str] = field(default_factory=lambda: []) |
|
|
lazy_preprocess: bool = False |
|
|
is_multimodal: bool = False |
|
|
image_folder: Optional[str] = field(default=None) |
|
|
image_aspect_ratio: str = "square" |
|
|
image_grid_pinpoints: Optional[str] = field(default=None) |
|
|
|
|
|
|
|
|
def _tokenize_fn( |
|
|
strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer |
|
|
) -> Dict: |
|
|
"""Tokenize a list of strings.""" |
|
|
tokenized_list = [ |
|
|
tokenizer( |
|
|
text, |
|
|
return_tensors="pt", |
|
|
padding="longest", |
|
|
max_length=tokenizer.model_max_length, |
|
|
truncation=True, |
|
|
) |
|
|
for text in strings |
|
|
] |
|
|
input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list] |
|
|
input_ids_lens = labels_lens = [ |
|
|
tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() |
|
|
for tokenized in tokenized_list |
|
|
] |
|
|
return dict( |
|
|
input_ids=input_ids, |
|
|
labels=labels, |
|
|
input_ids_lens=input_ids_lens, |
|
|
labels_lens=labels_lens, |
|
|
) |
|
|
|
|
|
|
|
|
def _mask_targets(target, tokenized_lens, speakers): |
|
|
|
|
|
cur_idx = tokenized_lens[0] |
|
|
tokenized_lens = tokenized_lens[1:] |
|
|
target[:cur_idx] = IGNORE_INDEX |
|
|
for tokenized_len, speaker in zip(tokenized_lens, speakers): |
|
|
if speaker == "human": |
|
|
target[cur_idx + 2 : cur_idx + tokenized_len] = IGNORE_INDEX |
|
|
cur_idx += tokenized_len |
|
|
|
|
|
|
|
|
def _add_speaker_and_signal(header, source, get_conversation=True): |
|
|
"""Add speaker and start/end signal on each round.""" |
|
|
BEGIN_SIGNAL = "### " |
|
|
END_SIGNAL = "\n" |
|
|
conversation = header |
|
|
for sentence in source: |
|
|
from_str = sentence["from"] |
|
|
if from_str.lower() == "human": |
|
|
from_str = conversation_lib.default_conversation.roles[0] |
|
|
elif from_str.lower() == "gpt": |
|
|
from_str = conversation_lib.default_conversation.roles[1] |
|
|
else: |
|
|
from_str = "unknown" |
|
|
sentence["value"] = ( |
|
|
BEGIN_SIGNAL + from_str + ": " + sentence["value"] + END_SIGNAL |
|
|
) |
|
|
if get_conversation: |
|
|
conversation += sentence["value"] |
|
|
conversation += BEGIN_SIGNAL |
|
|
return conversation |
|
|
|
|
|
|
|
|
def preprocess_multimodal(sources: Sequence[str], data_args: DataArguments) -> Dict: |
|
|
is_multimodal = data_args.is_multimodal |
|
|
if not is_multimodal: |
|
|
return sources |
|
|
|
|
|
for source in sources: |
|
|
for sentence in source: |
|
|
if DEFAULT_IMAGE_TOKEN in sentence["value"]: |
|
|
sentence["value"] = ( |
|
|
sentence["value"].replace(DEFAULT_IMAGE_TOKEN, "").strip() |
|
|
) |
|
|
sentence["value"] = DEFAULT_IMAGE_TOKEN + "\n" + sentence["value"] |
|
|
sentence["value"] = sentence["value"].strip() |
|
|
|
|
|
replace_token = DEFAULT_IMAGE_TOKEN |
|
|
sentence["value"] = sentence["value"].replace( |
|
|
DEFAULT_IMAGE_TOKEN, replace_token |
|
|
) |
|
|
|
|
|
return sources |
|
|
|
|
|
|
|
|
def preprocess_v1( |
|
|
sources, tokenizer: transformers.PreTrainedTokenizer, has_image: bool = False |
|
|
) -> Dict: |
|
|
conv = conversation_lib.default_conversation.copy() |
|
|
roles = {"human": conv.roles[0], "gpt": conv.roles[1]} |
|
|
|
|
|
|
|
|
conversations = [] |
|
|
for i, source in enumerate(sources): |
|
|
if roles[source[0]["from"]] != conv.roles[0]: |
|
|
|
|
|
source = source[1:] |
|
|
|
|
|
conv.messages = [] |
|
|
for j, sentence in enumerate(source): |
|
|
role = roles[sentence["from"]] |
|
|
assert role == conv.roles[j % 2], f"{i}" |
|
|
conv.append_message(role, sentence["value"]) |
|
|
conversations.append(conv.get_prompt()) |
|
|
|
|
|
|
|
|
|
|
|
if has_image: |
|
|
input_ids = torch.stack( |
|
|
[ |
|
|
tokenizer_image_token(prompt, tokenizer, return_tensors="pt") |
|
|
for prompt in conversations |
|
|
], |
|
|
dim=0, |
|
|
) |
|
|
else: |
|
|
input_ids = tokenizer( |
|
|
conversations, |
|
|
return_tensors="pt", |
|
|
padding="longest", |
|
|
max_length=tokenizer.model_max_length, |
|
|
truncation=True, |
|
|
).input_ids |
|
|
|
|
|
targets = input_ids.clone() |
|
|
|
|
|
assert ( |
|
|
conv.sep_style == conversation_lib.SeparatorStyle.TWO |
|
|
or conv.sep_style == conversation_lib.SeparatorStyle.TWO_NO_SYS |
|
|
) |
|
|
|
|
|
|
|
|
sep = conv.sep + conv.roles[1] + ": " |
|
|
for conversation, target in zip(conversations, targets): |
|
|
total_len = int(target.ne(tokenizer.pad_token_id).sum()) |
|
|
|
|
|
rounds = conversation.split(conv.sep2) |
|
|
cur_len = 1 + 1 |
|
|
target[:cur_len] = IGNORE_INDEX |
|
|
for i, rou in enumerate(rounds): |
|
|
if rou == "": |
|
|
break |
|
|
|
|
|
parts = rou.split(sep) |
|
|
if len(parts) != 2: |
|
|
break |
|
|
parts[0] += sep |
|
|
|
|
|
if has_image: |
|
|
round_len = len(tokenizer_image_token(rou, tokenizer)) |
|
|
instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 3 |
|
|
else: |
|
|
round_len = len(tokenizer(rou).input_ids) |
|
|
instruction_len = len(tokenizer(parts[0]).input_ids) - 2 |
|
|
round_len -= 1 |
|
|
target[cur_len : cur_len + instruction_len] = IGNORE_INDEX |
|
|
|
|
|
cur_len += round_len |
|
|
target[cur_len:] = IGNORE_INDEX |
|
|
|
|
|
if cur_len < tokenizer.model_max_length: |
|
|
if cur_len != total_len: |
|
|
target[:] = IGNORE_INDEX |
|
|
print( |
|
|
f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." |
|
|
f" (ignored)" |
|
|
) |
|
|
return dict( |
|
|
input_ids=input_ids, |
|
|
labels=targets, |
|
|
) |
|
|
|
|
|
|
|
|
def preprocess_plain( |
|
|
sources: Sequence[str], |
|
|
tokenizer: transformers.PreTrainedTokenizer, |
|
|
) -> Dict: |
|
|
|
|
|
conversations = [] |
|
|
for source in sources: |
|
|
assert len(source) == 2 |
|
|
assert DEFAULT_IMAGE_TOKEN in source[0]["value"] |
|
|
source[0]["value"] = DEFAULT_IMAGE_TOKEN |
|
|
conversation = ( |
|
|
source[0]["value"] |
|
|
+ source[1]["value"] |
|
|
+ conversation_lib.default_conversation.sep |
|
|
) |
|
|
conversations.append(conversation) |
|
|
|
|
|
input_ids = [ |
|
|
tokenizer_image_token(prompt, tokenizer, return_tensors="pt") |
|
|
for prompt in conversations |
|
|
] |
|
|
targets = copy.deepcopy(input_ids) |
|
|
for target, source in zip(targets, sources): |
|
|
tokenized_len = len(tokenizer_image_token(source[0]["value"], tokenizer)) |
|
|
target[:tokenized_len] = IGNORE_INDEX |
|
|
|
|
|
return dict(input_ids=input_ids, labels=targets) |
|
|
|
|
|
|
|
|
def preprocess( |
|
|
sources: Sequence[str], |
|
|
tokenizer: transformers.PreTrainedTokenizer, |
|
|
has_image: bool = False, |
|
|
) -> Dict: |
|
|
""" |
|
|
Given a list of sources, each is a conversation list. This transform: |
|
|
1. Add signal '### ' at the beginning each sentence, with end signal '\n'; |
|
|
2. Concatenate conversations together; |
|
|
3. Tokenize the concatenated conversation; |
|
|
4. Make a deepcopy as the target. Mask human words with IGNORE_INDEX. |
|
|
""" |
|
|
if ( |
|
|
conversation_lib.default_conversation.sep_style |
|
|
== conversation_lib.SeparatorStyle.PLAIN |
|
|
): |
|
|
return preprocess_plain(sources, tokenizer) |
|
|
if conversation_lib.default_conversation.version.startswith("v1"): |
|
|
return preprocess_v1(sources, tokenizer, has_image=has_image) |
|
|
|
|
|
conversations = [] |
|
|
for source in sources: |
|
|
header = f"{conversation_lib.default_conversation.system}\n\n" |
|
|
conversation = _add_speaker_and_signal(header, source) |
|
|
conversations.append(conversation) |
|
|
|
|
|
|
|
|
def get_tokenize_len(prompts): |
|
|
return [len(tokenizer_image_token(prompt, tokenizer)) for prompt in prompts] |
|
|
|
|
|
if has_image: |
|
|
input_ids = [ |
|
|
tokenizer_image_token(prompt, tokenizer, return_tensors="pt") |
|
|
for prompt in conversations |
|
|
] |
|
|
else: |
|
|
conversations_tokenized = _tokenize_fn(conversations, tokenizer) |
|
|
input_ids = conversations_tokenized["input_ids"] |
|
|
|
|
|
targets = copy.deepcopy(input_ids) |
|
|
for target, source in zip(targets, sources): |
|
|
if has_image: |
|
|
tokenized_lens = get_tokenize_len([header] + [s["value"] for s in source]) |
|
|
else: |
|
|
tokenized_lens = _tokenize_fn( |
|
|
[header] + [s["value"] for s in source], tokenizer |
|
|
)["input_ids_lens"] |
|
|
speakers = [sentence["from"] for sentence in source] |
|
|
_mask_targets(target, tokenized_lens, speakers) |
|
|
|
|
|
return dict(input_ids=input_ids, labels=targets) |
|
|
|
|
|
|
|
|
def load_video(video_file): |
|
|
from decord import VideoReader |
|
|
|
|
|
vr = VideoReader(video_file) |
|
|
|
|
|
|
|
|
fps = vr.get_avg_fps() |
|
|
|
|
|
|
|
|
frame_indices = [int(fps * i) for i in range(int(len(vr) / fps))] |
|
|
frames = vr.get_batch(frame_indices).asnumpy() |
|
|
return [Image.fromarray(frames[i]) for i in range(int(len(vr) / fps))] |
|
|
|
|
|
|
|
|
def expand2square(pil_img, background_color): |
|
|
width, height = pil_img.size |
|
|
if width == height: |
|
|
return pil_img |
|
|
elif width > height: |
|
|
result = Image.new(pil_img.mode, (width, width), background_color) |
|
|
result.paste(pil_img, (0, (width - height) // 2)) |
|
|
return result |
|
|
else: |
|
|
result = Image.new(pil_img.mode, (height, height), background_color) |
|
|
result.paste(pil_img, ((height - width) // 2, 0)) |
|
|
return result |
|
|
|