Lora-ace-step / scripts /pipeline /refine_dataset_json_with_openai.py
Andrew
Consolidate AF3/Qwen pipelines, endpoint templates, and setup docs
8bdd018
#!/usr/bin/env python
"""
Refine existing dataset JSON annotations into Ace-Step 1.5 LoRA-ready sidecars.
This script:
1. Reads existing JSON files (typically containing AF3 `generated_text`).
2. Uses OpenAI cleanup (optionally with web search) to normalize/expand metadata.
3. Writes normalized sidecar JSON in-place (or to an output directory).
4. Creates backup copies before overwrite by default.
"""
from __future__ import annotations
import argparse
import json
import shutil
import sys
from pathlib import Path
from typing import Dict, Iterable, List, Optional, Tuple
from tqdm import tqdm
PROJECT_ROOT = Path(__file__).resolve().parents[2]
if str(PROJECT_ROOT) not in sys.path:
sys.path.insert(0, str(PROJECT_ROOT))
from af3_chatgpt_pipeline import ( # noqa: E402
DEFAULT_AF3_PROMPT,
DEFAULT_OPENAI_MODEL,
build_lora_sidecar,
cleanup_with_chatgpt,
)
from qwen_audio_captioning import AUDIO_EXTENSIONS # noqa: E402
from utils.env_config import get_env, load_project_env # noqa: E402
def _parse_args() -> argparse.Namespace:
load_project_env()
p = argparse.ArgumentParser(
description="Refine dataset JSONs into Ace-Step 1.5 LoRA-ready metadata using OpenAI."
)
p.add_argument("--dataset-dir", default="train-dataset", help="Directory containing source JSON files")
p.add_argument("--recursive", action="store_true", help="Include nested folders")
p.add_argument("--pattern", default="*.json", help="Filename glob pattern")
p.add_argument("--output-dir", default="", help="Optional output folder. Default: overwrite in place")
p.add_argument(
"--backup-ext",
default=".backup-before-openai.json",
help="Backup extension for in-place writes",
)
p.add_argument("--no-backup", action="store_true", help="Disable backup creation for in-place writes")
p.add_argument("--limit", type=int, default=0, help="Process only first N files (0 = all)")
p.add_argument("--artist-default", default="Andrew Spacey", help="Fallback artist if parsing fails")
p.add_argument("--user-context", default="", help="Extra guidance passed to OpenAI cleanup")
p.add_argument("--openai-api-key", default="", help="Overrides OPENAI_API_KEY from .env")
p.add_argument(
"--openai-model",
default=get_env("OPENAI_MODEL", "openai_model", default=DEFAULT_OPENAI_MODEL),
help="OpenAI model id",
)
p.add_argument(
"--enable-web-search",
action="store_true",
help="Enable web search tool for artist/track context lookup",
)
p.add_argument("--fail-fast", action="store_true", help="Stop on first failure")
p.add_argument("--dry-run", action="store_true", help="Do not write files")
return p.parse_args()
def _iter_json_files(dataset_dir: Path, pattern: str, recursive: bool) -> List[Path]:
if recursive:
return sorted(dataset_dir.rglob(pattern))
return sorted(dataset_dir.glob(pattern))
def _load_json(path: Path) -> Dict:
# Handle both standard UTF-8 and UTF-8 with BOM.
text = path.read_text(encoding="utf-8-sig")
data = json.loads(text)
if not isinstance(data, dict):
raise ValueError("Top-level JSON is not an object")
return data
def _detect_audio_path(json_path: Path) -> Optional[Path]:
stem = json_path.stem
for ext in AUDIO_EXTENSIONS:
candidate = json_path.with_suffix(ext)
if candidate.exists():
return candidate
# Fallback to case-insensitive scan.
parent = json_path.parent
for f in parent.iterdir():
if f.is_file() and f.stem == stem and f.suffix.lower() in AUDIO_EXTENSIONS:
return f
return None
def _try_duration_seconds(audio_path: Optional[Path], fallback: float = 0.0) -> float:
if audio_path is None or not audio_path.exists():
return float(fallback or 0.0)
try:
import soundfile as sf
info = sf.info(str(audio_path))
if info.samplerate and info.frames:
return float(info.frames) / float(info.samplerate)
except Exception:
pass
return float(fallback or 0.0)
def _parse_artist_track_from_stem(stem: str, artist_default: str) -> Tuple[str, str]:
parts = stem.split(" - ", 1)
if len(parts) == 2:
artist, track = parts[0].strip(), parts[1].strip()
if artist and track:
return artist, track
return artist_default.strip() or "Unknown Artist", stem.strip()
def _extract_raw_analysis(data: Dict) -> str:
checks: Iterable[object] = (
data.get("generated_text"),
data.get("af3_analysis"),
data.get("analysis_long"),
data.get("analysis_short"),
(data.get("music_analysis") or {}).get("summary_long") if isinstance(data.get("music_analysis"), dict) else None,
data.get("caption"),
)
for value in checks:
if isinstance(value, str) and value.strip():
return value.strip()
return ""
def _ensure_output_path(src_json: Path, output_dir: Optional[Path]) -> Path:
if output_dir:
output_dir.mkdir(parents=True, exist_ok=True)
return output_dir / src_json.name
return src_json
def _create_backup(src: Path, backup_ext: str) -> Optional[Path]:
backup_path = src.with_name(src.stem + backup_ext)
if backup_path.exists():
return backup_path
shutil.copy2(src, backup_path)
return backup_path
def _finalize_sidecar(
*,
cleaned: Dict,
raw_analysis: str,
duration: float,
source_audio: Optional[Path],
source_json: Path,
artist: str,
track_name: str,
openai_model: str,
web_search_used: bool,
) -> Dict:
source_audio_str = str(source_audio) if source_audio else ""
sidecar = build_lora_sidecar(
cleaned,
af3_text=raw_analysis,
af3_prompt=DEFAULT_AF3_PROMPT,
af3_backend="existing_json_refine",
af3_model_id="nvidia/audio-flamingo-3",
source_audio=source_audio_str,
duration=duration,
chatgpt_model=openai_model,
web_search_used=web_search_used,
)
sidecar["artist"] = artist
sidecar["track_name"] = track_name
sidecar["source"] = {
"input_json": str(source_json),
"input_audio": source_audio_str,
"refined_from_existing_json": True,
}
return sidecar
def main() -> int:
args = _parse_args()
dataset_dir = Path(args.dataset_dir)
if not dataset_dir.exists():
raise FileNotFoundError(f"Dataset directory not found: {dataset_dir}")
openai_key = args.openai_api_key or get_env("OPENAI_API_KEY", "openai_api_key")
if not openai_key:
raise RuntimeError("Missing OPENAI_API_KEY (set in .env or pass --openai-api-key).")
files = _iter_json_files(dataset_dir, pattern=args.pattern, recursive=bool(args.recursive))
if args.limit and args.limit > 0:
files = files[: int(args.limit)]
if not files:
raise RuntimeError(f"No files matched {args.pattern} in {dataset_dir}")
output_dir = Path(args.output_dir) if args.output_dir else None
failures: List[str] = []
saved: List[str] = []
backups: List[str] = []
for json_path in tqdm(files, desc="Refine JSON"):
try:
data = _load_json(json_path)
raw_analysis = _extract_raw_analysis(data)
if not raw_analysis:
raise ValueError("No analysis text found (generated_text/analysis/caption missing)")
artist, track_name = _parse_artist_track_from_stem(json_path.stem, args.artist_default)
artist = str(data.get("artist") or artist).strip() or artist
track_name = str(data.get("track_name") or data.get("title") or track_name).strip() or track_name
source_audio = _detect_audio_path(json_path)
duration = _try_duration_seconds(source_audio, fallback=float(data.get("duration") or 0.0))
try:
cleaned = cleanup_with_chatgpt(
raw_analysis,
openai_api_key=openai_key,
model=args.openai_model,
duration=duration,
user_context=args.user_context,
artist_name=artist,
track_name=track_name,
enable_web_search=bool(args.enable_web_search),
)
web_used = bool(args.enable_web_search)
except Exception:
# If web-search tool compatibility fails on this runtime, retry without it.
if not args.enable_web_search:
raise
cleaned = cleanup_with_chatgpt(
raw_analysis,
openai_api_key=openai_key,
model=args.openai_model,
duration=duration,
user_context=args.user_context,
artist_name=artist,
track_name=track_name,
enable_web_search=False,
)
web_used = False
sidecar = _finalize_sidecar(
cleaned=cleaned,
raw_analysis=raw_analysis,
duration=duration,
source_audio=source_audio,
source_json=json_path,
artist=artist,
track_name=track_name,
openai_model=args.openai_model,
web_search_used=web_used,
)
out_path = _ensure_output_path(json_path, output_dir)
if not args.dry_run and output_dir is None and not args.no_backup:
backup = _create_backup(json_path, args.backup_ext)
if backup:
backups.append(str(backup))
if not args.dry_run:
out_path.write_text(json.dumps(sidecar, indent=2, ensure_ascii=False), encoding="utf-8")
saved.append(str(out_path))
except Exception as exc:
failures.append(f"{json_path.name}: {exc}")
if args.fail_fast:
break
summary = {
"processed": len(files),
"saved": len(saved),
"failed": len(failures),
"backup_count": len(backups),
"output_mode": "separate_dir" if output_dir else ("dry_run" if args.dry_run else "in_place"),
"sample_saved": saved[:10],
"sample_failures": failures[:10],
}
print(json.dumps(summary, indent=2, ensure_ascii=False))
return 0 if not failures else 2
if __name__ == "__main__":
raise SystemExit(main())