| """ | |
| Split long conversations based on certain max length. | |
| Usage: python3 -m fastchat.data.split_long_conversation \ | |
| --in sharegpt_clean.json \ | |
| --out sharegpt_split.json \ | |
| --model-name-or-path $<model-name> | |
| """ | |
| import argparse | |
| import json | |
| from typing import Dict, Sequence, Optional | |
| import transformers | |
| import tqdm | |
| from fastchat import conversation as conversation_lib | |
| def split_sample(sample, start_idx, end_idx): | |
| assert (end_idx - start_idx) % 2 == 0 | |
| return { | |
| "id": sample["id"] + "_" + str(start_idx), | |
| "conversations": sample["conversations"][start_idx:end_idx], | |
| } | |
| def split_contents(content, begin, end, tokenizer, max_length): | |
| """ | |
| Keep the maximum round of conversations within the max token length constraint | |
| """ | |
| content = content[begin:end] | |
| new_content = [] | |
| for sample in tqdm.tqdm(content): | |
| tokenized_lens = [] | |
| conversations = sample["conversations"] | |
| conversations = conversations[: len(conversations) // 2 * 2] | |
| for c in conversations: | |
| length = len(tokenizer(c["value"]).input_ids) + 5 | |
| tokenized_lens.append(length) | |
| start_idx = 0 | |
| cur_len = 0 | |
| sample | |
| assert len(conversations) % 2 == 0, f"id: {sample['id']}" | |
| for i in range(0, len(conversations), 2): | |
| tmp_len = tokenized_lens[i] + tokenized_lens[i + 1] | |
| if cur_len + tmp_len > max_length: | |
| new_content.append(split_sample(sample, start_idx, i)) | |
| start_idx = i | |
| cur_len = 0 | |
| elif i == len(conversations) - 2: | |
| new_content.append(split_sample(sample, start_idx, i + 2)) | |
| cur_len += tmp_len | |
| return new_content | |
| def filter_invalid_roles(content): | |
| new_content = [] | |
| for i, c in enumerate(content): | |
| roles = ["human", "gpt"] | |
| if len(c["conversations"]) <= 0: | |
| continue | |
| valid = True | |
| for j, s in enumerate(c["conversations"]): | |
| if s["from"] != roles[j % 2]: | |
| valid = False | |
| break | |
| if valid: | |
| new_content.append(c) | |
| return new_content | |
| def main(args): | |
| content = json.load(open(args.in_file, "r")) | |
| tokenizer = transformers.AutoTokenizer.from_pretrained( | |
| args.model_name_or_path, | |
| model_max_length=args.max_length, | |
| padding_side="right", | |
| use_fast=False, | |
| ) | |
| new_content = split_contents( | |
| content, args.begin, args.end, tokenizer, args.max_length | |
| ) | |
| new_content = filter_invalid_roles(new_content) | |
| print(f"total: {len(content)}, new: {len(new_content)}") | |
| json.dump(new_content, open(args.out_file, "w"), indent=2) | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--in-file", type=str, required=True) | |
| parser.add_argument("--out-file", type=str, default="sharegpt_split.json") | |
| parser.add_argument("--begin", type=int) | |
| parser.add_argument("--end", type=int) | |
| parser.add_argument("--model-name-or-path", type=str, required=True) | |
| parser.add_argument("--max-length", type=int, default=2048) | |
| args = parser.parse_args() | |
| main(args) | |