cond_gen / train.py
Leon299's picture
Add files using upload-large-folder tool
8337fa0 verified
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Train MAGEL directly from a vanilla Qwen3 checkpoint.
Compared with train.py/train_newparaonly.py, this script:
1) Loads an original Qwen3 base checkpoint.
2) Resolves MAGEL hparams explicitly at construction time.
3) Initializes MAGEL extra modules from scratch and trains end-to-end.
"""
import argparse
import os
import torch
from transformers import (
AutoConfig,
Trainer,
TrainingArguments,
)
import datasets
from dataset import DataCollate, MusicDataset
from modelling_qwen3 import MAGEL
def resolve_model_source(model_path: str, resume_from_checkpoint: str | None) -> str:
if not resume_from_checkpoint:
return model_path
if os.path.abspath(model_path) != os.path.abspath(resume_from_checkpoint):
print(
"Ignoring --model_path during resume and loading config/model from: "
f"{resume_from_checkpoint}"
)
return resume_from_checkpoint
def create_model(
model_path: str,
model_dtype: torch.dtype,
target_vocab_size: int,
attn_implementation: str,
) -> MAGEL:
print(f"Loading Qwen3 model from: {model_path}")
config = AutoConfig.from_pretrained(
model_path,
local_files_only=True,
)
model = MAGEL.from_pretrained(
model_path,
torch_dtype=model_dtype,
config=config,
attn_implementation=attn_implementation,
ignore_mismatched_sizes=True,
local_files_only=True,
)
model.resize_token_embeddings(target_vocab_size)
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
magel_extra_params = sum(
p.numel()
for name, p in model.named_parameters()
if ("condition_encoder" in name or "dit_adaln" in name)
)
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")
print(f"MAGEL extra parameters: {magel_extra_params:,}")
print(
"MAGEL config: "
f"adaln_dim={model.adaln_dim}, "
f"chord_dropout_trigger_prob={model.chord_dropout_trigger_prob}, "
f"structure_dropout_trigger_prob={model.structure_dropout_trigger_prob}"
)
return model
def create_dataset(
dataset_path: str,
tokenizer_path: str,
num_audio_token: int = 16384,
) -> MusicDataset:
print(f"Loading dataset from: {dataset_path}")
print(f"Loading tokenizer from: {tokenizer_path}")
hf_ds = datasets.load_from_disk(dataset_path)
train_dataset = MusicDataset(
hf_ds,
split="train",
tokenizer_path=tokenizer_path,
num_audio_token=num_audio_token,
use_fast=True,
)
print(f"Dataset size: {len(train_dataset)}")
return train_dataset
def main():
parser = argparse.ArgumentParser(
description="Train MAGEL directly from a vanilla Qwen3 base checkpoint."
)
parser.add_argument(
"--dataset_path",
type=str,
default="muse_mucodec_chord.ds",
)
parser.add_argument(
"--model_path",
type=str,
default="checkpoints/Qwen3-0.6B",
help="Local Qwen3 base checkpoint path.",
)
parser.add_argument(
"--tokenizer_path",
type=str,
default="checkpoints/Qwen3-0.6B",
help="Local tokenizer checkpoint path.",
)
parser.add_argument(
"--model_dtype",
type=str,
default="bfloat16",
choices=["float32", "float16", "bfloat16"],
)
parser.add_argument(
"--attn_implementation",
type=str,
default="sdpa",
choices=["eager", "sdpa", "flash_attention_2"],
)
parser.add_argument("--output_dir", type=str, default="./output_qwen3_0p6b_train")
parser.add_argument("--per_device_train_batch_size", type=int, default=1)
parser.add_argument("--gradient_accumulation_steps", type=int, default=4)
parser.add_argument("--learning_rate", type=float, default=1e-4)
parser.add_argument("--weight_decay", type=float, default=0.01)
parser.add_argument("--num_train_epochs", type=float, default=20)
parser.add_argument("--warmup_steps", type=int, default=1000)
parser.add_argument("--max_grad_norm", type=float, default=5.0)
parser.add_argument("--logging_steps", type=int, default=10)
parser.add_argument(
"--resume_from_checkpoint",
type=str,
default=None,
help="Resume training from a Trainer checkpoint directory such as output_dir/checkpoint-500.",
)
parser.add_argument("--dataloader_num_workers", type=int, default=12)
parser.add_argument(
"--gradient_checkpointing",
dest="gradient_checkpointing",
action="store_true",
)
parser.add_argument(
"--deepspeed",
type=str,
default=None,
help="Path to DeepSpeed config. Leave unset to disable DeepSpeed.",
)
parser.add_argument("--report_to", type=str, default="wandb")
parser.add_argument("--wandb_project", type=str, default="vaultum-qwen3-0p6b")
parser.add_argument("--wandb_run_name", type=str, default=None)
args = parser.parse_args()
model_dtype = {
"float32": torch.float32,
"float16": torch.float16,
"bfloat16": torch.bfloat16,
}[args.model_dtype]
model_source = resolve_model_source(
model_path=args.model_path,
resume_from_checkpoint=args.resume_from_checkpoint,
)
base_config = AutoConfig.from_pretrained(
model_source,
local_files_only=True,
)
num_audio_token = int(base_config.magel_num_audio_token)
print(f"Using num_audio_token={num_audio_token}")
train_dataset = create_dataset(
dataset_path=args.dataset_path,
tokenizer_path=args.tokenizer_path,
num_audio_token=num_audio_token,
)
target_vocab_size = train_dataset.tokenizer_vocab_size
model = create_model(
model_path=model_source,
model_dtype=model_dtype,
attn_implementation=args.attn_implementation,
target_vocab_size=target_vocab_size,
)
training_args = TrainingArguments(
output_dir=args.output_dir,
per_device_train_batch_size=args.per_device_train_batch_size,
gradient_accumulation_steps=args.gradient_accumulation_steps,
learning_rate=args.learning_rate,
weight_decay=args.weight_decay,
num_train_epochs=args.num_train_epochs,
warmup_steps=args.warmup_steps,
max_grad_norm=args.max_grad_norm,
logging_steps=args.logging_steps,
save_strategy="epoch",
dataloader_num_workers=args.dataloader_num_workers,
bf16=(args.model_dtype == "bfloat16"),
fp16=(args.model_dtype == "float16"),
gradient_checkpointing=args.gradient_checkpointing,
gradient_checkpointing_kwargs={"use_reentrant": False},
deepspeed=args.deepspeed,
remove_unused_columns=False,
dataloader_drop_last=True,
report_to=args.report_to,
logging_dir=None,
run_name=args.wandb_run_name,
)
if args.wandb_project and "wandb" in args.report_to:
os.environ["WANDB_PROJECT"] = args.wandb_project
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
data_collator=DataCollate(),
)
if args.resume_from_checkpoint:
print(f"Resuming training from checkpoint: {args.resume_from_checkpoint}")
else:
print("Starting training from current model initialization.")
trainer.train(resume_from_checkpoint=args.resume_from_checkpoint)
final_dir = os.path.join(args.output_dir, "final")
trainer.save_model(final_dir)
train_dataset.tokenizer.save_pretrained(final_dir)
print(f"Training complete. Final model saved to: {final_dir}")
if __name__ == "__main__":
main()