|
|
import io |
|
|
import os |
|
|
import random |
|
|
import re |
|
|
from typing import Dict |
|
|
|
|
|
import cv2 |
|
|
import imageio |
|
|
import numpy as np |
|
|
import torch |
|
|
import torchvision.transforms as T |
|
|
import transformers |
|
|
from PIL import Image |
|
|
from torch.utils.data import ConcatDataset, WeightedRandomSampler |
|
|
from torchvision.transforms.functional import InterpolationMode |
|
|
from xtuner.utils import IGNORE_INDEX |
|
|
IGNORE_TOKEN_ID = IGNORE_INDEX |
|
|
|
|
|
from ..utils import (get_conv_template, IMG_CONTEXT_TOKEN, IMG_START_TOKEN, |
|
|
IMG_END_TOKEN, ) |
|
|
|
|
|
try: |
|
|
from petrel_client.client import Client |
|
|
from petrel_client.common.config import Config |
|
|
except ImportError as E: |
|
|
print('petrel_client is not installed. If you read data locally instead of from ceph, ignore it.') |
|
|
import sys |
|
|
|
|
|
|
|
|
def preprocess( |
|
|
template_name, |
|
|
sources, |
|
|
tokenizer: transformers.PreTrainedTokenizer, |
|
|
num_image_token_list: list, |
|
|
text_only: bool = False, |
|
|
group_by_length: bool = False, |
|
|
use_packed_ds: bool = False, |
|
|
ds_name: str = None, |
|
|
num_image: int = 1 |
|
|
) -> Dict: |
|
|
conv = get_conv_template(template_name) |
|
|
roles = {'human': conv.roles[0], 'gpt': conv.roles[1]} |
|
|
|
|
|
|
|
|
conversations = [] |
|
|
for i, source in enumerate(sources): |
|
|
if roles[source[0]['from']] != conv.roles[0]: |
|
|
|
|
|
source = source[1:] |
|
|
|
|
|
conv.messages = [] |
|
|
for j, sentence in enumerate(source): |
|
|
role = roles[sentence['from']] |
|
|
assert role == conv.roles[j % 2], f'{i}' |
|
|
conv.append_message(role, sentence['value']) |
|
|
conversations.append(conv.get_prompt()) |
|
|
|
|
|
if not text_only: |
|
|
new_conversations = [] |
|
|
for conversation in conversations: |
|
|
for i in range(num_image): |
|
|
image_tokens = f'{IMG_START_TOKEN}{IMG_CONTEXT_TOKEN * num_image_token_list[i]}{IMG_END_TOKEN}' |
|
|
conversation = conversation.replace('<image>', image_tokens, 1) |
|
|
new_conversations.append(conversation) |
|
|
conversations = new_conversations |
|
|
|
|
|
|
|
|
input_ids = tokenizer( |
|
|
conversations, |
|
|
return_tensors='pt', |
|
|
padding=False if group_by_length or use_packed_ds else 'max_length', |
|
|
max_length=tokenizer.model_max_length, |
|
|
truncation=True, |
|
|
).input_ids |
|
|
targets = input_ids.clone() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sep = conv.sep + conv.roles[1] + ': ' |
|
|
for conversation, target in zip(conversations, targets): |
|
|
total_len = int(target.ne(tokenizer.pad_token_id).sum()) |
|
|
|
|
|
turns = conversation.split(conv.sep2) |
|
|
cur_len = 1 |
|
|
target[:cur_len] = IGNORE_TOKEN_ID |
|
|
for i, turn in enumerate(turns): |
|
|
if turn == '': |
|
|
break |
|
|
turn_len = len(tokenizer(turn).input_ids) |
|
|
|
|
|
parts = turn.split(sep) |
|
|
if len(parts) != 2: |
|
|
break |
|
|
parts[0] += sep |
|
|
|
|
|
instruction_len = len(tokenizer(parts[0]).input_ids) - 2 |
|
|
|
|
|
if i != 0 and not tokenizer.legacy: |
|
|
|
|
|
instruction_len -= 1 |
|
|
|
|
|
|
|
|
target[cur_len: cur_len + instruction_len] = IGNORE_TOKEN_ID |
|
|
cur_len += turn_len |
|
|
|
|
|
if i != 0 and not tokenizer.legacy: |
|
|
|
|
|
cur_len -= 1 |
|
|
|
|
|
target[cur_len:] = IGNORE_TOKEN_ID |
|
|
|
|
|
if False: |
|
|
z = target.clone() |
|
|
z = torch.where(z == IGNORE_TOKEN_ID, tokenizer.unk_token_id, z) |
|
|
logger.info(tokenizer.decode(z)) |
|
|
exit() |
|
|
|
|
|
if cur_len < tokenizer.model_max_length: |
|
|
if cur_len != total_len: |
|
|
target[:] = IGNORE_TOKEN_ID |
|
|
print( |
|
|
f'WARNING: tokenization mismatch: {cur_len} vs. {total_len}.' |
|
|
f' #turn = {len(turns) - 1}. (ignored). This dataset is {ds_name}.' |
|
|
) |
|
|
sys.stdout.flush() |
|
|
|
|
|
return dict( |
|
|
input_ids=input_ids, |
|
|
labels=targets, |
|
|
attention_mask=input_ids.ne(tokenizer.pad_token_id), |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def preprocess_mpt( |
|
|
template_name, |
|
|
sources, |
|
|
tokenizer: transformers.PreTrainedTokenizer, |
|
|
num_image_token_list: list, |
|
|
text_only: bool = False, |
|
|
group_by_length: bool = False, |
|
|
use_packed_ds: bool = False, |
|
|
ds_name: str = None, |
|
|
num_image: int = 1 |
|
|
) -> Dict: |
|
|
conv = get_conv_template(template_name) |
|
|
roles = {'human': conv.roles[0], 'gpt': conv.roles[1]} |
|
|
|
|
|
|
|
|
conversations = [] |
|
|
for i, source in enumerate(sources): |
|
|
if roles[source[0]['from']] != conv.roles[0]: |
|
|
|
|
|
source = source[1:] |
|
|
|
|
|
conv.messages = [] |
|
|
for j, sentence in enumerate(source): |
|
|
role = roles[sentence['from']] |
|
|
assert role == conv.roles[j % 2], f'{i}' |
|
|
conv.append_message(role, sentence['value']) |
|
|
conversations.append(conv.get_prompt()) |
|
|
|
|
|
if not text_only: |
|
|
new_conversations = [] |
|
|
for conversation in conversations: |
|
|
for i in range(num_image): |
|
|
image_tokens = f'{IMG_START_TOKEN}{IMG_CONTEXT_TOKEN * num_image_token_list[i]}{IMG_END_TOKEN}' |
|
|
conversation = conversation.replace('<image>', image_tokens, 1) |
|
|
new_conversations.append(conversation) |
|
|
conversations = new_conversations |
|
|
|
|
|
|
|
|
input_ids = tokenizer( |
|
|
conversations, |
|
|
return_tensors='pt', |
|
|
padding=False if group_by_length or use_packed_ds else 'max_length', |
|
|
max_length=tokenizer.model_max_length, |
|
|
truncation=True, |
|
|
).input_ids |
|
|
targets = input_ids.clone() |
|
|
|
|
|
|
|
|
sep = conv.sep + conv.roles[1] |
|
|
for conversation, target in zip(conversations, targets): |
|
|
total_len = int(target.ne(tokenizer.pad_token_id).sum()) |
|
|
|
|
|
turns = conversation.split(conv.sep) |
|
|
re_turns = [conv.sep.join(turns[:3])] |
|
|
for conv_idx in range(3, len(turns), 2): |
|
|
re_turns.append(conv.sep.join(turns[conv_idx:conv_idx + 2])) |
|
|
cur_len = 0 |
|
|
target[:cur_len] = IGNORE_TOKEN_ID |
|
|
for i, turn in enumerate(re_turns): |
|
|
if turn == '': |
|
|
break |
|
|
turn_len = len(tokenizer(turn).input_ids) + 1 |
|
|
|
|
|
parts = turn.split(sep) |
|
|
if len(parts) != 2: |
|
|
break |
|
|
parts[0] += sep |
|
|
instruction_len = len(tokenizer(parts[0]).input_ids) |
|
|
|
|
|
|
|
|
target[cur_len: cur_len + instruction_len] = IGNORE_TOKEN_ID |
|
|
|
|
|
|
|
|
|
|
|
cur_len += turn_len |
|
|
|
|
|
target[cur_len:] = IGNORE_TOKEN_ID |
|
|
|
|
|
if cur_len < tokenizer.model_max_length: |
|
|
if cur_len != total_len: |
|
|
target[:] = IGNORE_TOKEN_ID |
|
|
print( |
|
|
f'WARNING: tokenization mismatch: {cur_len} vs. {total_len}.' |
|
|
f' #turn = {len(turns) - 1}. (ignored). This dataset is {ds_name}.' |
|
|
) |
|
|
sys.stdout.flush() |
|
|
|
|
|
return dict( |
|
|
input_ids=input_ids, |
|
|
labels=targets, |
|
|
attention_mask=input_ids.ne(tokenizer.pad_token_id), |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
def preprocess_phi3( |
|
|
template_name, |
|
|
sources, |
|
|
tokenizer: transformers.PreTrainedTokenizer, |
|
|
num_image_token_list: list, |
|
|
text_only: bool = False, |
|
|
group_by_length: bool = False, |
|
|
use_packed_ds: bool = False, |
|
|
ds_name: str = None, |
|
|
num_image: int = 1 |
|
|
) -> Dict: |
|
|
conv = get_conv_template(template_name) |
|
|
roles = {'human': conv.roles[0], 'gpt': conv.roles[1]} |
|
|
|
|
|
|
|
|
conversations = [] |
|
|
for i, source in enumerate(sources): |
|
|
if roles[source[0]['from']] != conv.roles[0]: |
|
|
|
|
|
source = source[1:] |
|
|
|
|
|
conv.messages = [] |
|
|
for j, sentence in enumerate(source): |
|
|
role = roles[sentence['from']] |
|
|
assert role == conv.roles[j % 2], f'{i}' |
|
|
conv.append_message(role, sentence['value']) |
|
|
conversations.append(conv.get_prompt()) |
|
|
|
|
|
if not text_only: |
|
|
new_conversations = [] |
|
|
for conversation in conversations: |
|
|
for i in range(num_image): |
|
|
image_tokens = f'{IMG_START_TOKEN}{IMG_CONTEXT_TOKEN * num_image_token_list[i]}{IMG_END_TOKEN}' |
|
|
conversation = conversation.replace('<image>', image_tokens, 1) |
|
|
new_conversations.append(conversation) |
|
|
conversations = new_conversations |
|
|
|
|
|
|
|
|
tokenizer.padding_side = 'right' |
|
|
input_ids = tokenizer( |
|
|
conversations, |
|
|
return_tensors='pt', |
|
|
padding=False if group_by_length or use_packed_ds else 'max_length', |
|
|
max_length=tokenizer.model_max_length, |
|
|
truncation=True, |
|
|
).input_ids |
|
|
targets = input_ids.clone() |
|
|
|
|
|
|
|
|
sep = conv.sep + conv.roles[1] |
|
|
for conversation, target in zip(conversations, targets): |
|
|
total_len = int(target.ne(int(tokenizer.pad_token_id)).sum()) |
|
|
|
|
|
turns = conversation.split(conv.sep) |
|
|
re_turns = [conv.sep.join(turns[:3])] |
|
|
for conv_idx in range(3, len(turns), 2): |
|
|
re_turns.append(conv.sep.join(turns[conv_idx:conv_idx + 2])) |
|
|
cur_len = 1 |
|
|
target[:cur_len] = IGNORE_TOKEN_ID |
|
|
endoftext_id = tokenizer.convert_tokens_to_ids('<|endoftext|>') |
|
|
target[target == endoftext_id] = IGNORE_TOKEN_ID |
|
|
|
|
|
for i, turn in enumerate(re_turns): |
|
|
if turn == '': |
|
|
break |
|
|
if i == 0: |
|
|
turn_len = len(tokenizer(turn).input_ids) |
|
|
else: |
|
|
turn_len = len(tokenizer(turn).input_ids) - 1 |
|
|
parts = turn.split(sep) |
|
|
if len(parts) != 2: |
|
|
break |
|
|
parts[0] += sep |
|
|
|
|
|
if i == 0: |
|
|
instruction_len = len(tokenizer(parts[0]).input_ids) - 1 |
|
|
else: |
|
|
instruction_len = len(tokenizer(parts[0]).input_ids) - 2 |
|
|
|
|
|
|
|
|
target[cur_len: cur_len + instruction_len] = IGNORE_TOKEN_ID |
|
|
|
|
|
|
|
|
|
|
|
cur_len += turn_len |
|
|
|
|
|
target[cur_len:] = IGNORE_TOKEN_ID |
|
|
|
|
|
if False: |
|
|
z = target.clone() |
|
|
z = torch.where(z == IGNORE_TOKEN_ID, tokenizer.unk_token_id, z) |
|
|
print(repr(tokenizer.decode(z))) |
|
|
|
|
|
if cur_len < tokenizer.model_max_length: |
|
|
if cur_len != total_len: |
|
|
target[:] = IGNORE_TOKEN_ID |
|
|
print( |
|
|
f'WARNING: tokenization mismatch: {cur_len} vs. {total_len}.' |
|
|
f' #turn = {len(turns) - 1}. (ignored). This dataset is {ds_name}.' |
|
|
) |
|
|
sys.stdout.flush() |
|
|
|
|
|
return dict( |
|
|
input_ids=input_ids, |
|
|
labels=targets, |
|
|
attention_mask=input_ids.ne(tokenizer.pad_token_id), |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
def preprocess_internlm( |
|
|
template_name, |
|
|
sources, |
|
|
tokenizer: transformers.PreTrainedTokenizer, |
|
|
num_image_token_list: list, |
|
|
text_only: bool = False, |
|
|
group_by_length: bool = False, |
|
|
use_packed_ds: bool = False, |
|
|
ds_name: str = None, |
|
|
num_image: int = 1 |
|
|
) -> Dict: |
|
|
conv = get_conv_template(template_name) |
|
|
roles = {'human': conv.roles[0], 'gpt': conv.roles[1]} |
|
|
|
|
|
|
|
|
conversations = [] |
|
|
for i, source in enumerate(sources): |
|
|
if roles[source[0]['from']] != conv.roles[0]: |
|
|
|
|
|
source = source[1:] |
|
|
|
|
|
conv.messages = [] |
|
|
for j, sentence in enumerate(source): |
|
|
role = roles[sentence['from']] |
|
|
assert role == conv.roles[j % 2], f'{i}' |
|
|
sentence['value'] = sentence['value'].strip() |
|
|
conv.append_message(role, sentence['value']) |
|
|
conversations.append(conv.get_prompt()) |
|
|
|
|
|
if not text_only: |
|
|
new_conversations = [] |
|
|
for conversation in conversations: |
|
|
for i in range(num_image): |
|
|
image_tokens = f'{IMG_START_TOKEN}{IMG_CONTEXT_TOKEN * num_image_token_list[i]}{IMG_END_TOKEN}' |
|
|
conversation = conversation.replace('<image>', image_tokens, 1) |
|
|
new_conversations.append(conversation) |
|
|
conversations = new_conversations |
|
|
|
|
|
|
|
|
input_ids = tokenizer( |
|
|
conversations, |
|
|
return_tensors='pt', |
|
|
padding=False if group_by_length or use_packed_ds else 'max_length', |
|
|
max_length=tokenizer.model_max_length, |
|
|
truncation=True, |
|
|
).input_ids |
|
|
targets = input_ids.clone() |
|
|
|
|
|
for conversation, target in zip(conversations, targets): |
|
|
total_len = int(target.ne(tokenizer.pad_token_id).sum()) |
|
|
cur_len = 1 |
|
|
target[:cur_len] = IGNORE_TOKEN_ID |
|
|
parts = conversation.split(conv.roles[1]) |
|
|
info = parts[0] + conv.roles[1] |
|
|
temp_len = len(tokenizer(info).input_ids) - 1 |
|
|
target[cur_len: cur_len + temp_len] = IGNORE_TOKEN_ID |
|
|
cur_len = cur_len + temp_len |
|
|
|
|
|
for index in range(1, len(parts) - 1): |
|
|
info = parts[index] |
|
|
part1, part2 = info.split(conv.roles[0]) |
|
|
temp_len = len(tokenizer(part1).input_ids) - 1 |
|
|
cur_len = cur_len + temp_len |
|
|
part = conv.roles[0] + part2 + conv.roles[1] |
|
|
temp_len = len(tokenizer(part).input_ids) - 1 |
|
|
target[cur_len: cur_len + temp_len] = IGNORE_TOKEN_ID |
|
|
cur_len = cur_len + temp_len |
|
|
last_info = parts[-1] |
|
|
temp_len = len(tokenizer(last_info).input_ids) - 1 |
|
|
cur_len = cur_len + temp_len |
|
|
|
|
|
target[cur_len:] = IGNORE_TOKEN_ID |
|
|
if False: |
|
|
z = target.clone() |
|
|
z = torch.where(z == IGNORE_TOKEN_ID, tokenizer.unk_token_id, z) |
|
|
print(repr(tokenizer.decode(z))) |
|
|
|
|
|
if cur_len < tokenizer.model_max_length: |
|
|
if cur_len != total_len: |
|
|
target[:] = IGNORE_TOKEN_ID |
|
|
print(f'WARNING: tokenization mismatch: {cur_len} vs. {total_len}. This dataset is {ds_name}.') |
|
|
sys.stdout.flush() |
|
|
|
|
|
return dict( |
|
|
input_ids=input_ids, |
|
|
labels=targets, |
|
|
attention_mask=input_ids.ne(tokenizer.pad_token_id), |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|