ace-step15-endpoint / scripts /annotations /qwen_caption_dataset.py
Andrew
Consolidate AF3/Qwen pipelines, endpoint templates, and setup docs
8bdd018
#!/usr/bin/env python
"""
Batch caption a music dataset with Qwen2-Audio and export LoRA-ready sidecars.
"""
from __future__ import annotations
import argparse
import os
import tempfile
from pathlib import Path
from typing import List
from huggingface_hub import HfApi, snapshot_download
from loguru import logger
from tqdm import tqdm
from qwen_audio_captioning import (
DEFAULT_ANALYSIS_PROMPT,
DEFAULT_LONG_ANALYSIS_PROMPT,
DEFAULT_MODEL_ID,
build_captioner,
export_annotation_records,
generate_track_annotation,
list_audio_files,
read_prompt_file,
)
def build_parser() -> argparse.ArgumentParser:
p = argparse.ArgumentParser(description="Qwen2-Audio batch captioning for LoRA datasets")
# Data source
p.add_argument("--dataset-dir", type=str, default="", help="Local dataset folder")
p.add_argument("--dataset-repo", type=str, default="", help="HF dataset repo id")
p.add_argument("--dataset-revision", type=str, default="main", help="HF dataset revision")
p.add_argument("--dataset-subdir", type=str, default="", help="Subdirectory inside dataset")
# Backend
p.add_argument("--backend", type=str, default="local", choices=["local", "hf_endpoint"])
p.add_argument("--model-id", type=str, default=DEFAULT_MODEL_ID)
p.add_argument("--endpoint-url", type=str, default="")
p.add_argument("--hf-token", type=str, default="", help="HF token (or use HF_TOKEN env var)")
p.add_argument("--device", type=str, default="auto", choices=["auto", "cuda", "cpu", "mps"])
p.add_argument("--torch-dtype", type=str, default="auto", choices=["auto", "float16", "bfloat16", "float32"])
# Prompt + generation controls
p.add_argument("--prompt", type=str, default=DEFAULT_ANALYSIS_PROMPT)
p.add_argument("--prompt-file", type=str, default="", help="Text file to override --prompt")
p.add_argument("--include-long-analysis", action="store_true", help="Also request long prose analysis")
p.add_argument("--long-analysis-prompt", type=str, default=DEFAULT_LONG_ANALYSIS_PROMPT)
p.add_argument("--long-analysis-prompt-file", type=str, default="", help="Text file to override --long-analysis-prompt")
p.add_argument("--long-analysis-max-new-tokens", type=int, default=1200)
p.add_argument("--long-analysis-temperature", type=float, default=0.1)
p.add_argument("--segment-seconds", type=float, default=30.0)
p.add_argument("--overlap-seconds", type=float, default=2.0)
p.add_argument("--max-new-tokens", type=int, default=384)
p.add_argument("--temperature", type=float, default=0.1)
p.add_argument("--keep-raw-outputs", action="store_true", help="Store per-segment raw outputs in sidecar JSON")
# Export
p.add_argument("--output-dir", type=str, default="qwen_annotations")
p.add_argument("--copy-audio", action="store_true", help="Copy audio files into output_dir/dataset")
p.add_argument(
"--write-inplace-sidecars",
action=argparse.BooleanOptionalAction,
default=True,
help="Write sidecars next to source audio (default: true). Use --no-write-inplace-sidecars to disable.",
)
# Optional upload of exported folder
p.add_argument("--upload-repo", type=str, default="", help="Optional HF dataset repo to upload exports")
p.add_argument("--upload-private", action="store_true", help="Create upload repo as private")
p.add_argument("--upload-path", type=str, default="", help="Optional path inside upload repo")
return p
def resolve_dataset_dir(args) -> str:
if args.dataset_dir:
if not Path(args.dataset_dir).is_dir():
raise FileNotFoundError(f"Dataset folder not found: {args.dataset_dir}")
return args.dataset_dir
if not args.dataset_repo:
raise ValueError("Provide --dataset-dir or --dataset-repo")
token = args.hf_token or os.getenv("HF_TOKEN", "")
temp_root = tempfile.mkdtemp(prefix="qwen_caption_dataset_")
local_dir = os.path.join(temp_root, "dataset")
logger.info(f"Downloading dataset {args.dataset_repo}@{args.dataset_revision} -> {local_dir}")
snapshot_download(
repo_id=args.dataset_repo,
repo_type="dataset",
revision=args.dataset_revision,
local_dir=local_dir,
local_dir_use_symlinks=False,
token=token or None,
)
if args.dataset_subdir:
sub = os.path.join(local_dir, args.dataset_subdir)
if not Path(sub).is_dir():
raise FileNotFoundError(f"Dataset subdir not found: {sub}")
return sub
return local_dir
def upload_export_if_requested(args, output_dir: str):
if not args.upload_repo:
return
token = args.hf_token or os.getenv("HF_TOKEN", "")
if not token:
raise RuntimeError("HF token missing. Set --hf-token or HF_TOKEN.")
api = HfApi(token=token)
api.create_repo(
repo_id=args.upload_repo,
repo_type="dataset",
private=bool(args.upload_private),
exist_ok=True,
)
path_in_repo = args.upload_path.strip().strip("/") if args.upload_path else ""
logger.info(f"Uploading {output_dir} -> {args.upload_repo}/{path_in_repo}")
api.upload_folder(
repo_id=args.upload_repo,
repo_type="dataset",
folder_path=output_dir,
path_in_repo=path_in_repo,
commit_message="Upload Qwen2-Audio annotations",
)
logger.info("Upload complete")
def main() -> int:
args = build_parser().parse_args()
prompt = read_prompt_file(args.prompt_file) if args.prompt_file else args.prompt
long_prompt = (
read_prompt_file(args.long_analysis_prompt_file)
if args.long_analysis_prompt_file
else args.long_analysis_prompt
)
token = args.hf_token or os.getenv("HF_TOKEN", "")
dataset_dir = resolve_dataset_dir(args)
audio_files: List[str] = list_audio_files(dataset_dir)
if not audio_files:
raise RuntimeError(f"No audio files found in {dataset_dir}")
logger.info(f"Found {len(audio_files)} audio files")
captioner = build_captioner(
backend=args.backend,
model_id=args.model_id,
endpoint_url=args.endpoint_url,
token=token,
device=args.device,
torch_dtype=args.torch_dtype,
)
records = []
failed = []
for path in tqdm(audio_files, desc="Captioning audio"):
try:
sidecar = generate_track_annotation(
audio_path=path,
captioner=captioner,
prompt=prompt,
segment_seconds=float(args.segment_seconds),
overlap_seconds=float(args.overlap_seconds),
max_new_tokens=int(args.max_new_tokens),
temperature=float(args.temperature),
keep_raw_outputs=bool(args.keep_raw_outputs),
include_long_analysis=bool(args.include_long_analysis),
long_analysis_prompt=long_prompt,
long_analysis_max_new_tokens=int(args.long_analysis_max_new_tokens),
long_analysis_temperature=float(args.long_analysis_temperature),
)
records.append({"audio_path": path, "sidecar": sidecar})
except Exception as exc:
failed.append(f"{Path(path).name}: {exc}")
logger.exception(f"Failed: {path}")
export_result = export_annotation_records(
records=records,
output_dir=args.output_dir,
copy_audio=bool(args.copy_audio),
write_inplace_sidecars=bool(args.write_inplace_sidecars),
)
logger.info(
"Done. analyzed={} failed={} manifest={}",
len(records),
len(failed),
export_result["manifest_path"],
)
if failed:
logger.warning("First failures:\n" + "\n".join(failed[:20]))
upload_export_if_requested(args, args.output_dir)
return 0
if __name__ == "__main__":
raise SystemExit(main())