import argparse import json import subprocess import sys import time from pathlib import Path from typing import Any, List, Optional, Tuple from src.config import ( DECODED_DIR, ENCODED_DIR, LOCALIZED_DIR, RAW_DATA_PATH, ) DEFAULT_ENCODING_MODEL = "gemini-3-flash" DEFAULT_LOCALIZE_MODEL = "gemini-3-flash" DEFAULT_DECODE_MODELS = [ "gemini-3-flash", "gpt-5.1", "gemma-3-27b-it", ] def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser( description="Run the African Dialogue Framework pipeline." ) parser.add_argument( "--input_path", type=str, default=str(RAW_DATA_PATH), help="Path to raw DailyDialog JSON file.", ) parser.add_argument( "--language", type=str, required=True, help="Target language, e.g. 'Swahili'.", ) parser.add_argument( "--region", type=str, default="", help="Optional region/community, e.g. 'Kenya - Nairobi'.", ) parser.add_argument( "--decode_models", type=str, default=",".join(DEFAULT_DECODE_MODELS), help="Comma-separated model aliases for decode comparison.", ) parser.add_argument( "--max_instances", type=int, default=None, help="Optional cap on number of dialogues to process.", ) parser.add_argument( "--start_idx", type=int, default=0, help="Optional start index for slicing input data.", ) parser.add_argument( "--end_idx", type=int, default=None, help="Optional end index for slicing input data.", ) parser.add_argument( "--batch_size", type=int, default=20, help="Number of examples to process per batch.", ) parser.add_argument( "--rest_seconds", type=int, default=300, help="Seconds to rest between batches.", ) parser.add_argument( "--dont_use_cached", action="store_true", help="Disable cached prompt responses.", ) return parser.parse_args() def slugify(text: str) -> str: return ( text.strip() .lower() .replace(" ", "_") .replace("/", "_") .replace("-", "_") ) def build_run_tag(language: str, region: str) -> str: lang = slugify(language) if region.strip(): return f"{lang}_{slugify(region)}" return lang def parse_model_csv(raw: str) -> List[str]: models = [item.strip() for item in raw.split(",") if item.strip()] if not models: raise ValueError("Model list cannot be empty.") return models def load_json(path: Path) -> Any: return json.loads(path.read_text(encoding="utf-8")) def save_json(path: Path, data: Any) -> None: path.parent.mkdir(parents=True, exist_ok=True) path.write_text( json.dumps(data, indent=2, ensure_ascii=False), encoding="utf-8", ) def ensure_parent(path: Path) -> None: path.parent.mkdir(parents=True, exist_ok=True) def reset_output_file(path: Path) -> None: ensure_parent(path) save_json(path, []) def append_json_list(destination: Path, batch_data: List[Any]) -> None: ensure_parent(destination) if destination.exists(): existing = load_json(destination) if not isinstance(existing, list): raise ValueError(f"Expected JSON list in existing file: {destination}") else: existing = [] existing.extend(batch_data) save_json(destination, existing) def cleanup_file(path: Path) -> None: if path.exists(): path.unlink() def run_python_module(module_name: str, args: List[str]) -> None: cmd = [sys.executable, "-m", module_name] + args print("\n[RUN]", " ".join(cmd)) subprocess.run(cmd, check=True) def compute_effective_end_idx(args: argparse.Namespace) -> Optional[int]: if args.end_idx is not None and args.max_instances is not None: capped_end = args.start_idx + args.max_instances return min(args.end_idx, capped_end) if args.end_idx is not None: return args.end_idx if args.max_instances is not None: return args.start_idx + args.max_instances return None def build_batch_ranges( start_idx: int, effective_end_idx: Optional[int], batch_size: int, ) -> List[Tuple[int, Optional[int]]]: if batch_size <= 0: raise ValueError("--batch_size must be greater than 0.") if effective_end_idx is None: return [(start_idx, None)] if effective_end_idx <= start_idx: return [] ranges: List[Tuple[int, Optional[int]]] = [] current = start_idx while current < effective_end_idx: batch_end = min(current + batch_size, effective_end_idx) ranges.append((current, batch_end)) current = batch_end return ranges def build_slice_args( batch_start: int, batch_end: Optional[int], dont_use_cached: bool, model: Optional[str] = None, ) -> List[str]: args: List[str] = [ "--start_idx", str(batch_start), ] if batch_end is not None: args += ["--end_idx", str(batch_end)] if model is not None: args += ["--model", model] if dont_use_cached: args += ["--dont_use_cached"] return args def model_tag(model_alias: str) -> str: return slugify(model_alias).replace(".", "_") def make_paths(args: argparse.Namespace) -> dict[str, Any]: run_tag = build_run_tag(args.language, args.region) input_stem = Path(args.input_path).stem # Encoding is region-independent (English → DAS). One file for all regions. encoded_path = ENCODED_DIR / f"{input_stem}_encoded.json" # Localization is done once with Gemini (shared across all decode models) localized_path = LOCALIZED_DIR / f"{input_stem}_{run_tag}_localized.json" # Each decode model gets its own output decode_models = parse_model_csv(args.decode_models) decoded_paths = { model: DECODED_DIR / f"{input_stem}_{run_tag}_{model_tag(model)}_decoded.json" for model in decode_models } temp_dir = LOCALIZED_DIR / "_tmp" temp_localized = temp_dir / f"{input_stem}_{run_tag}_localized_tmp.json" temp_decoded = { model: DECODED_DIR / f"{input_stem}_{run_tag}_{model_tag(model)}_decoded_tmp.json" for model in decode_models } return { "encoded": encoded_path, "localized": localized_path, "decoded": decoded_paths, "temp_localized": temp_localized, "temp_decoded": temp_decoded, } def ensure_encoded_file( args: argparse.Namespace, encoded_path: Path, ) -> None: if encoded_path.exists() and not args.dont_use_cached: print(f"[ENCODE] Reusing existing encoded file: {encoded_path}") return print(f"[ENCODE] Using model: {DEFAULT_ENCODING_MODEL}") encode_args: List[str] = [ "--input_path", args.input_path, "--output_path", str(encoded_path), "--model", DEFAULT_ENCODING_MODEL, "--start_idx", str(args.start_idx), ] if args.end_idx is not None: encode_args += ["--end_idx", str(args.end_idx)] if args.max_instances is not None: encode_args += ["--max_instances", str(args.max_instances)] if args.dont_use_cached: encode_args += ["--dont_use_cached"] run_python_module("src.encode", encode_args) def ensure_localized_file( args: argparse.Namespace, encoded_path: Path, localized_path: Path, batch_ranges: List[Tuple[int, Optional[int]]], temp_localized_path: Path, ) -> None: if localized_path.exists() and not args.dont_use_cached: print(f"[LOCALIZE] Reusing existing localized file: {localized_path}") return print(f"[LOCALIZE] Using model: {DEFAULT_LOCALIZE_MODEL}") reset_output_file(localized_path) total_batches = len(batch_ranges) for batch_index, (batch_start, batch_end) in enumerate(batch_ranges): print( f"\n========== LOCALIZE BATCH {batch_index + 1}/{total_batches} " f"(start_idx={batch_start}, end_idx={batch_end}) ==========" ) slice_args = build_slice_args( batch_start=batch_start, batch_end=batch_end, model=DEFAULT_LOCALIZE_MODEL, dont_use_cached=args.dont_use_cached, ) run_python_module( "src.localize", [ "--input_path", str(encoded_path), "--output_path", str(temp_localized_path), "--language", args.language, "--region", args.region, *slice_args, ], ) localized_batch = load_json(temp_localized_path) append_json_list(localized_path, localized_batch) cleanup_file(temp_localized_path) is_last_batch = batch_index == total_batches - 1 if not is_last_batch and args.rest_seconds > 0: print(f"\n[REST] Sleeping for {args.rest_seconds} seconds before next batch...") time.sleep(args.rest_seconds) def run_decode_batches( args: argparse.Namespace, localized_path: Path, decoded_paths: dict[str, Path], temp_decoded_paths: dict[str, Path], batch_ranges: List[Tuple[int, Optional[int]]], ) -> None: decode_models = parse_model_csv(args.decode_models) for model in decode_models: reset_output_file(decoded_paths[model]) total_batches = len(batch_ranges) for batch_index, (batch_start, batch_end) in enumerate(batch_ranges): print( f"\n========== DECODE BATCH {batch_index + 1}/{total_batches} " f"(start_idx={batch_start}, end_idx={batch_end}) ==========" ) for model in decode_models: print(f"[DECODE] Running model: {model}") try: temp_decoded_path = temp_decoded_paths[model] decode_args = [ "--input_path", str(localized_path), "--output_path", str(temp_decoded_path), "--language", args.language, "--region", args.region, "--model", model, "--start_idx", str(batch_start), ] if batch_end is not None: decode_args += ["--end_idx", str(batch_end)] if args.dont_use_cached: decode_args += ["--dont_use_cached"] run_python_module("src.decode", decode_args) decoded_batch = load_json(temp_decoded_path) append_json_list(decoded_paths[model], decoded_batch) cleanup_file(temp_decoded_path) except Exception as exc: print(f"\n[ERROR] Model '{model}' failed on batch {batch_index + 1}: {exc}") print(f"[SKIP] Continuing with remaining models...") cleanup_file(temp_decoded_paths[model]) is_last_batch = batch_index == total_batches - 1 if not is_last_batch and args.rest_seconds > 0: print(f"\n[REST] Sleeping for {args.rest_seconds} seconds before next batch...") time.sleep(args.rest_seconds) def main() -> None: args = parse_args() paths = make_paths(args) effective_end_idx = compute_effective_end_idx(args) batch_ranges = build_batch_ranges( start_idx=args.start_idx, effective_end_idx=effective_end_idx, batch_size=args.batch_size, ) if not batch_ranges: print("No examples to process with the given slicing arguments.") return print( f"language='{args.language}' | region='{args.region}' | " f"batches={len(batch_ranges)} | batch_size={args.batch_size} | " f"rest_seconds={args.rest_seconds}" ) print(f"Encode/Localize model: {DEFAULT_LOCALIZE_MODEL}") print(f"Decode models: {args.decode_models}") # Step 1: Encode (Gemini, once, region-independent) ensure_encoded_file(args, paths["encoded"]) # Step 2: Localize (Gemini, once per region) ensure_localized_file( args=args, encoded_path=paths["encoded"], localized_path=paths["localized"], batch_ranges=batch_ranges, temp_localized_path=paths["temp_localized"], ) # Step 3: Decode (all models, from shared localized file) run_decode_batches( args=args, localized_path=paths["localized"], decoded_paths=paths["decoded"], temp_decoded_paths=paths["temp_decoded"], batch_ranges=batch_ranges, ) print("\n[FINAL OUTPUTS]") print(f"Encoded: {paths['encoded']}") print(f"Localized: {paths['localized']}") for model in parse_model_csv(args.decode_models): print(f" {model} decoded: {paths['decoded'][model]}") print("\nAll batches completed.") if __name__ == "__main__": main()