File size: 7,924 Bytes
8bdd018 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 | #!/usr/bin/env python
"""
Batch caption a music dataset with Qwen2-Audio and export LoRA-ready sidecars.
"""
from __future__ import annotations
import argparse
import os
import tempfile
from pathlib import Path
from typing import List
from huggingface_hub import HfApi, snapshot_download
from loguru import logger
from tqdm import tqdm
from qwen_audio_captioning import (
DEFAULT_ANALYSIS_PROMPT,
DEFAULT_LONG_ANALYSIS_PROMPT,
DEFAULT_MODEL_ID,
build_captioner,
export_annotation_records,
generate_track_annotation,
list_audio_files,
read_prompt_file,
)
def build_parser() -> argparse.ArgumentParser:
p = argparse.ArgumentParser(description="Qwen2-Audio batch captioning for LoRA datasets")
# Data source
p.add_argument("--dataset-dir", type=str, default="", help="Local dataset folder")
p.add_argument("--dataset-repo", type=str, default="", help="HF dataset repo id")
p.add_argument("--dataset-revision", type=str, default="main", help="HF dataset revision")
p.add_argument("--dataset-subdir", type=str, default="", help="Subdirectory inside dataset")
# Backend
p.add_argument("--backend", type=str, default="local", choices=["local", "hf_endpoint"])
p.add_argument("--model-id", type=str, default=DEFAULT_MODEL_ID)
p.add_argument("--endpoint-url", type=str, default="")
p.add_argument("--hf-token", type=str, default="", help="HF token (or use HF_TOKEN env var)")
p.add_argument("--device", type=str, default="auto", choices=["auto", "cuda", "cpu", "mps"])
p.add_argument("--torch-dtype", type=str, default="auto", choices=["auto", "float16", "bfloat16", "float32"])
# Prompt + generation controls
p.add_argument("--prompt", type=str, default=DEFAULT_ANALYSIS_PROMPT)
p.add_argument("--prompt-file", type=str, default="", help="Text file to override --prompt")
p.add_argument("--include-long-analysis", action="store_true", help="Also request long prose analysis")
p.add_argument("--long-analysis-prompt", type=str, default=DEFAULT_LONG_ANALYSIS_PROMPT)
p.add_argument("--long-analysis-prompt-file", type=str, default="", help="Text file to override --long-analysis-prompt")
p.add_argument("--long-analysis-max-new-tokens", type=int, default=1200)
p.add_argument("--long-analysis-temperature", type=float, default=0.1)
p.add_argument("--segment-seconds", type=float, default=30.0)
p.add_argument("--overlap-seconds", type=float, default=2.0)
p.add_argument("--max-new-tokens", type=int, default=384)
p.add_argument("--temperature", type=float, default=0.1)
p.add_argument("--keep-raw-outputs", action="store_true", help="Store per-segment raw outputs in sidecar JSON")
# Export
p.add_argument("--output-dir", type=str, default="qwen_annotations")
p.add_argument("--copy-audio", action="store_true", help="Copy audio files into output_dir/dataset")
p.add_argument(
"--write-inplace-sidecars",
action=argparse.BooleanOptionalAction,
default=True,
help="Write sidecars next to source audio (default: true). Use --no-write-inplace-sidecars to disable.",
)
# Optional upload of exported folder
p.add_argument("--upload-repo", type=str, default="", help="Optional HF dataset repo to upload exports")
p.add_argument("--upload-private", action="store_true", help="Create upload repo as private")
p.add_argument("--upload-path", type=str, default="", help="Optional path inside upload repo")
return p
def resolve_dataset_dir(args) -> str:
if args.dataset_dir:
if not Path(args.dataset_dir).is_dir():
raise FileNotFoundError(f"Dataset folder not found: {args.dataset_dir}")
return args.dataset_dir
if not args.dataset_repo:
raise ValueError("Provide --dataset-dir or --dataset-repo")
token = args.hf_token or os.getenv("HF_TOKEN", "")
temp_root = tempfile.mkdtemp(prefix="qwen_caption_dataset_")
local_dir = os.path.join(temp_root, "dataset")
logger.info(f"Downloading dataset {args.dataset_repo}@{args.dataset_revision} -> {local_dir}")
snapshot_download(
repo_id=args.dataset_repo,
repo_type="dataset",
revision=args.dataset_revision,
local_dir=local_dir,
local_dir_use_symlinks=False,
token=token or None,
)
if args.dataset_subdir:
sub = os.path.join(local_dir, args.dataset_subdir)
if not Path(sub).is_dir():
raise FileNotFoundError(f"Dataset subdir not found: {sub}")
return sub
return local_dir
def upload_export_if_requested(args, output_dir: str):
if not args.upload_repo:
return
token = args.hf_token or os.getenv("HF_TOKEN", "")
if not token:
raise RuntimeError("HF token missing. Set --hf-token or HF_TOKEN.")
api = HfApi(token=token)
api.create_repo(
repo_id=args.upload_repo,
repo_type="dataset",
private=bool(args.upload_private),
exist_ok=True,
)
path_in_repo = args.upload_path.strip().strip("/") if args.upload_path else ""
logger.info(f"Uploading {output_dir} -> {args.upload_repo}/{path_in_repo}")
api.upload_folder(
repo_id=args.upload_repo,
repo_type="dataset",
folder_path=output_dir,
path_in_repo=path_in_repo,
commit_message="Upload Qwen2-Audio annotations",
)
logger.info("Upload complete")
def main() -> int:
args = build_parser().parse_args()
prompt = read_prompt_file(args.prompt_file) if args.prompt_file else args.prompt
long_prompt = (
read_prompt_file(args.long_analysis_prompt_file)
if args.long_analysis_prompt_file
else args.long_analysis_prompt
)
token = args.hf_token or os.getenv("HF_TOKEN", "")
dataset_dir = resolve_dataset_dir(args)
audio_files: List[str] = list_audio_files(dataset_dir)
if not audio_files:
raise RuntimeError(f"No audio files found in {dataset_dir}")
logger.info(f"Found {len(audio_files)} audio files")
captioner = build_captioner(
backend=args.backend,
model_id=args.model_id,
endpoint_url=args.endpoint_url,
token=token,
device=args.device,
torch_dtype=args.torch_dtype,
)
records = []
failed = []
for path in tqdm(audio_files, desc="Captioning audio"):
try:
sidecar = generate_track_annotation(
audio_path=path,
captioner=captioner,
prompt=prompt,
segment_seconds=float(args.segment_seconds),
overlap_seconds=float(args.overlap_seconds),
max_new_tokens=int(args.max_new_tokens),
temperature=float(args.temperature),
keep_raw_outputs=bool(args.keep_raw_outputs),
include_long_analysis=bool(args.include_long_analysis),
long_analysis_prompt=long_prompt,
long_analysis_max_new_tokens=int(args.long_analysis_max_new_tokens),
long_analysis_temperature=float(args.long_analysis_temperature),
)
records.append({"audio_path": path, "sidecar": sidecar})
except Exception as exc:
failed.append(f"{Path(path).name}: {exc}")
logger.exception(f"Failed: {path}")
export_result = export_annotation_records(
records=records,
output_dir=args.output_dir,
copy_audio=bool(args.copy_audio),
write_inplace_sidecars=bool(args.write_inplace_sidecars),
)
logger.info(
"Done. analyzed={} failed={} manifest={}",
len(records),
len(failed),
export_result["manifest_path"],
)
if failed:
logger.warning("First failures:\n" + "\n".join(failed[:20]))
upload_export_if_requested(args, args.output_dir)
return 0
if __name__ == "__main__":
raise SystemExit(main())
|