|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from unsloth import FastLanguageModel |
|
|
|
|
|
import os |
|
|
|
|
|
from datasets import load_dataset |
|
|
from huggingface_hub import login |
|
|
from loguru import logger |
|
|
from snac import SNAC |
|
|
from trl import SFTConfig, SFTTrainer |
|
|
import soundfile as sf |
|
|
import torch |
|
|
import wandb |
|
|
|
|
|
|
|
|
|
|
|
STAGE = 1 |
|
|
HUGGINGFACE_USERNAME = "" |
|
|
|
|
|
if STAGE == 1: |
|
|
|
|
|
BASE_MODEL = "canopylabs/3b-hi-pretrain-research_release" |
|
|
TARGET_MODULES = [ |
|
|
"q_proj", |
|
|
"k_proj", |
|
|
"v_proj", |
|
|
"o_proj", |
|
|
"up_proj", |
|
|
"down_proj", |
|
|
"gate_proj", |
|
|
"lm_head", |
|
|
"embed_tokens", |
|
|
] |
|
|
TRAIN_CSV_PATH = "data/data_stage_1.csv" |
|
|
VALID_CSV_PATH = "data/data_eval.csv" |
|
|
LR = 2e-4 |
|
|
EPOCHS = 2 |
|
|
MODEL_NAME = f"snorTTS-indicv0-stage-{STAGE}" |
|
|
else: |
|
|
BASE_MODEL = f"{HUGGINGFACE_USERNAME}/snorTTS-indicv0-stage-1" |
|
|
TARGET_MODULES = [ |
|
|
"q_proj", |
|
|
"k_proj", |
|
|
"v_proj", |
|
|
"o_proj", |
|
|
"up_proj", |
|
|
"down_proj", |
|
|
"gate_proj", |
|
|
] |
|
|
TRAIN_CSV_PATH = "data/data_stage_2.csv" |
|
|
VALID_CSV_PATH = "data/data_eval.csv" |
|
|
LR = 2e-4 |
|
|
EPOCHS = 2 |
|
|
MODEL_NAME = f"snorTTS-indicv0-stage-{STAGE}" |
|
|
|
|
|
TRAIN_NUM_SAMPLES = None |
|
|
EVAL_NUM_SAMPLES = 250 |
|
|
MAX_SEQ_LENGTH = 2048 |
|
|
PER_DEVICE_TRAIN_BATCH_SIZE = 8 |
|
|
GRADIENT_ACCUMULATION_STEPS = 4 |
|
|
WANDB_USERNAME = "" |
|
|
WANDB_PROJECT = MODEL_NAME |
|
|
WANDB_LOG_MODEL = "checkpoint" |
|
|
WANDB_RUN_NAME = f"{MODEL_NAME}-training" |
|
|
WANDB_RUN_ID = None |
|
|
SEED = 3407 |
|
|
HUGGINGFACE_TOKEN = "" |
|
|
WANDB_TOKEN = "" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
login(token=HUGGINGFACE_TOKEN) |
|
|
|
|
|
|
|
|
wandb.login(key=WANDB_TOKEN) |
|
|
|
|
|
|
|
|
os.environ["WANDB_PROJECT"] = WANDB_PROJECT |
|
|
os.environ["WANDB_LOG_MODEL"] = WANDB_LOG_MODEL |
|
|
|
|
|
|
|
|
model, tokenizer = FastLanguageModel.from_pretrained( |
|
|
model_name=BASE_MODEL, |
|
|
load_in_4bit=True, |
|
|
max_seq_length=MAX_SEQ_LENGTH, |
|
|
token=HUGGINGFACE_TOKEN, |
|
|
) |
|
|
logger.success(f"Loaded model: {BASE_MODEL}") |
|
|
|
|
|
|
|
|
tokeniser_length = 128256 |
|
|
|
|
|
start_of_text_id = 128000 |
|
|
end_of_text_id = 128009 |
|
|
start_of_speech_id = tokeniser_length + 1 |
|
|
end_of_speech_id = tokeniser_length + 2 |
|
|
start_of_human_id = tokeniser_length + 3 |
|
|
end_of_human_id = tokeniser_length + 4 |
|
|
start_of_ai_id = tokeniser_length + 5 |
|
|
end_of_ai_id = tokeniser_length + 6 |
|
|
pad_token_id = tokeniser_length + 7 |
|
|
audio_start_id = tokeniser_length + 10 |
|
|
|
|
|
start_of_text_token = tokenizer.decode([start_of_text_id]) |
|
|
end_of_text_token = tokenizer.decode([end_of_text_id]) |
|
|
start_of_speech_token = tokenizer.decode([start_of_speech_id]) |
|
|
end_of_speech_token = tokenizer.decode([end_of_speech_id]) |
|
|
start_of_human_token = tokenizer.decode([start_of_human_id]) |
|
|
end_of_human_token = tokenizer.decode([end_of_human_id]) |
|
|
start_of_ai_token = tokenizer.decode([start_of_ai_id]) |
|
|
end_of_ai_token = tokenizer.decode([end_of_ai_id]) |
|
|
pad_token = tokenizer.decode([pad_token_id]) |
|
|
audio_start_token = tokenizer.decode([audio_start_id]) |
|
|
|
|
|
logger.success("Load special tokens for the tokenizer.") |
|
|
|
|
|
|
|
|
tokenizer.pad_token = pad_token |
|
|
tokenizer.padding_side = "left" |
|
|
logger.success("Set padding token and padding side for the tokenizer.") |
|
|
|
|
|
|
|
|
model = FastLanguageModel.get_peft_model( |
|
|
model, |
|
|
r=192, |
|
|
target_modules=TARGET_MODULES, |
|
|
lora_alpha=384, |
|
|
random_state=SEED, |
|
|
) |
|
|
logger.success("Initialized parameter efficient fine-tuning model.") |
|
|
|
|
|
|
|
|
|
|
|
train_dataset = load_dataset("csv", data_files=TRAIN_CSV_PATH)["train"] |
|
|
eval_dataset = load_dataset("csv", data_files=VALID_CSV_PATH)["train"] |
|
|
|
|
|
if TRAIN_NUM_SAMPLES: |
|
|
train_dataset = train_dataset.shuffle(seed=SEED).select( |
|
|
range(min(TRAIN_NUM_SAMPLES, len(train_dataset))) |
|
|
) |
|
|
|
|
|
if EVAL_NUM_SAMPLES: |
|
|
eval_dataset = eval_dataset.shuffle(seed=SEED).select( |
|
|
range(min(EVAL_NUM_SAMPLES, len(eval_dataset))) |
|
|
) |
|
|
|
|
|
logger.success( |
|
|
f"Loaded datasets: {len(train_dataset)} training samples, {len(eval_dataset)} evaluation samples." |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
def flatten_and_get_audio_input_ids(row): |
|
|
audio_codes = row["snac_codes"] |
|
|
if isinstance(audio_codes, str): |
|
|
audio_codes = eval(audio_codes) |
|
|
snac_token_ids = [] |
|
|
for i in range(len(audio_codes[0])): |
|
|
snac_token_ids.append(audio_codes[0][i] + 128266) |
|
|
snac_token_ids.append(audio_codes[1][2 * i] + 128266 + 4096) |
|
|
snac_token_ids.append(audio_codes[2][4 * i] + 128266 + (2 * 4096)) |
|
|
snac_token_ids.append(audio_codes[2][(4 * i) + 1] + 128266 + (3 * 4096)) |
|
|
snac_token_ids.append(audio_codes[1][(2 * i) + 1] + 128266 + (4 * 4096)) |
|
|
snac_token_ids.append(audio_codes[2][(4 * i) + 2] + 128266 + (5 * 4096)) |
|
|
snac_token_ids.append(audio_codes[2][(4 * i) + 3] + 128266 + (6 * 4096)) |
|
|
row["snac_token_ids"] = snac_token_ids |
|
|
return row |
|
|
|
|
|
|
|
|
train_dataset = train_dataset.map(flatten_and_get_audio_input_ids) |
|
|
eval_dataset = eval_dataset.map(flatten_and_get_audio_input_ids) |
|
|
logger.success("Flattened and extracted SNAC token IDs from audio codes.") |
|
|
|
|
|
|
|
|
train_dataset = train_dataset.filter( |
|
|
lambda x: x["snac_token_ids"] is not None and len(x["snac_token_ids"]) > 0 |
|
|
) |
|
|
eval_dataset = eval_dataset.filter( |
|
|
lambda x: x["snac_token_ids"] is not None and len(x["snac_token_ids"]) > 0 |
|
|
) |
|
|
logger.success("Filtered datasets to remove rows with empty or None audio codes.") |
|
|
|
|
|
|
|
|
|
|
|
def remove_duplicate_frames(row): |
|
|
vals = row["snac_token_ids"] |
|
|
if len(vals) % 7 != 0: |
|
|
raise ValueError("Input list length must be divisible by 7") |
|
|
result = vals[:7] |
|
|
for i in range(7, len(vals), 7): |
|
|
current_first = vals[i] |
|
|
previous_first = result[-7] |
|
|
if current_first != previous_first: |
|
|
result.extend(vals[i : i + 7]) |
|
|
row["snac_token_ids"] = result |
|
|
return row |
|
|
|
|
|
|
|
|
train_dataset = train_dataset.map(remove_duplicate_frames) |
|
|
eval_dataset = eval_dataset.map(remove_duplicate_frames) |
|
|
logger.success("Removed duplicate frames from audio codes.") |
|
|
|
|
|
|
|
|
|
|
|
def format_text(row): |
|
|
text = ( |
|
|
f"{start_of_human_token}{start_of_text_token}{row['language']}{row['user']}: {row['utterance']}{end_of_text_token}" |
|
|
f"{end_of_human_token}{start_of_ai_token}{start_of_speech_token}" |
|
|
f"{tokenizer.decode(row['snac_token_ids'])}{end_of_speech_token}{end_of_ai_token}" |
|
|
) |
|
|
eval_text_user = ( |
|
|
f"{start_of_human_token}{start_of_text_token}{row['language']}{row['user']}: {row['utterance']}{end_of_text_token}" |
|
|
f"{end_of_human_token}{start_of_ai_token}{start_of_speech_token}" |
|
|
) |
|
|
eval_text_no_user = ( |
|
|
f"{start_of_human_token}{start_of_text_token}{row['utterance']}{end_of_text_token}" |
|
|
f"{end_of_human_token}{start_of_ai_token}{start_of_speech_token}" |
|
|
) |
|
|
row["text"] = text |
|
|
row["eval_text_user"] = eval_text_user |
|
|
row["eval_text_no_user"] = eval_text_no_user |
|
|
return row |
|
|
|
|
|
|
|
|
train_dataset = train_dataset.map(format_text) |
|
|
eval_dataset = eval_dataset.map(format_text) |
|
|
logger.success("Formatted text for training and evaluation datasets.") |
|
|
|
|
|
|
|
|
|
|
|
def tokenize_function(example): |
|
|
return tokenizer( |
|
|
example["text"], |
|
|
add_special_tokens=False, |
|
|
truncation=True, |
|
|
max_length=MAX_SEQ_LENGTH, |
|
|
) |
|
|
|
|
|
|
|
|
train_dataset = train_dataset.map(tokenize_function) |
|
|
eval_dataset = eval_dataset.map(tokenize_function) |
|
|
logger.success("Tokenized text in the datasets without adding special tokens.") |
|
|
|
|
|
|
|
|
training_args = SFTConfig( |
|
|
num_train_epochs=EPOCHS, |
|
|
per_device_train_batch_size=PER_DEVICE_TRAIN_BATCH_SIZE, |
|
|
gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS, |
|
|
optim="adamw_8bit", |
|
|
learning_rate=LR, |
|
|
lr_scheduler_type="cosine", |
|
|
warmup_ratio=0.02, |
|
|
do_eval=True, |
|
|
eval_strategy="steps", |
|
|
eval_steps=50, |
|
|
logging_strategy="steps", |
|
|
logging_steps=1, |
|
|
save_strategy="steps", |
|
|
save_only_model=True, |
|
|
save_steps=1250, |
|
|
output_dir="outputs", |
|
|
report_to="wandb", |
|
|
run_name=WANDB_RUN_NAME, |
|
|
seed=SEED, |
|
|
) |
|
|
|
|
|
|
|
|
trainer = SFTTrainer( |
|
|
model=model, |
|
|
tokenizer=tokenizer, |
|
|
train_dataset=train_dataset, |
|
|
eval_dataset=eval_dataset, |
|
|
max_seq_length=MAX_SEQ_LENGTH, |
|
|
dataset_num_proc=2, |
|
|
packing=True, |
|
|
args=training_args, |
|
|
) |
|
|
|
|
|
logger.success("Initialized SFTTrainer with the specified configuration.") |
|
|
|
|
|
|
|
|
logger.info("Starting the training process...") |
|
|
|
|
|
run = wandb.init() |
|
|
|
|
|
if WANDB_RUN_ID: |
|
|
logger.info(f"Resuming from Weights & Biases run ID: {WANDB_RUN_ID}") |
|
|
|
|
|
artifact = run.use_artifact( |
|
|
f"{WANDB_USERNAME}/{WANDB_PROJECT}/{WANDB_RUN_ID}", type="model" |
|
|
) |
|
|
|
|
|
artifact_dir = artifact.download() |
|
|
|
|
|
trainer.train(resume_from_checkpoint=artifact_dir) |
|
|
else: |
|
|
try: |
|
|
logger.info("Attempting to resume training from the last checkpoint...") |
|
|
|
|
|
trainer.train(resume_from_checkpoint=True) |
|
|
except Exception as err: |
|
|
trainer.train() |
|
|
|
|
|
|
|
|
wandb.finish() |
|
|
|
|
|
logger.success("Training completed successfully.") |
|
|
|
|
|
|
|
|
FastLanguageModel.for_inference(model) |
|
|
logger.success(f"Model {MODEL_NAME} is ready for inference.") |
|
|
|
|
|
|
|
|
snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz") |
|
|
logger.success("Loaded SNAC model for audio decoding.") |
|
|
|
|
|
|
|
|
|
|
|
def generate_audio( |
|
|
row, |
|
|
model, |
|
|
tokenizer, |
|
|
user=False, |
|
|
temperature=0.4, |
|
|
top_p=0.9, |
|
|
repetition_penalty=1.05, |
|
|
): |
|
|
try: |
|
|
if user: |
|
|
prompt = row["eval_text_user"] |
|
|
else: |
|
|
prompt = row["eval_text_no_user"] |
|
|
inputs = tokenizer(prompt, add_special_tokens=False, return_tensors="pt") |
|
|
max_tokens = MAX_SEQ_LENGTH - inputs.input_ids.shape[1] |
|
|
output = model.generate( |
|
|
input_ids=inputs.input_ids.to("cuda"), |
|
|
attention_mask=inputs.attention_mask.to("cuda"), |
|
|
max_new_tokens=max_tokens, |
|
|
temperature=temperature, |
|
|
top_p=top_p, |
|
|
repetition_penalty=repetition_penalty, |
|
|
eos_token_id=end_of_speech_id, |
|
|
) |
|
|
audio_ids = [] |
|
|
for id in output[0]: |
|
|
if id >= audio_start_id: |
|
|
audio_ids.append(id.item()) |
|
|
clean_audio_ids = [] |
|
|
for i in range((len(audio_ids) + 1) // 7): |
|
|
for j in range(7): |
|
|
clean_audio_ids += [audio_ids[7 * i + j] - audio_start_id] |
|
|
codes = [[], [], []] |
|
|
for i in range((len(clean_audio_ids) + 1) // 7): |
|
|
codes[0].append(clean_audio_ids[7 * i]) |
|
|
codes[1].append(clean_audio_ids[7 * i + 1] - 4096) |
|
|
codes[2].append(clean_audio_ids[7 * i + 2] - (2 * 4096)) |
|
|
codes[2].append(clean_audio_ids[7 * i + 3] - (3 * 4096)) |
|
|
codes[1].append(clean_audio_ids[7 * i + 4] - (4 * 4096)) |
|
|
codes[2].append(clean_audio_ids[7 * i + 5] - (5 * 4096)) |
|
|
codes[2].append(clean_audio_ids[7 * i + 6] - (6 * 4096)) |
|
|
codes = [ |
|
|
torch.tensor(codes[0]).unsqueeze(0), |
|
|
torch.tensor(codes[1]).unsqueeze(0), |
|
|
torch.tensor(codes[2]).unsqueeze(0), |
|
|
] |
|
|
audio = snac_model.decode(codes) |
|
|
return audio.detach().squeeze().to("cpu").numpy() |
|
|
except Exception as e: |
|
|
logger.error(f"Error decoding audio: {e}") |
|
|
return None |
|
|
|
|
|
|
|
|
|
|
|
train_sample = generate_audio(train_dataset[0], model, tokenizer, True) |
|
|
if train_sample is None: |
|
|
logger.error("Failed to generate audio for training sample.") |
|
|
else: |
|
|
sf.write(f"train_{STAGE}.wav", train_sample, 24000) |
|
|
logger.success("Generated and saved training sample audio.") |
|
|
|
|
|
|
|
|
dir_ = f"eval_{STAGE}/" |
|
|
os.makedirs(dir_, exist_ok=True) |
|
|
for i in range(10): |
|
|
eval_sample = generate_audio(eval_dataset[i], model, tokenizer, True) |
|
|
if eval_sample is None: |
|
|
logger.error(f"Failed to generate audio for evaluation sample {i}.") |
|
|
else: |
|
|
filename = dir_ + f"eval_{i}.wav" |
|
|
sf.write(filename, eval_sample, 24000) |
|
|
logger.success(f"Generated and saved evaluation sample audio as {filename}.") |
|
|
|
|
|
|
|
|
model.save_pretrained_merged( |
|
|
f"{HUGGINGFACE_USERNAME}/{MODEL_NAME}", |
|
|
tokenizer, |
|
|
save_method="merged_16bit", |
|
|
) |
|
|
logger.success("Saved the model and tokenizer locally.") |
|
|
|
|
|
model.push_to_hub_merged( |
|
|
f"{HUGGINGFACE_USERNAME}/{MODEL_NAME}", |
|
|
tokenizer, |
|
|
save_method="merged_16bit", |
|
|
token=HUGGINGFACE_TOKEN, |
|
|
) |
|
|
logger.success("Pushed the model and tokenizer to the Hugging Face Hub.") |
|
|
|