File size: 4,643 Bytes
8bdd018 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 | #!/usr/bin/env python
"""
Annotate one audio file with Qwen2-Audio and save a sidecar JSON.
"""
from __future__ import annotations
import argparse
import json
import os
from pathlib import Path
from qwen_audio_captioning import (
DEFAULT_ANALYSIS_PROMPT,
DEFAULT_LONG_ANALYSIS_PROMPT,
DEFAULT_MODEL_ID,
build_captioner,
generate_track_annotation,
read_prompt_file,
)
def read_dotenv_value(path: str, key: str) -> str:
p = Path(path)
if not p.exists():
return ""
for raw in p.read_text(encoding="utf-8").splitlines():
line = raw.strip()
if not line or line.startswith("#") or "=" not in line:
continue
k, v = line.split("=", 1)
if k.strip() == key:
return v.strip().strip('"').strip("'")
return ""
def main() -> int:
parser = argparse.ArgumentParser(description="Annotate a single audio file with Qwen2-Audio")
parser.add_argument("--audio", required=True, help="Audio file path")
parser.add_argument("--backend", default="hf_endpoint", choices=["local", "hf_endpoint"])
parser.add_argument("--model-id", default=DEFAULT_MODEL_ID)
parser.add_argument("--endpoint-url", default=os.getenv("HF_QWEN_ENDPOINT_URL", ""))
parser.add_argument("--token", default="")
parser.add_argument("--device", default="auto", choices=["auto", "cuda", "cpu", "mps"])
parser.add_argument("--torch-dtype", default="auto", choices=["auto", "float16", "bfloat16", "float32"])
parser.add_argument("--prompt", default=DEFAULT_ANALYSIS_PROMPT)
parser.add_argument("--prompt-file", default="")
parser.add_argument("--include-long-analysis", action="store_true")
parser.add_argument("--long-analysis-prompt", default=DEFAULT_LONG_ANALYSIS_PROMPT)
parser.add_argument("--long-analysis-prompt-file", default="")
parser.add_argument("--long-analysis-max-new-tokens", type=int, default=1200)
parser.add_argument("--long-analysis-temperature", type=float, default=0.1)
parser.add_argument("--segment-seconds", type=float, default=30.0)
parser.add_argument("--overlap-seconds", type=float, default=2.0)
parser.add_argument("--max-new-tokens", type=int, default=384)
parser.add_argument("--temperature", type=float, default=0.1)
parser.add_argument("--keep-raw-outputs", action="store_true")
parser.add_argument("--output-json", default="", help="Output JSON path (default: audio sidecar)")
args = parser.parse_args()
audio_path = Path(args.audio)
if not audio_path.is_file():
raise FileNotFoundError(f"Audio not found: {audio_path}")
prompt = read_prompt_file(args.prompt_file) if args.prompt_file else args.prompt
long_prompt = (
read_prompt_file(args.long_analysis_prompt_file)
if args.long_analysis_prompt_file
else args.long_analysis_prompt
)
token = (
args.token
or os.getenv("HF_TOKEN", "")
or read_dotenv_value(".env", "HF_TOKEN")
or read_dotenv_value(".env", "hf_token")
)
captioner = build_captioner(
backend=args.backend,
model_id=args.model_id,
endpoint_url=args.endpoint_url,
token=token,
device=args.device,
torch_dtype=args.torch_dtype,
)
sidecar = generate_track_annotation(
audio_path=str(audio_path),
captioner=captioner,
prompt=prompt,
segment_seconds=float(args.segment_seconds),
overlap_seconds=float(args.overlap_seconds),
max_new_tokens=int(args.max_new_tokens),
temperature=float(args.temperature),
keep_raw_outputs=bool(args.keep_raw_outputs),
include_long_analysis=bool(args.include_long_analysis),
long_analysis_prompt=long_prompt,
long_analysis_max_new_tokens=int(args.long_analysis_max_new_tokens),
long_analysis_temperature=float(args.long_analysis_temperature),
)
out_path = Path(args.output_json) if args.output_json else audio_path.with_suffix(".json")
out_path.write_text(json.dumps(sidecar, indent=2, ensure_ascii=False), encoding="utf-8")
print(
json.dumps(
{
"saved_to": str(out_path),
"caption": sidecar.get("caption", ""),
"bpm": sidecar.get("bpm"),
"keyscale": sidecar.get("keyscale", ""),
"duration": sidecar.get("duration"),
"segment_count": sidecar.get("music_analysis", {}).get("segment_count"),
},
indent=2,
ensure_ascii=False,
)
)
return 0
if __name__ == "__main__":
raise SystemExit(main())
|