Spaces:
Running
Running
| """Command-line entry points for generation engines.""" | |
| from __future__ import annotations | |
| import argparse | |
| from pathlib import Path | |
| from .common import duration_options | |
| from .corpus import endpoint_priors, load_sequences, symbol_stats | |
| from .engines.markov import generate_markov | |
| from .engines.transformer import ( | |
| TransformerConfig, | |
| generate_transformer, | |
| load_transformer_checkpoint, | |
| sample_transformer_checkpoint, | |
| train_and_save_checkpoint, | |
| ) | |
| from .io import write_samples | |
| from .reports import format_allowed_durations, write_generation_report | |
| def add_common_args(parser: argparse.ArgumentParser) -> None: | |
| parser.add_argument("--db", type=Path, default=Path("audit/themes_audit.sqlite")) | |
| parser.add_argument("--output-dir", type=Path) | |
| parser.add_argument("--length", type=int, default=24) | |
| parser.add_argument("--samples", type=int, default=12) | |
| parser.add_argument("--key", default="C") | |
| parser.add_argument("--endpoint-strength", type=float, default=1.0) | |
| parser.add_argument("--seed", type=int, default=42) | |
| parser.add_argument( | |
| "--min-duration", | |
| default="16th", | |
| help="Shortest generated/training duration label.", | |
| ) | |
| parser.add_argument( | |
| "--duration-grid", | |
| default="16th", | |
| help="Require generated/training durations to be multiples of this value.", | |
| ) | |
| parser.add_argument( | |
| "--no-triplets", | |
| action="store_true", | |
| help="Exclude regular triplet durations from the generated/training vocabulary.", | |
| ) | |
| parser.add_argument( | |
| "--loose-triplets", | |
| action="store_true", | |
| help="Allow triplet durations outside complete beat-aligned groups.", | |
| ) | |
| parser.add_argument("--write-abc", action="store_true", help="Also write ABC files next to the MIDIs.") | |
| parser.add_argument( | |
| "--write-musicxml", | |
| action="store_true", | |
| help="Also write MusicXML files next to the MIDIs.", | |
| ) | |
| def load_generation_inputs(args: argparse.Namespace, *, min_len: int): | |
| allowed_durations = duration_options(args.min_duration, args.duration_grid, not args.no_triplets) | |
| if not allowed_durations: | |
| raise ValueError(f"No allowed durations remain for min duration {args.min_duration!r}") | |
| sequences = load_sequences(args.db, allowed_durations, min_len=min_len) | |
| if not sequences: | |
| raise ValueError("No training sequences matched the selected duration and length settings") | |
| return allowed_durations, sequences, symbol_stats(sequences), endpoint_priors(args.db) | |
| def base_settings(args: argparse.Namespace, stats: dict, allowed_durations: set[str]) -> dict[str, object]: | |
| return { | |
| "sequences": stats["sequence_count"], | |
| "events": stats["event_count"], | |
| "vocabulary size": stats["vocab_size"], | |
| "generated note length": args.length, | |
| "samples": args.samples, | |
| "output key": args.key, | |
| "minimum duration": args.min_duration, | |
| "duration grid": args.duration_grid, | |
| "triplets allowed": not args.no_triplets, | |
| "triplets grouped": not args.loose_triplets, | |
| "allowed durations": format_allowed_durations(allowed_durations), | |
| "endpoint strength": args.endpoint_strength, | |
| } | |
| def run_markov(args: argparse.Namespace) -> None: | |
| allowed_durations, sequences, stats, priors = load_generation_inputs(args, min_len=max(6, args.max_order + 1)) | |
| start_weights, end_weights = priors | |
| generated, diagnostics = generate_markov( | |
| sequences=sequences, | |
| length=args.length, | |
| samples=args.samples, | |
| max_order=args.max_order, | |
| start_weights=start_weights, | |
| end_weights=end_weights, | |
| endpoint_strength=args.endpoint_strength, | |
| enforce_triplet_groups=not args.loose_triplets, | |
| seed=args.seed, | |
| ) | |
| write_samples( | |
| generated, | |
| output_dir=args.output_dir, | |
| key_name=args.key, | |
| engine_name="vo_regular baseline", | |
| write_abc=args.write_abc, | |
| write_musicxml_files=args.write_musicxml, | |
| ) | |
| settings = base_settings(args, stats, allowed_durations) | |
| settings["max order"] = args.max_order | |
| settings.update(diagnostics) | |
| write_generation_report( | |
| output_dir=args.output_dir, | |
| title="VO-Regular Baseline Generation", | |
| description="This is the key-relative variable-order Markov baseline.", | |
| settings=settings, | |
| stats=stats, | |
| generated=generated, | |
| write_abc=args.write_abc, | |
| write_musicxml=args.write_musicxml, | |
| ) | |
| print(f"Wrote {args.output_dir}") | |
| def run_transformer(args: argparse.Namespace) -> None: | |
| cfg = TransformerConfig( | |
| block_size=args.block_size, | |
| d_model=args.d_model, | |
| nhead=args.nhead, | |
| num_layers=args.layers, | |
| dim_feedforward=args.feedforward, | |
| dropout=args.dropout, | |
| batch_size=args.batch_size, | |
| steps=args.steps, | |
| learning_rate=args.learning_rate, | |
| temperature=args.temperature, | |
| top_k=args.top_k, | |
| max_retries=args.max_retries, | |
| ) | |
| allowed_durations, sequences, stats, priors = load_generation_inputs(args, min_len=max(6, cfg.block_size // 4)) | |
| start_weights, end_weights = priors | |
| if args.load_checkpoint: | |
| checkpoint = load_transformer_checkpoint(args.load_checkpoint, requested_device=args.device) | |
| generated, diagnostics = sample_transformer_checkpoint( | |
| checkpoint=checkpoint, | |
| length=args.length, | |
| samples=args.samples, | |
| start_weights=start_weights, | |
| end_weights=end_weights, | |
| endpoint_strength=args.endpoint_strength, | |
| enforce_triplet_groups=not args.loose_triplets, | |
| seed=args.seed, | |
| temperature=args.temperature, | |
| top_k=args.top_k, | |
| max_retries=args.max_retries, | |
| ) | |
| elif args.save_checkpoint: | |
| checkpoint = train_and_save_checkpoint( | |
| sequences=sequences, | |
| cfg=cfg, | |
| seed=args.seed, | |
| requested_device=args.device, | |
| path=args.save_checkpoint, | |
| ) | |
| generated, diagnostics = sample_transformer_checkpoint( | |
| checkpoint=checkpoint, | |
| length=args.length, | |
| samples=args.samples, | |
| start_weights=start_weights, | |
| end_weights=end_weights, | |
| endpoint_strength=args.endpoint_strength, | |
| enforce_triplet_groups=not args.loose_triplets, | |
| seed=args.seed, | |
| temperature=args.temperature, | |
| top_k=args.top_k, | |
| max_retries=args.max_retries, | |
| ) | |
| diagnostics["saved checkpoint"] = str(args.save_checkpoint) | |
| else: | |
| generated, diagnostics = generate_transformer( | |
| sequences=sequences, | |
| length=args.length, | |
| samples=args.samples, | |
| start_weights=start_weights, | |
| end_weights=end_weights, | |
| endpoint_strength=args.endpoint_strength, | |
| enforce_triplet_groups=not args.loose_triplets, | |
| seed=args.seed, | |
| cfg=cfg, | |
| device=args.device, | |
| ) | |
| write_samples( | |
| generated, | |
| output_dir=args.output_dir, | |
| key_name=args.key, | |
| engine_name="transformer baseline", | |
| write_abc=args.write_abc, | |
| write_musicxml_files=args.write_musicxml, | |
| ) | |
| settings = base_settings(args, stats, allowed_durations) | |
| settings.update( | |
| { | |
| "block size": cfg.block_size, | |
| "d model": cfg.d_model, | |
| "heads": cfg.nhead, | |
| "layers": cfg.num_layers, | |
| "feedforward": cfg.dim_feedforward, | |
| "dropout": cfg.dropout, | |
| "batch size": cfg.batch_size, | |
| "learning rate": cfg.learning_rate, | |
| "temperature": cfg.temperature, | |
| "top k": cfg.top_k, | |
| } | |
| ) | |
| settings.update(diagnostics) | |
| write_generation_report( | |
| output_dir=args.output_dir, | |
| title="Transformer Baseline Generation", | |
| description="This is the first key-relative tiny transformer baseline.", | |
| settings=settings, | |
| stats=stats, | |
| generated=generated, | |
| write_abc=args.write_abc, | |
| write_musicxml=args.write_musicxml, | |
| ) | |
| print(f"Wrote {args.output_dir}") | |
| def build_parser() -> argparse.ArgumentParser: | |
| parser = argparse.ArgumentParser(description="Generate short key-relative theme samples.") | |
| subparsers = parser.add_subparsers(dest="engine", required=True) | |
| markov = subparsers.add_parser("markov", help="Run the vo_regular variable-order Markov engine.") | |
| add_common_args(markov) | |
| markov.set_defaults(output_dir=Path("outputs/vo_regular_baseline"), func=run_markov) | |
| markov.add_argument("--max-order", type=int, default=4) | |
| transformer = subparsers.add_parser("transformer", help="Run the tiny PyTorch transformer engine.") | |
| add_common_args(transformer) | |
| transformer.set_defaults(output_dir=Path("outputs/transformer_baseline"), func=run_transformer) | |
| transformer.add_argument("--block-size", type=int, default=64) | |
| transformer.add_argument("--d-model", type=int, default=96) | |
| transformer.add_argument("--nhead", type=int, default=4) | |
| transformer.add_argument("--layers", type=int, default=3) | |
| transformer.add_argument("--feedforward", type=int, default=192) | |
| transformer.add_argument("--dropout", type=float, default=0.1) | |
| transformer.add_argument("--batch-size", type=int, default=64) | |
| transformer.add_argument("--steps", type=int, default=800) | |
| transformer.add_argument("--learning-rate", type=float, default=3e-4) | |
| transformer.add_argument("--temperature", type=float, default=1.0) | |
| transformer.add_argument("--top-k", type=int, default=16) | |
| transformer.add_argument("--max-retries", type=int, default=100) | |
| transformer.add_argument("--device", default="auto", help="auto, cpu, cuda, or mps.") | |
| transformer.add_argument("--save-checkpoint", type=Path, help="Train once and save a reusable checkpoint.") | |
| transformer.add_argument("--load-checkpoint", type=Path, help="Generate from a saved checkpoint instead of training.") | |
| return parser | |
| def main(argv: list[str] | None = None) -> None: | |
| parser = build_parser() | |
| args = parser.parse_args(argv) | |
| try: | |
| args.func(args) | |
| except RuntimeError as exc: | |
| parser.exit(1, f"error: {exc}\n") | |
| if __name__ == "__main__": | |
| main() | |