"""Command-line entry points for generation engines.""" from __future__ import annotations import argparse from pathlib import Path from .common import duration_options from .corpus import endpoint_priors, load_sequences, symbol_stats from .engines.markov import generate_markov from .engines.transformer import ( TransformerConfig, generate_transformer, load_transformer_checkpoint, sample_transformer_checkpoint, train_and_save_checkpoint, ) from .io import write_samples from .reports import format_allowed_durations, write_generation_report def add_common_args(parser: argparse.ArgumentParser) -> None: parser.add_argument("--db", type=Path, default=Path("audit/themes_audit.sqlite")) parser.add_argument("--output-dir", type=Path) parser.add_argument("--length", type=int, default=24) parser.add_argument("--samples", type=int, default=12) parser.add_argument("--key", default="C") parser.add_argument("--endpoint-strength", type=float, default=1.0) parser.add_argument("--seed", type=int, default=42) parser.add_argument( "--min-duration", default="16th", help="Shortest generated/training duration label.", ) parser.add_argument( "--duration-grid", default="16th", help="Require generated/training durations to be multiples of this value.", ) parser.add_argument( "--no-triplets", action="store_true", help="Exclude regular triplet durations from the generated/training vocabulary.", ) parser.add_argument( "--loose-triplets", action="store_true", help="Allow triplet durations outside complete beat-aligned groups.", ) parser.add_argument("--write-abc", action="store_true", help="Also write ABC files next to the MIDIs.") parser.add_argument( "--write-musicxml", action="store_true", help="Also write MusicXML files next to the MIDIs.", ) def load_generation_inputs(args: argparse.Namespace, *, min_len: int): allowed_durations = duration_options(args.min_duration, args.duration_grid, not args.no_triplets) if not allowed_durations: raise ValueError(f"No allowed durations remain for min duration {args.min_duration!r}") sequences = load_sequences(args.db, allowed_durations, min_len=min_len) if not sequences: raise ValueError("No training sequences matched the selected duration and length settings") return allowed_durations, sequences, symbol_stats(sequences), endpoint_priors(args.db) def base_settings(args: argparse.Namespace, stats: dict, allowed_durations: set[str]) -> dict[str, object]: return { "sequences": stats["sequence_count"], "events": stats["event_count"], "vocabulary size": stats["vocab_size"], "generated note length": args.length, "samples": args.samples, "output key": args.key, "minimum duration": args.min_duration, "duration grid": args.duration_grid, "triplets allowed": not args.no_triplets, "triplets grouped": not args.loose_triplets, "allowed durations": format_allowed_durations(allowed_durations), "endpoint strength": args.endpoint_strength, } def run_markov(args: argparse.Namespace) -> None: allowed_durations, sequences, stats, priors = load_generation_inputs(args, min_len=max(6, args.max_order + 1)) start_weights, end_weights = priors generated, diagnostics = generate_markov( sequences=sequences, length=args.length, samples=args.samples, max_order=args.max_order, start_weights=start_weights, end_weights=end_weights, endpoint_strength=args.endpoint_strength, enforce_triplet_groups=not args.loose_triplets, seed=args.seed, ) write_samples( generated, output_dir=args.output_dir, key_name=args.key, engine_name="vo_regular baseline", write_abc=args.write_abc, write_musicxml_files=args.write_musicxml, ) settings = base_settings(args, stats, allowed_durations) settings["max order"] = args.max_order settings.update(diagnostics) write_generation_report( output_dir=args.output_dir, title="VO-Regular Baseline Generation", description="This is the key-relative variable-order Markov baseline.", settings=settings, stats=stats, generated=generated, write_abc=args.write_abc, write_musicxml=args.write_musicxml, ) print(f"Wrote {args.output_dir}") def run_transformer(args: argparse.Namespace) -> None: cfg = TransformerConfig( block_size=args.block_size, d_model=args.d_model, nhead=args.nhead, num_layers=args.layers, dim_feedforward=args.feedforward, dropout=args.dropout, batch_size=args.batch_size, steps=args.steps, learning_rate=args.learning_rate, temperature=args.temperature, top_k=args.top_k, max_retries=args.max_retries, ) allowed_durations, sequences, stats, priors = load_generation_inputs(args, min_len=max(6, cfg.block_size // 4)) start_weights, end_weights = priors if args.load_checkpoint: checkpoint = load_transformer_checkpoint(args.load_checkpoint, requested_device=args.device) generated, diagnostics = sample_transformer_checkpoint( checkpoint=checkpoint, length=args.length, samples=args.samples, start_weights=start_weights, end_weights=end_weights, endpoint_strength=args.endpoint_strength, enforce_triplet_groups=not args.loose_triplets, seed=args.seed, temperature=args.temperature, top_k=args.top_k, max_retries=args.max_retries, ) elif args.save_checkpoint: checkpoint = train_and_save_checkpoint( sequences=sequences, cfg=cfg, seed=args.seed, requested_device=args.device, path=args.save_checkpoint, ) generated, diagnostics = sample_transformer_checkpoint( checkpoint=checkpoint, length=args.length, samples=args.samples, start_weights=start_weights, end_weights=end_weights, endpoint_strength=args.endpoint_strength, enforce_triplet_groups=not args.loose_triplets, seed=args.seed, temperature=args.temperature, top_k=args.top_k, max_retries=args.max_retries, ) diagnostics["saved checkpoint"] = str(args.save_checkpoint) else: generated, diagnostics = generate_transformer( sequences=sequences, length=args.length, samples=args.samples, start_weights=start_weights, end_weights=end_weights, endpoint_strength=args.endpoint_strength, enforce_triplet_groups=not args.loose_triplets, seed=args.seed, cfg=cfg, device=args.device, ) write_samples( generated, output_dir=args.output_dir, key_name=args.key, engine_name="transformer baseline", write_abc=args.write_abc, write_musicxml_files=args.write_musicxml, ) settings = base_settings(args, stats, allowed_durations) settings.update( { "block size": cfg.block_size, "d model": cfg.d_model, "heads": cfg.nhead, "layers": cfg.num_layers, "feedforward": cfg.dim_feedforward, "dropout": cfg.dropout, "batch size": cfg.batch_size, "learning rate": cfg.learning_rate, "temperature": cfg.temperature, "top k": cfg.top_k, } ) settings.update(diagnostics) write_generation_report( output_dir=args.output_dir, title="Transformer Baseline Generation", description="This is the first key-relative tiny transformer baseline.", settings=settings, stats=stats, generated=generated, write_abc=args.write_abc, write_musicxml=args.write_musicxml, ) print(f"Wrote {args.output_dir}") def build_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser(description="Generate short key-relative theme samples.") subparsers = parser.add_subparsers(dest="engine", required=True) markov = subparsers.add_parser("markov", help="Run the vo_regular variable-order Markov engine.") add_common_args(markov) markov.set_defaults(output_dir=Path("outputs/vo_regular_baseline"), func=run_markov) markov.add_argument("--max-order", type=int, default=4) transformer = subparsers.add_parser("transformer", help="Run the tiny PyTorch transformer engine.") add_common_args(transformer) transformer.set_defaults(output_dir=Path("outputs/transformer_baseline"), func=run_transformer) transformer.add_argument("--block-size", type=int, default=64) transformer.add_argument("--d-model", type=int, default=96) transformer.add_argument("--nhead", type=int, default=4) transformer.add_argument("--layers", type=int, default=3) transformer.add_argument("--feedforward", type=int, default=192) transformer.add_argument("--dropout", type=float, default=0.1) transformer.add_argument("--batch-size", type=int, default=64) transformer.add_argument("--steps", type=int, default=800) transformer.add_argument("--learning-rate", type=float, default=3e-4) transformer.add_argument("--temperature", type=float, default=1.0) transformer.add_argument("--top-k", type=int, default=16) transformer.add_argument("--max-retries", type=int, default=100) transformer.add_argument("--device", default="auto", help="auto, cpu, cuda, or mps.") transformer.add_argument("--save-checkpoint", type=Path, help="Train once and save a reusable checkpoint.") transformer.add_argument("--load-checkpoint", type=Path, help="Generate from a saved checkpoint instead of training.") return parser def main(argv: list[str] | None = None) -> None: parser = build_parser() args = parser.parse_args(argv) try: args.func(args) except RuntimeError as exc: parser.exit(1, f"error: {exc}\n") if __name__ == "__main__": main()