File size: 4,643 Bytes
8bdd018
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
#!/usr/bin/env python
"""
Annotate one audio file with Qwen2-Audio and save a sidecar JSON.
"""

from __future__ import annotations

import argparse
import json
import os
from pathlib import Path

from qwen_audio_captioning import (
    DEFAULT_ANALYSIS_PROMPT,
    DEFAULT_LONG_ANALYSIS_PROMPT,
    DEFAULT_MODEL_ID,
    build_captioner,
    generate_track_annotation,
    read_prompt_file,
)


def read_dotenv_value(path: str, key: str) -> str:
    p = Path(path)
    if not p.exists():
        return ""
    for raw in p.read_text(encoding="utf-8").splitlines():
        line = raw.strip()
        if not line or line.startswith("#") or "=" not in line:
            continue
        k, v = line.split("=", 1)
        if k.strip() == key:
            return v.strip().strip('"').strip("'")
    return ""


def main() -> int:
    parser = argparse.ArgumentParser(description="Annotate a single audio file with Qwen2-Audio")
    parser.add_argument("--audio", required=True, help="Audio file path")
    parser.add_argument("--backend", default="hf_endpoint", choices=["local", "hf_endpoint"])
    parser.add_argument("--model-id", default=DEFAULT_MODEL_ID)
    parser.add_argument("--endpoint-url", default=os.getenv("HF_QWEN_ENDPOINT_URL", ""))
    parser.add_argument("--token", default="")
    parser.add_argument("--device", default="auto", choices=["auto", "cuda", "cpu", "mps"])
    parser.add_argument("--torch-dtype", default="auto", choices=["auto", "float16", "bfloat16", "float32"])
    parser.add_argument("--prompt", default=DEFAULT_ANALYSIS_PROMPT)
    parser.add_argument("--prompt-file", default="")
    parser.add_argument("--include-long-analysis", action="store_true")
    parser.add_argument("--long-analysis-prompt", default=DEFAULT_LONG_ANALYSIS_PROMPT)
    parser.add_argument("--long-analysis-prompt-file", default="")
    parser.add_argument("--long-analysis-max-new-tokens", type=int, default=1200)
    parser.add_argument("--long-analysis-temperature", type=float, default=0.1)
    parser.add_argument("--segment-seconds", type=float, default=30.0)
    parser.add_argument("--overlap-seconds", type=float, default=2.0)
    parser.add_argument("--max-new-tokens", type=int, default=384)
    parser.add_argument("--temperature", type=float, default=0.1)
    parser.add_argument("--keep-raw-outputs", action="store_true")
    parser.add_argument("--output-json", default="", help="Output JSON path (default: audio sidecar)")
    args = parser.parse_args()

    audio_path = Path(args.audio)
    if not audio_path.is_file():
        raise FileNotFoundError(f"Audio not found: {audio_path}")

    prompt = read_prompt_file(args.prompt_file) if args.prompt_file else args.prompt
    long_prompt = (
        read_prompt_file(args.long_analysis_prompt_file)
        if args.long_analysis_prompt_file
        else args.long_analysis_prompt
    )
    token = (
        args.token
        or os.getenv("HF_TOKEN", "")
        or read_dotenv_value(".env", "HF_TOKEN")
        or read_dotenv_value(".env", "hf_token")
    )

    captioner = build_captioner(
        backend=args.backend,
        model_id=args.model_id,
        endpoint_url=args.endpoint_url,
        token=token,
        device=args.device,
        torch_dtype=args.torch_dtype,
    )

    sidecar = generate_track_annotation(
        audio_path=str(audio_path),
        captioner=captioner,
        prompt=prompt,
        segment_seconds=float(args.segment_seconds),
        overlap_seconds=float(args.overlap_seconds),
        max_new_tokens=int(args.max_new_tokens),
        temperature=float(args.temperature),
        keep_raw_outputs=bool(args.keep_raw_outputs),
        include_long_analysis=bool(args.include_long_analysis),
        long_analysis_prompt=long_prompt,
        long_analysis_max_new_tokens=int(args.long_analysis_max_new_tokens),
        long_analysis_temperature=float(args.long_analysis_temperature),
    )

    out_path = Path(args.output_json) if args.output_json else audio_path.with_suffix(".json")
    out_path.write_text(json.dumps(sidecar, indent=2, ensure_ascii=False), encoding="utf-8")
    print(
        json.dumps(
            {
                "saved_to": str(out_path),
                "caption": sidecar.get("caption", ""),
                "bpm": sidecar.get("bpm"),
                "keyscale": sidecar.get("keyscale", ""),
                "duration": sidecar.get("duration"),
                "segment_count": sidecar.get("music_analysis", {}).get("segment_count"),
            },
            indent=2,
            ensure_ascii=False,
        )
    )
    return 0


if __name__ == "__main__":
    raise SystemExit(main())