zhouyik's picture
Upload folder using huggingface_hub
032e687 verified
import io
import os
import random
import re
from typing import Dict
import cv2
import imageio
import numpy as np
import torch
import torchvision.transforms as T
import transformers
from PIL import Image
from torch.utils.data import ConcatDataset, WeightedRandomSampler
from torchvision.transforms.functional import InterpolationMode
from xtuner.utils import IGNORE_INDEX
IGNORE_TOKEN_ID = IGNORE_INDEX
from ..utils import (get_conv_template, IMG_CONTEXT_TOKEN, IMG_START_TOKEN,
IMG_END_TOKEN, )
try:
from petrel_client.client import Client
from petrel_client.common.config import Config
except ImportError as E:
print('petrel_client is not installed. If you read data locally instead of from ceph, ignore it.')
import sys
def preprocess(
template_name,
sources,
tokenizer: transformers.PreTrainedTokenizer,
num_image_token_list: list,
text_only: bool = False,
group_by_length: bool = False,
use_packed_ds: bool = False,
ds_name: str = None,
num_image: int = 1
) -> Dict:
conv = get_conv_template(template_name)
roles = {'human': conv.roles[0], 'gpt': conv.roles[1]}
# Apply prompt templates
conversations = []
for i, source in enumerate(sources):
if roles[source[0]['from']] != conv.roles[0]:
# Skip the first one if it is not from human
source = source[1:]
conv.messages = []
for j, sentence in enumerate(source):
role = roles[sentence['from']]
assert role == conv.roles[j % 2], f'{i}'
conv.append_message(role, sentence['value'])
conversations.append(conv.get_prompt())
if not text_only:
new_conversations = []
for conversation in conversations:
for i in range(num_image):
image_tokens = f'{IMG_START_TOKEN}{IMG_CONTEXT_TOKEN * num_image_token_list[i]}{IMG_END_TOKEN}'
conversation = conversation.replace('<image>', image_tokens, 1)
new_conversations.append(conversation)
conversations = new_conversations
# Tokenize conversations
input_ids = tokenizer(
conversations,
return_tensors='pt',
padding=False if group_by_length or use_packed_ds else 'max_length',
max_length=tokenizer.model_max_length,
truncation=True,
).input_ids
targets = input_ids.clone()
# assert conv.sep_style == SeparatorStyle.ADD_COLON_TWO
# Mask targets. Only compute loss on the assistant outputs.
sep = conv.sep + conv.roles[1] + ': '
for conversation, target in zip(conversations, targets):
total_len = int(target.ne(tokenizer.pad_token_id).sum())
turns = conversation.split(conv.sep2)
cur_len = 1
target[:cur_len] = IGNORE_TOKEN_ID
for i, turn in enumerate(turns):
if turn == '':
break
turn_len = len(tokenizer(turn).input_ids)
parts = turn.split(sep)
if len(parts) != 2:
break
parts[0] += sep
# "-2" is hardcoded for the Llama tokenizer to make the offset correct.
instruction_len = len(tokenizer(parts[0]).input_ids) - 2
if i != 0 and not tokenizer.legacy:
# The legacy and non-legacy modes handle special tokens differently
instruction_len -= 1
# Ignore the user instructions
target[cur_len: cur_len + instruction_len] = IGNORE_TOKEN_ID
cur_len += turn_len
if i != 0 and not tokenizer.legacy:
# The legacy and non-legacy modes handle special tokens differently
cur_len -= 1
target[cur_len:] = IGNORE_TOKEN_ID
if False: # Inspect and check the correctness of masking
z = target.clone()
z = torch.where(z == IGNORE_TOKEN_ID, tokenizer.unk_token_id, z)
logger.info(tokenizer.decode(z))
exit()
if cur_len < tokenizer.model_max_length:
if cur_len != total_len:
target[:] = IGNORE_TOKEN_ID
print(
f'WARNING: tokenization mismatch: {cur_len} vs. {total_len}.'
f' #turn = {len(turns) - 1}. (ignored). This dataset is {ds_name}.'
)
sys.stdout.flush()
return dict(
input_ids=input_ids,
labels=targets,
attention_mask=input_ids.ne(tokenizer.pad_token_id),
)
def preprocess_mpt(
template_name,
sources,
tokenizer: transformers.PreTrainedTokenizer,
num_image_token_list: list,
text_only: bool = False,
group_by_length: bool = False,
use_packed_ds: bool = False,
ds_name: str = None,
num_image: int = 1
) -> Dict:
conv = get_conv_template(template_name)
roles = {'human': conv.roles[0], 'gpt': conv.roles[1]}
# Apply prompt templates
conversations = []
for i, source in enumerate(sources):
if roles[source[0]['from']] != conv.roles[0]:
# Skip the first one if it is not from human
source = source[1:]
conv.messages = []
for j, sentence in enumerate(source):
role = roles[sentence['from']]
assert role == conv.roles[j % 2], f'{i}'
conv.append_message(role, sentence['value'])
conversations.append(conv.get_prompt())
if not text_only:
new_conversations = []
for conversation in conversations:
for i in range(num_image):
image_tokens = f'{IMG_START_TOKEN}{IMG_CONTEXT_TOKEN * num_image_token_list[i]}{IMG_END_TOKEN}'
conversation = conversation.replace('<image>', image_tokens, 1)
new_conversations.append(conversation)
conversations = new_conversations
# Tokenize conversations
input_ids = tokenizer(
conversations,
return_tensors='pt',
padding=False if group_by_length or use_packed_ds else 'max_length',
max_length=tokenizer.model_max_length,
truncation=True,
).input_ids
targets = input_ids.clone()
# Mask targets. Only compute loss on the assistant outputs.
sep = conv.sep + conv.roles[1] # <|im_end|><|im_start|>assistant\n
for conversation, target in zip(conversations, targets):
total_len = int(target.ne(tokenizer.pad_token_id).sum())
turns = conversation.split(conv.sep)
re_turns = [conv.sep.join(turns[:3])] # system + user + gpt
for conv_idx in range(3, len(turns), 2):
re_turns.append(conv.sep.join(turns[conv_idx:conv_idx + 2])) # user + gpt
cur_len = 0
target[:cur_len] = IGNORE_TOKEN_ID
for i, turn in enumerate(re_turns):
if turn == '':
break
turn_len = len(tokenizer(turn).input_ids) + 1
parts = turn.split(sep)
if len(parts) != 2:
break
parts[0] += sep
instruction_len = len(tokenizer(parts[0]).input_ids)
# Ignore the user instructions
target[cur_len: cur_len + instruction_len] = IGNORE_TOKEN_ID
# print(f'[question {i}]', tokenizer.decode(input_ids[:, cur_len: cur_len + instruction_len][0]))
# print(f'[answer {i}]', tokenizer.decode(input_ids[:, cur_len + instruction_len: cur_len + turn_len][0]))
# print(f'[label {i}]', target[cur_len + instruction_len: cur_len + turn_len])
cur_len += turn_len
target[cur_len:] = IGNORE_TOKEN_ID
if cur_len < tokenizer.model_max_length:
if cur_len != total_len:
target[:] = IGNORE_TOKEN_ID
print(
f'WARNING: tokenization mismatch: {cur_len} vs. {total_len}.'
f' #turn = {len(turns) - 1}. (ignored). This dataset is {ds_name}.'
)
sys.stdout.flush()
return dict(
input_ids=input_ids,
labels=targets,
attention_mask=input_ids.ne(tokenizer.pad_token_id),
)
def preprocess_phi3(
template_name,
sources,
tokenizer: transformers.PreTrainedTokenizer,
num_image_token_list: list,
text_only: bool = False,
group_by_length: bool = False,
use_packed_ds: bool = False,
ds_name: str = None,
num_image: int = 1
) -> Dict:
conv = get_conv_template(template_name)
roles = {'human': conv.roles[0], 'gpt': conv.roles[1]}
# Apply prompt templates
conversations = []
for i, source in enumerate(sources):
if roles[source[0]['from']] != conv.roles[0]:
# Skip the first one if it is not from human
source = source[1:]
conv.messages = []
for j, sentence in enumerate(source):
role = roles[sentence['from']]
assert role == conv.roles[j % 2], f'{i}'
conv.append_message(role, sentence['value'])
conversations.append(conv.get_prompt())
if not text_only:
new_conversations = []
for conversation in conversations:
for i in range(num_image):
image_tokens = f'{IMG_START_TOKEN}{IMG_CONTEXT_TOKEN * num_image_token_list[i]}{IMG_END_TOKEN}'
conversation = conversation.replace('<image>', image_tokens, 1)
new_conversations.append(conversation)
conversations = new_conversations
# Tokenize conversations
tokenizer.padding_side = 'right'
input_ids = tokenizer(
conversations,
return_tensors='pt',
padding=False if group_by_length or use_packed_ds else 'max_length',
max_length=tokenizer.model_max_length,
truncation=True,
).input_ids
targets = input_ids.clone()
# Mask targets. Only compute loss on the assistant outputs.
sep = conv.sep + conv.roles[1] # <|end|>\n<|assistant|>
for conversation, target in zip(conversations, targets):
total_len = int(target.ne(int(tokenizer.pad_token_id)).sum())
turns = conversation.split(conv.sep)
re_turns = [conv.sep.join(turns[:3])] # system + user + gpt
for conv_idx in range(3, len(turns), 2):
re_turns.append(conv.sep.join(turns[conv_idx:conv_idx + 2])) # user + gpt
cur_len = 1
target[:cur_len] = IGNORE_TOKEN_ID
endoftext_id = tokenizer.convert_tokens_to_ids('<|endoftext|>')
target[target == endoftext_id] = IGNORE_TOKEN_ID
for i, turn in enumerate(re_turns):
if turn == '':
break
if i == 0:
turn_len = len(tokenizer(turn).input_ids)
else:
turn_len = len(tokenizer(turn).input_ids) - 1
parts = turn.split(sep)
if len(parts) != 2:
break
parts[0] += sep
if i == 0:
instruction_len = len(tokenizer(parts[0]).input_ids) - 1
else:
instruction_len = len(tokenizer(parts[0]).input_ids) - 2
# Ignore the user instructions
target[cur_len: cur_len + instruction_len] = IGNORE_TOKEN_ID
# print(f'[question {i}]', tokenizer.decode(input_ids[:, cur_len: cur_len + instruction_len][0]))
# print(f'[answer {i}]', tokenizer.decode(input_ids[:, cur_len + instruction_len: cur_len + turn_len][0]))
# print(f'[label {i}]', target[cur_len + instruction_len: cur_len + turn_len])
cur_len += turn_len
target[cur_len:] = IGNORE_TOKEN_ID
if False: # Inspect and check the correctness of masking
z = target.clone()
z = torch.where(z == IGNORE_TOKEN_ID, tokenizer.unk_token_id, z)
print(repr(tokenizer.decode(z)))
if cur_len < tokenizer.model_max_length:
if cur_len != total_len:
target[:] = IGNORE_TOKEN_ID
print(
f'WARNING: tokenization mismatch: {cur_len} vs. {total_len}.'
f' #turn = {len(turns) - 1}. (ignored). This dataset is {ds_name}.'
)
sys.stdout.flush()
return dict(
input_ids=input_ids,
labels=targets,
attention_mask=input_ids.ne(tokenizer.pad_token_id),
)
def preprocess_internlm(
template_name,
sources,
tokenizer: transformers.PreTrainedTokenizer,
num_image_token_list: list,
text_only: bool = False,
group_by_length: bool = False,
use_packed_ds: bool = False,
ds_name: str = None,
num_image: int = 1
) -> Dict:
conv = get_conv_template(template_name)
roles = {'human': conv.roles[0], 'gpt': conv.roles[1]}
# Apply prompt templates
conversations = []
for i, source in enumerate(sources):
if roles[source[0]['from']] != conv.roles[0]:
# Skip the first one if it is not from human
source = source[1:]
conv.messages = []
for j, sentence in enumerate(source):
role = roles[sentence['from']]
assert role == conv.roles[j % 2], f'{i}'
sentence['value'] = sentence['value'].strip()
conv.append_message(role, sentence['value'])
conversations.append(conv.get_prompt())
if not text_only:
new_conversations = []
for conversation in conversations:
for i in range(num_image):
image_tokens = f'{IMG_START_TOKEN}{IMG_CONTEXT_TOKEN * num_image_token_list[i]}{IMG_END_TOKEN}'
conversation = conversation.replace('<image>', image_tokens, 1)
new_conversations.append(conversation)
conversations = new_conversations
# Tokenize conversations
input_ids = tokenizer(
conversations,
return_tensors='pt',
padding=False if group_by_length or use_packed_ds else 'max_length',
max_length=tokenizer.model_max_length,
truncation=True,
).input_ids
targets = input_ids.clone()
for conversation, target in zip(conversations, targets):
total_len = int(target.ne(tokenizer.pad_token_id).sum()) # 浦语里面 pad_token_id = eos_token_id
cur_len = 1
target[:cur_len] = IGNORE_TOKEN_ID # <s>
parts = conversation.split(conv.roles[1]) # [UNUSED_TOKEN_146]assistant\n
info = parts[0] + conv.roles[1]
temp_len = len(tokenizer(info).input_ids) - 1 # 去除tokenizer的<s>
target[cur_len: cur_len + temp_len] = IGNORE_TOKEN_ID
cur_len = cur_len + temp_len
for index in range(1, len(parts) - 1):
info = parts[index]
part1, part2 = info.split(conv.roles[0])
temp_len = len(tokenizer(part1).input_ids) - 1
cur_len = cur_len + temp_len
part = conv.roles[0] + part2 + conv.roles[1]
temp_len = len(tokenizer(part).input_ids) - 1
target[cur_len: cur_len + temp_len] = IGNORE_TOKEN_ID
cur_len = cur_len + temp_len
last_info = parts[-1]
temp_len = len(tokenizer(last_info).input_ids) - 1
cur_len = cur_len + temp_len
target[cur_len:] = IGNORE_TOKEN_ID
if False: # Inspect and check the correctness of masking
z = target.clone()
z = torch.where(z == IGNORE_TOKEN_ID, tokenizer.unk_token_id, z)
print(repr(tokenizer.decode(z)))
if cur_len < tokenizer.model_max_length:
if cur_len != total_len:
target[:] = IGNORE_TOKEN_ID
print(f'WARNING: tokenization mismatch: {cur_len} vs. {total_len}. This dataset is {ds_name}.')
sys.stdout.flush()
return dict(
input_ids=input_ids,
labels=targets,
attention_mask=input_ids.ne(tokenizer.pad_token_id),
)