| | import datasets |
| | import os |
| | import json |
| | import torch |
| | import random |
| | import time |
| | random.seed(time.time()) |
| | import logging |
| | from tqdm import tqdm |
| | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') |
| |
|
| | def verify_jsonl_files(data_files): |
| | """检查每个 jsonl 文件的有效性""" |
| | invalid_files = [] |
| | |
| | for file_path in tqdm(data_files, desc="验证文件"): |
| | try: |
| | with open(file_path, 'r', encoding='utf-8') as f: |
| | for i, line in enumerate(f): |
| | try: |
| | json.loads(line) |
| | except json.JSONDecodeError: |
| | invalid_files.append((file_path, i+1)) |
| | logging.error(f"文件 {file_path} 在第 {i+1} 行有无效的 JSON") |
| | break |
| | except Exception as e: |
| | invalid_files.append((file_path, f"读取错误: {str(e)}")) |
| | logging.error(f"无法读取文件 {file_path}: {str(e)}") |
| | |
| | return invalid_files |
| | def load_jsonl_dataset(directory,tokenizer): |
| | ''' |
| | load jsonl files in a directory recursively |
| | ''' |
| | data_files = [] |
| | for root, dirs, files in os.walk(directory): |
| | for file in files: |
| | if file.endswith('.jsonl'): |
| | data_files.append(os.path.join(root, file)) |
| | |
| | logging.info(f"找到 {len(data_files)} 个 JSONL 文件") |
| | |
| | invalid_files = verify_jsonl_files(data_files) |
| | if invalid_files: |
| | logging.error(f"发现 {len(invalid_files)} 个无效文件:") |
| | for file_info in invalid_files: |
| | if isinstance(file_info[1], int): |
| | logging.error(f" - {file_info[0]} (错误在第 {file_info[1]} 行)") |
| | else: |
| | logging.error(f" - {file_info[0]} ({file_info[1]})") |
| | |
| | |
| | valid_files = [f for f in data_files if f not in [info[0] for info in invalid_files]] |
| | logging.info(f"继续处理剩余的 {len(valid_files)} 个有效文件") |
| | data_files = valid_files |
| | |
| | all_samples = [] |
| | |
| | for file_path in tqdm(data_files, desc="加载数据集"): |
| | try: |
| | |
| | with open(file_path, 'r', encoding='utf-8') as f: |
| | for line in f: |
| | try: |
| | data = json.loads(line) |
| | |
| | llm_prompt_speech_token = data.get('llm_prompt_speech_token', []) |
| | tts_speech_tokens = data.get('tts_speech_tokens', []) |
| | text = str(data.get('text', "")) |
| | prompt_text = str(data.get('prompt_text', "")) |
| | |
| | |
| | if not isinstance(llm_prompt_speech_token, list): |
| | llm_prompt_speech_token = [] |
| | if not isinstance(tts_speech_tokens, list): |
| | tts_speech_tokens = [] |
| | |
| | |
| | all_samples.append({ |
| | 'llm_prompt_speech_token': llm_prompt_speech_token, |
| | 'tts_speech_tokens': tts_speech_tokens, |
| | 'text': text, |
| | 'prompt_text': prompt_text |
| | }) |
| | except json.JSONDecodeError: |
| | continue |
| | except Exception as e: |
| | logging.error(f"处理样本时出错: {str(e)}") |
| | except Exception as e: |
| | logging.error(f"打开文件 {file_path} 时出错: {str(e)}") |
| | |
| | if not all_samples: |
| | raise ValueError("没有成功加载任何样本") |
| | |
| | |
| | logging.info(f"手动创建数据集,包含 {len(all_samples)} 个样本") |
| | dataset = datasets.Dataset.from_list(all_samples) |
| | |
| | logging.info(f"成功加载 {len(dataset)} 个样本") |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | dataset = dataset.filter(lambda x:len(x['llm_prompt_speech_token']) < 2048 and len(x['tts_speech_tokens']) < 2048 |
| | and len(tokenizer.encode(x['text'])) < 2048 and len(tokenizer.encode(x['prompt_text'])) < 2048 ) |
| | logging.info(f"过滤后剩余 {len(dataset)} 个样本") |
| | |
| | |
| | return dataset |
| |
|
| | def collate_fn(batch, tokenizer, pad_to_max_length=True, max_length=2048, drop_prompt_audio_rate=-0.1): |
| | ''' |
| | convert the data to torch tensors |
| | 1. call tokenizer.encode('text') and tokenizer.encode('prompt_text'), concatenate them to get the text_token, record each sample's length to text_token_len |
| | 2. convert the text_tokens and text_token_len to torch tensor |
| | 3. record each sample's speech_token length to speech_token_len |
| | 4. convert the speech_token and speech_token_len to torch tensor |
| | 5. We will drop prompt with drop_prompt_audio_rate to ask model to learn generate audio without guaidance |
| | By default we won't drop anything |
| | ''' |
| | all_text_tokens = [] |
| | all_speech_tokens = [] |
| | speech_token_len = [] |
| | text_token_len = [] |
| | my_max_length = 0 |
| | is_drop_prompt = random.random() < drop_prompt_audio_rate |
| | |
| | for sample in batch: |
| | tts_speech_tokens = sample['tts_speech_tokens'] |
| | llm_prompt_speech_token = sample['llm_prompt_speech_token'] |
| | |
| | if is_drop_prompt: |
| | |
| | text_tokens = tokenizer.encode(sample['text']) |
| | all_text_tokens.append(torch.tensor(text_tokens, dtype=torch.int32)) |
| | text_token_len.append(len(text_tokens)) |
| | |
| | |
| | current_speech_tokens = tts_speech_tokens |
| | all_speech_tokens.append(torch.tensor(current_speech_tokens, dtype=torch.int32)) |
| | speech_token_len.append(len(current_speech_tokens)) |
| | |
| | total_length = len(text_tokens) + len(current_speech_tokens) |
| | else: |
| | |
| | text_tokens = tokenizer.encode(sample['text']) |
| | prompt_tokens = tokenizer.encode(sample['prompt_text']) |
| | combined_text_tokens = prompt_tokens + text_tokens |
| | all_text_tokens.append(torch.tensor(combined_text_tokens, dtype=torch.int32)) |
| | text_token_len.append(len(combined_text_tokens)) |
| | |
| | |
| | current_speech_tokens = llm_prompt_speech_token + tts_speech_tokens |
| | all_speech_tokens.append(torch.tensor(current_speech_tokens, dtype=torch.int32)) |
| | speech_token_len.append(len(current_speech_tokens)) |
| | |
| | total_length = len(combined_text_tokens) + len(current_speech_tokens) |
| | |
| | if total_length > my_max_length: |
| | my_max_length = total_length |
| | |
| | |
| | skip = my_max_length > max_length |
| | |
| | |
| | all_text_tokens = torch.nn.utils.rnn.pad_sequence(all_text_tokens, batch_first=True, padding_value=0) |
| | all_speech_tokens = torch.nn.utils.rnn.pad_sequence(all_speech_tokens, batch_first=True, padding_value=0) |
| | |
| | |
| | if pad_to_max_length and not skip: |
| | pad_length = max_length - my_max_length |
| | if pad_length > 0: |
| | all_speech_tokens = torch.nn.functional.pad(all_speech_tokens, (0, pad_length), value=0) |
| | |
| | return { |
| | 'text_token': all_text_tokens, |
| | 'text_token_len': torch.tensor(text_token_len, dtype=torch.int32), |
| | 'speech_token': all_speech_tokens, |
| | 'speech_token_len': torch.tensor(speech_token_len, dtype=torch.int32), |
| | 'skip': skip |
| | } |
| | |
| | |
| | if __name__ == '__main__': |
| | from transformers import AutoTokenizer |
| | model_path = "/external_data/models/rwkv7-2.9B-world" |
| | tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) |
| | directory = '/external_data/yueyudata/speech_corpus' |
| | dataset = load_jsonl_dataset(directory,tokenizer) |
| | print(dataset) |
| | print(dataset[0]) |
| | from functools import partial |
| | collate_fn = partial(collate_fn,tokenizer=tokenizer,pad_to_max_length=False) |
| | dataloader = torch.utils.data.DataLoader(dataset,batch_size=1,collate_fn=collate_fn) |
| | for data in dataloader: |
| | print(data) |
| | print(data['speech_token'].shape) |
| | print(data['text_token'].shape) |
| | break |