pachet's picture
Deploy Theme Lab Docker app
187bf9a
Raw
History Blame Contribute Delete
10.5 kB
"""Command-line entry points for generation engines."""
from __future__ import annotations
import argparse
from pathlib import Path
from .common import duration_options
from .corpus import endpoint_priors, load_sequences, symbol_stats
from .engines.markov import generate_markov
from .engines.transformer import (
TransformerConfig,
generate_transformer,
load_transformer_checkpoint,
sample_transformer_checkpoint,
train_and_save_checkpoint,
)
from .io import write_samples
from .reports import format_allowed_durations, write_generation_report
def add_common_args(parser: argparse.ArgumentParser) -> None:
parser.add_argument("--db", type=Path, default=Path("audit/themes_audit.sqlite"))
parser.add_argument("--output-dir", type=Path)
parser.add_argument("--length", type=int, default=24)
parser.add_argument("--samples", type=int, default=12)
parser.add_argument("--key", default="C")
parser.add_argument("--endpoint-strength", type=float, default=1.0)
parser.add_argument("--seed", type=int, default=42)
parser.add_argument(
"--min-duration",
default="16th",
help="Shortest generated/training duration label.",
)
parser.add_argument(
"--duration-grid",
default="16th",
help="Require generated/training durations to be multiples of this value.",
)
parser.add_argument(
"--no-triplets",
action="store_true",
help="Exclude regular triplet durations from the generated/training vocabulary.",
)
parser.add_argument(
"--loose-triplets",
action="store_true",
help="Allow triplet durations outside complete beat-aligned groups.",
)
parser.add_argument("--write-abc", action="store_true", help="Also write ABC files next to the MIDIs.")
parser.add_argument(
"--write-musicxml",
action="store_true",
help="Also write MusicXML files next to the MIDIs.",
)
def load_generation_inputs(args: argparse.Namespace, *, min_len: int):
allowed_durations = duration_options(args.min_duration, args.duration_grid, not args.no_triplets)
if not allowed_durations:
raise ValueError(f"No allowed durations remain for min duration {args.min_duration!r}")
sequences = load_sequences(args.db, allowed_durations, min_len=min_len)
if not sequences:
raise ValueError("No training sequences matched the selected duration and length settings")
return allowed_durations, sequences, symbol_stats(sequences), endpoint_priors(args.db)
def base_settings(args: argparse.Namespace, stats: dict, allowed_durations: set[str]) -> dict[str, object]:
return {
"sequences": stats["sequence_count"],
"events": stats["event_count"],
"vocabulary size": stats["vocab_size"],
"generated note length": args.length,
"samples": args.samples,
"output key": args.key,
"minimum duration": args.min_duration,
"duration grid": args.duration_grid,
"triplets allowed": not args.no_triplets,
"triplets grouped": not args.loose_triplets,
"allowed durations": format_allowed_durations(allowed_durations),
"endpoint strength": args.endpoint_strength,
}
def run_markov(args: argparse.Namespace) -> None:
allowed_durations, sequences, stats, priors = load_generation_inputs(args, min_len=max(6, args.max_order + 1))
start_weights, end_weights = priors
generated, diagnostics = generate_markov(
sequences=sequences,
length=args.length,
samples=args.samples,
max_order=args.max_order,
start_weights=start_weights,
end_weights=end_weights,
endpoint_strength=args.endpoint_strength,
enforce_triplet_groups=not args.loose_triplets,
seed=args.seed,
)
write_samples(
generated,
output_dir=args.output_dir,
key_name=args.key,
engine_name="vo_regular baseline",
write_abc=args.write_abc,
write_musicxml_files=args.write_musicxml,
)
settings = base_settings(args, stats, allowed_durations)
settings["max order"] = args.max_order
settings.update(diagnostics)
write_generation_report(
output_dir=args.output_dir,
title="VO-Regular Baseline Generation",
description="This is the key-relative variable-order Markov baseline.",
settings=settings,
stats=stats,
generated=generated,
write_abc=args.write_abc,
write_musicxml=args.write_musicxml,
)
print(f"Wrote {args.output_dir}")
def run_transformer(args: argparse.Namespace) -> None:
cfg = TransformerConfig(
block_size=args.block_size,
d_model=args.d_model,
nhead=args.nhead,
num_layers=args.layers,
dim_feedforward=args.feedforward,
dropout=args.dropout,
batch_size=args.batch_size,
steps=args.steps,
learning_rate=args.learning_rate,
temperature=args.temperature,
top_k=args.top_k,
max_retries=args.max_retries,
)
allowed_durations, sequences, stats, priors = load_generation_inputs(args, min_len=max(6, cfg.block_size // 4))
start_weights, end_weights = priors
if args.load_checkpoint:
checkpoint = load_transformer_checkpoint(args.load_checkpoint, requested_device=args.device)
generated, diagnostics = sample_transformer_checkpoint(
checkpoint=checkpoint,
length=args.length,
samples=args.samples,
start_weights=start_weights,
end_weights=end_weights,
endpoint_strength=args.endpoint_strength,
enforce_triplet_groups=not args.loose_triplets,
seed=args.seed,
temperature=args.temperature,
top_k=args.top_k,
max_retries=args.max_retries,
)
elif args.save_checkpoint:
checkpoint = train_and_save_checkpoint(
sequences=sequences,
cfg=cfg,
seed=args.seed,
requested_device=args.device,
path=args.save_checkpoint,
)
generated, diagnostics = sample_transformer_checkpoint(
checkpoint=checkpoint,
length=args.length,
samples=args.samples,
start_weights=start_weights,
end_weights=end_weights,
endpoint_strength=args.endpoint_strength,
enforce_triplet_groups=not args.loose_triplets,
seed=args.seed,
temperature=args.temperature,
top_k=args.top_k,
max_retries=args.max_retries,
)
diagnostics["saved checkpoint"] = str(args.save_checkpoint)
else:
generated, diagnostics = generate_transformer(
sequences=sequences,
length=args.length,
samples=args.samples,
start_weights=start_weights,
end_weights=end_weights,
endpoint_strength=args.endpoint_strength,
enforce_triplet_groups=not args.loose_triplets,
seed=args.seed,
cfg=cfg,
device=args.device,
)
write_samples(
generated,
output_dir=args.output_dir,
key_name=args.key,
engine_name="transformer baseline",
write_abc=args.write_abc,
write_musicxml_files=args.write_musicxml,
)
settings = base_settings(args, stats, allowed_durations)
settings.update(
{
"block size": cfg.block_size,
"d model": cfg.d_model,
"heads": cfg.nhead,
"layers": cfg.num_layers,
"feedforward": cfg.dim_feedforward,
"dropout": cfg.dropout,
"batch size": cfg.batch_size,
"learning rate": cfg.learning_rate,
"temperature": cfg.temperature,
"top k": cfg.top_k,
}
)
settings.update(diagnostics)
write_generation_report(
output_dir=args.output_dir,
title="Transformer Baseline Generation",
description="This is the first key-relative tiny transformer baseline.",
settings=settings,
stats=stats,
generated=generated,
write_abc=args.write_abc,
write_musicxml=args.write_musicxml,
)
print(f"Wrote {args.output_dir}")
def build_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(description="Generate short key-relative theme samples.")
subparsers = parser.add_subparsers(dest="engine", required=True)
markov = subparsers.add_parser("markov", help="Run the vo_regular variable-order Markov engine.")
add_common_args(markov)
markov.set_defaults(output_dir=Path("outputs/vo_regular_baseline"), func=run_markov)
markov.add_argument("--max-order", type=int, default=4)
transformer = subparsers.add_parser("transformer", help="Run the tiny PyTorch transformer engine.")
add_common_args(transformer)
transformer.set_defaults(output_dir=Path("outputs/transformer_baseline"), func=run_transformer)
transformer.add_argument("--block-size", type=int, default=64)
transformer.add_argument("--d-model", type=int, default=96)
transformer.add_argument("--nhead", type=int, default=4)
transformer.add_argument("--layers", type=int, default=3)
transformer.add_argument("--feedforward", type=int, default=192)
transformer.add_argument("--dropout", type=float, default=0.1)
transformer.add_argument("--batch-size", type=int, default=64)
transformer.add_argument("--steps", type=int, default=800)
transformer.add_argument("--learning-rate", type=float, default=3e-4)
transformer.add_argument("--temperature", type=float, default=1.0)
transformer.add_argument("--top-k", type=int, default=16)
transformer.add_argument("--max-retries", type=int, default=100)
transformer.add_argument("--device", default="auto", help="auto, cpu, cuda, or mps.")
transformer.add_argument("--save-checkpoint", type=Path, help="Train once and save a reusable checkpoint.")
transformer.add_argument("--load-checkpoint", type=Path, help="Generate from a saved checkpoint instead of training.")
return parser
def main(argv: list[str] | None = None) -> None:
parser = build_parser()
args = parser.parse_args(argv)
try:
args.func(args)
except RuntimeError as exc:
parser.exit(1, f"error: {exc}\n")
if __name__ == "__main__":
main()