| |
| |
| """ |
| Train MAGEL directly from a vanilla Qwen3 checkpoint. |
| |
| Compared with train.py/train_newparaonly.py, this script: |
| 1) Loads an original Qwen3 base checkpoint. |
| 2) Resolves MAGEL hparams explicitly at construction time. |
| 3) Initializes MAGEL extra modules from scratch and trains end-to-end. |
| """ |
|
|
| import argparse |
| import os |
|
|
| import torch |
| from transformers import ( |
| AutoConfig, |
| Trainer, |
| TrainingArguments, |
| ) |
| import datasets |
| from dataset import DataCollate, MusicDataset |
| from modelling_qwen3 import MAGEL |
|
|
|
|
| def resolve_model_source(model_path: str, resume_from_checkpoint: str | None) -> str: |
| if not resume_from_checkpoint: |
| return model_path |
|
|
| if os.path.abspath(model_path) != os.path.abspath(resume_from_checkpoint): |
| print( |
| "Ignoring --model_path during resume and loading config/model from: " |
| f"{resume_from_checkpoint}" |
| ) |
| return resume_from_checkpoint |
|
|
|
|
| def create_model( |
| model_path: str, |
| model_dtype: torch.dtype, |
| target_vocab_size: int, |
| attn_implementation: str, |
| ) -> MAGEL: |
| print(f"Loading Qwen3 model from: {model_path}") |
|
|
| config = AutoConfig.from_pretrained( |
| model_path, |
| local_files_only=True, |
| ) |
|
|
| model = MAGEL.from_pretrained( |
| model_path, |
| torch_dtype=model_dtype, |
| config=config, |
| attn_implementation=attn_implementation, |
| ignore_mismatched_sizes=True, |
| local_files_only=True, |
| ) |
| model.resize_token_embeddings(target_vocab_size) |
|
|
| total_params = sum(p.numel() for p in model.parameters()) |
| trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) |
| magel_extra_params = sum( |
| p.numel() |
| for name, p in model.named_parameters() |
| if ("condition_encoder" in name or "dit_adaln" in name) |
| ) |
|
|
| print(f"Total parameters: {total_params:,}") |
| print(f"Trainable parameters: {trainable_params:,}") |
| print(f"MAGEL extra parameters: {magel_extra_params:,}") |
| print( |
| "MAGEL config: " |
| f"adaln_dim={model.adaln_dim}, " |
| f"chord_dropout_trigger_prob={model.chord_dropout_trigger_prob}, " |
| f"structure_dropout_trigger_prob={model.structure_dropout_trigger_prob}" |
| ) |
|
|
| return model |
|
|
|
|
| def create_dataset( |
| dataset_path: str, |
| tokenizer_path: str, |
| num_audio_token: int = 16384, |
| ) -> MusicDataset: |
| print(f"Loading dataset from: {dataset_path}") |
| print(f"Loading tokenizer from: {tokenizer_path}") |
|
|
| hf_ds = datasets.load_from_disk(dataset_path) |
|
|
| train_dataset = MusicDataset( |
| hf_ds, |
| split="train", |
| tokenizer_path=tokenizer_path, |
| num_audio_token=num_audio_token, |
| use_fast=True, |
| ) |
|
|
| print(f"Dataset size: {len(train_dataset)}") |
|
|
| return train_dataset |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser( |
| description="Train MAGEL directly from a vanilla Qwen3 base checkpoint." |
| ) |
|
|
| parser.add_argument( |
| "--dataset_path", |
| type=str, |
| default="muse_mucodec_chord.ds", |
| ) |
|
|
| parser.add_argument( |
| "--model_path", |
| type=str, |
| default="checkpoints/Qwen3-0.6B", |
| help="Local Qwen3 base checkpoint path.", |
| ) |
| parser.add_argument( |
| "--tokenizer_path", |
| type=str, |
| default="checkpoints/Qwen3-0.6B", |
| help="Local tokenizer checkpoint path.", |
| ) |
| parser.add_argument( |
| "--model_dtype", |
| type=str, |
| default="bfloat16", |
| choices=["float32", "float16", "bfloat16"], |
| ) |
| parser.add_argument( |
| "--attn_implementation", |
| type=str, |
| default="sdpa", |
| choices=["eager", "sdpa", "flash_attention_2"], |
| ) |
|
|
| parser.add_argument("--output_dir", type=str, default="./output_qwen3_0p6b_train") |
| parser.add_argument("--per_device_train_batch_size", type=int, default=1) |
| parser.add_argument("--gradient_accumulation_steps", type=int, default=4) |
| parser.add_argument("--learning_rate", type=float, default=1e-4) |
| parser.add_argument("--weight_decay", type=float, default=0.01) |
| parser.add_argument("--num_train_epochs", type=float, default=20) |
| parser.add_argument("--warmup_steps", type=int, default=1000) |
| parser.add_argument("--max_grad_norm", type=float, default=5.0) |
| parser.add_argument("--logging_steps", type=int, default=10) |
| parser.add_argument( |
| "--resume_from_checkpoint", |
| type=str, |
| default=None, |
| help="Resume training from a Trainer checkpoint directory such as output_dir/checkpoint-500.", |
| ) |
|
|
| parser.add_argument("--dataloader_num_workers", type=int, default=12) |
| parser.add_argument( |
| "--gradient_checkpointing", |
| dest="gradient_checkpointing", |
| action="store_true", |
| ) |
|
|
| parser.add_argument( |
| "--deepspeed", |
| type=str, |
| default=None, |
| help="Path to DeepSpeed config. Leave unset to disable DeepSpeed.", |
| ) |
|
|
| parser.add_argument("--report_to", type=str, default="wandb") |
| parser.add_argument("--wandb_project", type=str, default="vaultum-qwen3-0p6b") |
| parser.add_argument("--wandb_run_name", type=str, default=None) |
|
|
| args = parser.parse_args() |
|
|
| model_dtype = { |
| "float32": torch.float32, |
| "float16": torch.float16, |
| "bfloat16": torch.bfloat16, |
| }[args.model_dtype] |
|
|
| model_source = resolve_model_source( |
| model_path=args.model_path, |
| resume_from_checkpoint=args.resume_from_checkpoint, |
| ) |
|
|
| base_config = AutoConfig.from_pretrained( |
| model_source, |
| local_files_only=True, |
| ) |
|
|
| num_audio_token = int(base_config.magel_num_audio_token) |
| print(f"Using num_audio_token={num_audio_token}") |
|
|
| train_dataset = create_dataset( |
| dataset_path=args.dataset_path, |
| tokenizer_path=args.tokenizer_path, |
| num_audio_token=num_audio_token, |
| ) |
|
|
| target_vocab_size = train_dataset.tokenizer_vocab_size |
|
|
| model = create_model( |
| model_path=model_source, |
| model_dtype=model_dtype, |
| attn_implementation=args.attn_implementation, |
| target_vocab_size=target_vocab_size, |
| ) |
|
|
| training_args = TrainingArguments( |
| output_dir=args.output_dir, |
| per_device_train_batch_size=args.per_device_train_batch_size, |
| gradient_accumulation_steps=args.gradient_accumulation_steps, |
| learning_rate=args.learning_rate, |
| weight_decay=args.weight_decay, |
| num_train_epochs=args.num_train_epochs, |
| warmup_steps=args.warmup_steps, |
| max_grad_norm=args.max_grad_norm, |
| logging_steps=args.logging_steps, |
| save_strategy="epoch", |
| dataloader_num_workers=args.dataloader_num_workers, |
| bf16=(args.model_dtype == "bfloat16"), |
| fp16=(args.model_dtype == "float16"), |
| gradient_checkpointing=args.gradient_checkpointing, |
| gradient_checkpointing_kwargs={"use_reentrant": False}, |
| deepspeed=args.deepspeed, |
| remove_unused_columns=False, |
| dataloader_drop_last=True, |
| report_to=args.report_to, |
| logging_dir=None, |
| run_name=args.wandb_run_name, |
| ) |
|
|
| if args.wandb_project and "wandb" in args.report_to: |
| os.environ["WANDB_PROJECT"] = args.wandb_project |
|
|
| trainer = Trainer( |
| model=model, |
| args=training_args, |
| train_dataset=train_dataset, |
| data_collator=DataCollate(), |
| ) |
|
|
| if args.resume_from_checkpoint: |
| print(f"Resuming training from checkpoint: {args.resume_from_checkpoint}") |
| else: |
| print("Starting training from current model initialization.") |
|
|
| trainer.train(resume_from_checkpoint=args.resume_from_checkpoint) |
| final_dir = os.path.join(args.output_dir, "final") |
| trainer.save_model(final_dir) |
| train_dataset.tokenizer.save_pretrained(final_dir) |
|
|
| print(f"Training complete. Final model saved to: {final_dir}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|