| |
| """ |
| Batch caption a music dataset with Qwen2-Audio and export LoRA-ready sidecars. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import os |
| import tempfile |
| from pathlib import Path |
| from typing import List |
|
|
| from huggingface_hub import HfApi, snapshot_download |
| from loguru import logger |
| from tqdm import tqdm |
|
|
| from qwen_audio_captioning import ( |
| DEFAULT_ANALYSIS_PROMPT, |
| DEFAULT_LONG_ANALYSIS_PROMPT, |
| DEFAULT_MODEL_ID, |
| build_captioner, |
| export_annotation_records, |
| generate_track_annotation, |
| list_audio_files, |
| read_prompt_file, |
| ) |
|
|
|
|
| def build_parser() -> argparse.ArgumentParser: |
| p = argparse.ArgumentParser(description="Qwen2-Audio batch captioning for LoRA datasets") |
|
|
| |
| p.add_argument("--dataset-dir", type=str, default="", help="Local dataset folder") |
| p.add_argument("--dataset-repo", type=str, default="", help="HF dataset repo id") |
| p.add_argument("--dataset-revision", type=str, default="main", help="HF dataset revision") |
| p.add_argument("--dataset-subdir", type=str, default="", help="Subdirectory inside dataset") |
|
|
| |
| p.add_argument("--backend", type=str, default="local", choices=["local", "hf_endpoint"]) |
| p.add_argument("--model-id", type=str, default=DEFAULT_MODEL_ID) |
| p.add_argument("--endpoint-url", type=str, default="") |
| p.add_argument("--hf-token", type=str, default="", help="HF token (or use HF_TOKEN env var)") |
| p.add_argument("--device", type=str, default="auto", choices=["auto", "cuda", "cpu", "mps"]) |
| p.add_argument("--torch-dtype", type=str, default="auto", choices=["auto", "float16", "bfloat16", "float32"]) |
|
|
| |
| p.add_argument("--prompt", type=str, default=DEFAULT_ANALYSIS_PROMPT) |
| p.add_argument("--prompt-file", type=str, default="", help="Text file to override --prompt") |
| p.add_argument("--include-long-analysis", action="store_true", help="Also request long prose analysis") |
| p.add_argument("--long-analysis-prompt", type=str, default=DEFAULT_LONG_ANALYSIS_PROMPT) |
| p.add_argument("--long-analysis-prompt-file", type=str, default="", help="Text file to override --long-analysis-prompt") |
| p.add_argument("--long-analysis-max-new-tokens", type=int, default=1200) |
| p.add_argument("--long-analysis-temperature", type=float, default=0.1) |
| p.add_argument("--segment-seconds", type=float, default=30.0) |
| p.add_argument("--overlap-seconds", type=float, default=2.0) |
| p.add_argument("--max-new-tokens", type=int, default=384) |
| p.add_argument("--temperature", type=float, default=0.1) |
| p.add_argument("--keep-raw-outputs", action="store_true", help="Store per-segment raw outputs in sidecar JSON") |
|
|
| |
| p.add_argument("--output-dir", type=str, default="qwen_annotations") |
| p.add_argument("--copy-audio", action="store_true", help="Copy audio files into output_dir/dataset") |
| p.add_argument( |
| "--write-inplace-sidecars", |
| action=argparse.BooleanOptionalAction, |
| default=True, |
| help="Write sidecars next to source audio (default: true). Use --no-write-inplace-sidecars to disable.", |
| ) |
|
|
| |
| p.add_argument("--upload-repo", type=str, default="", help="Optional HF dataset repo to upload exports") |
| p.add_argument("--upload-private", action="store_true", help="Create upload repo as private") |
| p.add_argument("--upload-path", type=str, default="", help="Optional path inside upload repo") |
|
|
| return p |
|
|
|
|
| def resolve_dataset_dir(args) -> str: |
| if args.dataset_dir: |
| if not Path(args.dataset_dir).is_dir(): |
| raise FileNotFoundError(f"Dataset folder not found: {args.dataset_dir}") |
| return args.dataset_dir |
|
|
| if not args.dataset_repo: |
| raise ValueError("Provide --dataset-dir or --dataset-repo") |
|
|
| token = args.hf_token or os.getenv("HF_TOKEN", "") |
| temp_root = tempfile.mkdtemp(prefix="qwen_caption_dataset_") |
| local_dir = os.path.join(temp_root, "dataset") |
| logger.info(f"Downloading dataset {args.dataset_repo}@{args.dataset_revision} -> {local_dir}") |
| snapshot_download( |
| repo_id=args.dataset_repo, |
| repo_type="dataset", |
| revision=args.dataset_revision, |
| local_dir=local_dir, |
| local_dir_use_symlinks=False, |
| token=token or None, |
| ) |
| if args.dataset_subdir: |
| sub = os.path.join(local_dir, args.dataset_subdir) |
| if not Path(sub).is_dir(): |
| raise FileNotFoundError(f"Dataset subdir not found: {sub}") |
| return sub |
| return local_dir |
|
|
|
|
| def upload_export_if_requested(args, output_dir: str): |
| if not args.upload_repo: |
| return |
| token = args.hf_token or os.getenv("HF_TOKEN", "") |
| if not token: |
| raise RuntimeError("HF token missing. Set --hf-token or HF_TOKEN.") |
|
|
| api = HfApi(token=token) |
| api.create_repo( |
| repo_id=args.upload_repo, |
| repo_type="dataset", |
| private=bool(args.upload_private), |
| exist_ok=True, |
| ) |
| path_in_repo = args.upload_path.strip().strip("/") if args.upload_path else "" |
| logger.info(f"Uploading {output_dir} -> {args.upload_repo}/{path_in_repo}") |
| api.upload_folder( |
| repo_id=args.upload_repo, |
| repo_type="dataset", |
| folder_path=output_dir, |
| path_in_repo=path_in_repo, |
| commit_message="Upload Qwen2-Audio annotations", |
| ) |
| logger.info("Upload complete") |
|
|
|
|
| def main() -> int: |
| args = build_parser().parse_args() |
| prompt = read_prompt_file(args.prompt_file) if args.prompt_file else args.prompt |
| long_prompt = ( |
| read_prompt_file(args.long_analysis_prompt_file) |
| if args.long_analysis_prompt_file |
| else args.long_analysis_prompt |
| ) |
| token = args.hf_token or os.getenv("HF_TOKEN", "") |
|
|
| dataset_dir = resolve_dataset_dir(args) |
| audio_files: List[str] = list_audio_files(dataset_dir) |
| if not audio_files: |
| raise RuntimeError(f"No audio files found in {dataset_dir}") |
| logger.info(f"Found {len(audio_files)} audio files") |
|
|
| captioner = build_captioner( |
| backend=args.backend, |
| model_id=args.model_id, |
| endpoint_url=args.endpoint_url, |
| token=token, |
| device=args.device, |
| torch_dtype=args.torch_dtype, |
| ) |
|
|
| records = [] |
| failed = [] |
| for path in tqdm(audio_files, desc="Captioning audio"): |
| try: |
| sidecar = generate_track_annotation( |
| audio_path=path, |
| captioner=captioner, |
| prompt=prompt, |
| segment_seconds=float(args.segment_seconds), |
| overlap_seconds=float(args.overlap_seconds), |
| max_new_tokens=int(args.max_new_tokens), |
| temperature=float(args.temperature), |
| keep_raw_outputs=bool(args.keep_raw_outputs), |
| include_long_analysis=bool(args.include_long_analysis), |
| long_analysis_prompt=long_prompt, |
| long_analysis_max_new_tokens=int(args.long_analysis_max_new_tokens), |
| long_analysis_temperature=float(args.long_analysis_temperature), |
| ) |
| records.append({"audio_path": path, "sidecar": sidecar}) |
| except Exception as exc: |
| failed.append(f"{Path(path).name}: {exc}") |
| logger.exception(f"Failed: {path}") |
|
|
| export_result = export_annotation_records( |
| records=records, |
| output_dir=args.output_dir, |
| copy_audio=bool(args.copy_audio), |
| write_inplace_sidecars=bool(args.write_inplace_sidecars), |
| ) |
|
|
| logger.info( |
| "Done. analyzed={} failed={} manifest={}", |
| len(records), |
| len(failed), |
| export_result["manifest_path"], |
| ) |
| if failed: |
| logger.warning("First failures:\n" + "\n".join(failed[:20])) |
|
|
| upload_export_if_requested(args, args.output_dir) |
| return 0 |
|
|
|
|
| if __name__ == "__main__": |
| raise SystemExit(main()) |
|
|