Spaces:
Running
Running
| import argparse | |
| import json | |
| import subprocess | |
| import sys | |
| import time | |
| from pathlib import Path | |
| from typing import Any, List, Optional, Tuple | |
| from src.config import ( | |
| DECODED_DIR, | |
| ENCODED_DIR, | |
| LOCALIZED_DIR, | |
| RAW_DATA_PATH, | |
| ) | |
| DEFAULT_ENCODING_MODEL = "gemini-3-flash" | |
| DEFAULT_LOCALIZE_MODEL = "gemini-3-flash" | |
| DEFAULT_DECODE_MODELS = [ | |
| "gemini-3-flash", | |
| "gpt-5.1", | |
| "gemma-3-27b-it", | |
| ] | |
| def parse_args() -> argparse.Namespace: | |
| parser = argparse.ArgumentParser( | |
| description="Run the African Dialogue Framework pipeline." | |
| ) | |
| parser.add_argument( | |
| "--input_path", | |
| type=str, | |
| default=str(RAW_DATA_PATH), | |
| help="Path to raw DailyDialog JSON file.", | |
| ) | |
| parser.add_argument( | |
| "--language", | |
| type=str, | |
| required=True, | |
| help="Target language, e.g. 'Swahili'.", | |
| ) | |
| parser.add_argument( | |
| "--region", | |
| type=str, | |
| default="", | |
| help="Optional region/community, e.g. 'Kenya - Nairobi'.", | |
| ) | |
| parser.add_argument( | |
| "--decode_models", | |
| type=str, | |
| default=",".join(DEFAULT_DECODE_MODELS), | |
| help="Comma-separated model aliases for decode comparison.", | |
| ) | |
| parser.add_argument( | |
| "--max_instances", | |
| type=int, | |
| default=None, | |
| help="Optional cap on number of dialogues to process.", | |
| ) | |
| parser.add_argument( | |
| "--start_idx", | |
| type=int, | |
| default=0, | |
| help="Optional start index for slicing input data.", | |
| ) | |
| parser.add_argument( | |
| "--end_idx", | |
| type=int, | |
| default=None, | |
| help="Optional end index for slicing input data.", | |
| ) | |
| parser.add_argument( | |
| "--batch_size", | |
| type=int, | |
| default=20, | |
| help="Number of examples to process per batch.", | |
| ) | |
| parser.add_argument( | |
| "--rest_seconds", | |
| type=int, | |
| default=300, | |
| help="Seconds to rest between batches.", | |
| ) | |
| parser.add_argument( | |
| "--dont_use_cached", | |
| action="store_true", | |
| help="Disable cached prompt responses.", | |
| ) | |
| return parser.parse_args() | |
| def slugify(text: str) -> str: | |
| return ( | |
| text.strip() | |
| .lower() | |
| .replace(" ", "_") | |
| .replace("/", "_") | |
| .replace("-", "_") | |
| ) | |
| def build_run_tag(language: str, region: str) -> str: | |
| lang = slugify(language) | |
| if region.strip(): | |
| return f"{lang}_{slugify(region)}" | |
| return lang | |
| def parse_model_csv(raw: str) -> List[str]: | |
| models = [item.strip() for item in raw.split(",") if item.strip()] | |
| if not models: | |
| raise ValueError("Model list cannot be empty.") | |
| return models | |
| def load_json(path: Path) -> Any: | |
| return json.loads(path.read_text(encoding="utf-8")) | |
| def save_json(path: Path, data: Any) -> None: | |
| path.parent.mkdir(parents=True, exist_ok=True) | |
| path.write_text( | |
| json.dumps(data, indent=2, ensure_ascii=False), | |
| encoding="utf-8", | |
| ) | |
| def ensure_parent(path: Path) -> None: | |
| path.parent.mkdir(parents=True, exist_ok=True) | |
| def reset_output_file(path: Path) -> None: | |
| ensure_parent(path) | |
| save_json(path, []) | |
| def append_json_list(destination: Path, batch_data: List[Any]) -> None: | |
| ensure_parent(destination) | |
| if destination.exists(): | |
| existing = load_json(destination) | |
| if not isinstance(existing, list): | |
| raise ValueError(f"Expected JSON list in existing file: {destination}") | |
| else: | |
| existing = [] | |
| existing.extend(batch_data) | |
| save_json(destination, existing) | |
| def cleanup_file(path: Path) -> None: | |
| if path.exists(): | |
| path.unlink() | |
| def run_python_module(module_name: str, args: List[str]) -> None: | |
| cmd = [sys.executable, "-m", module_name] + args | |
| print("\n[RUN]", " ".join(cmd)) | |
| subprocess.run(cmd, check=True) | |
| def compute_effective_end_idx(args: argparse.Namespace) -> Optional[int]: | |
| if args.end_idx is not None and args.max_instances is not None: | |
| capped_end = args.start_idx + args.max_instances | |
| return min(args.end_idx, capped_end) | |
| if args.end_idx is not None: | |
| return args.end_idx | |
| if args.max_instances is not None: | |
| return args.start_idx + args.max_instances | |
| return None | |
| def build_batch_ranges( | |
| start_idx: int, | |
| effective_end_idx: Optional[int], | |
| batch_size: int, | |
| ) -> List[Tuple[int, Optional[int]]]: | |
| if batch_size <= 0: | |
| raise ValueError("--batch_size must be greater than 0.") | |
| if effective_end_idx is None: | |
| return [(start_idx, None)] | |
| if effective_end_idx <= start_idx: | |
| return [] | |
| ranges: List[Tuple[int, Optional[int]]] = [] | |
| current = start_idx | |
| while current < effective_end_idx: | |
| batch_end = min(current + batch_size, effective_end_idx) | |
| ranges.append((current, batch_end)) | |
| current = batch_end | |
| return ranges | |
| def build_slice_args( | |
| batch_start: int, | |
| batch_end: Optional[int], | |
| dont_use_cached: bool, | |
| model: Optional[str] = None, | |
| ) -> List[str]: | |
| args: List[str] = [ | |
| "--start_idx", str(batch_start), | |
| ] | |
| if batch_end is not None: | |
| args += ["--end_idx", str(batch_end)] | |
| if model is not None: | |
| args += ["--model", model] | |
| if dont_use_cached: | |
| args += ["--dont_use_cached"] | |
| return args | |
| def model_tag(model_alias: str) -> str: | |
| return slugify(model_alias).replace(".", "_") | |
| def make_paths(args: argparse.Namespace) -> dict[str, Any]: | |
| run_tag = build_run_tag(args.language, args.region) | |
| input_stem = Path(args.input_path).stem | |
| # Encoding is region-independent (English → DAS). One file for all regions. | |
| encoded_path = ENCODED_DIR / f"{input_stem}_encoded.json" | |
| # Localization is done once with Gemini (shared across all decode models) | |
| localized_path = LOCALIZED_DIR / f"{input_stem}_{run_tag}_localized.json" | |
| # Each decode model gets its own output | |
| decode_models = parse_model_csv(args.decode_models) | |
| decoded_paths = { | |
| model: DECODED_DIR / f"{input_stem}_{run_tag}_{model_tag(model)}_decoded.json" | |
| for model in decode_models | |
| } | |
| temp_dir = LOCALIZED_DIR / "_tmp" | |
| temp_localized = temp_dir / f"{input_stem}_{run_tag}_localized_tmp.json" | |
| temp_decoded = { | |
| model: DECODED_DIR / f"{input_stem}_{run_tag}_{model_tag(model)}_decoded_tmp.json" | |
| for model in decode_models | |
| } | |
| return { | |
| "encoded": encoded_path, | |
| "localized": localized_path, | |
| "decoded": decoded_paths, | |
| "temp_localized": temp_localized, | |
| "temp_decoded": temp_decoded, | |
| } | |
| def ensure_encoded_file( | |
| args: argparse.Namespace, | |
| encoded_path: Path, | |
| ) -> None: | |
| if encoded_path.exists() and not args.dont_use_cached: | |
| print(f"[ENCODE] Reusing existing encoded file: {encoded_path}") | |
| return | |
| print(f"[ENCODE] Using model: {DEFAULT_ENCODING_MODEL}") | |
| encode_args: List[str] = [ | |
| "--input_path", args.input_path, | |
| "--output_path", str(encoded_path), | |
| "--model", DEFAULT_ENCODING_MODEL, | |
| "--start_idx", str(args.start_idx), | |
| ] | |
| if args.end_idx is not None: | |
| encode_args += ["--end_idx", str(args.end_idx)] | |
| if args.max_instances is not None: | |
| encode_args += ["--max_instances", str(args.max_instances)] | |
| if args.dont_use_cached: | |
| encode_args += ["--dont_use_cached"] | |
| run_python_module("src.encode", encode_args) | |
| def ensure_localized_file( | |
| args: argparse.Namespace, | |
| encoded_path: Path, | |
| localized_path: Path, | |
| batch_ranges: List[Tuple[int, Optional[int]]], | |
| temp_localized_path: Path, | |
| ) -> None: | |
| if localized_path.exists() and not args.dont_use_cached: | |
| print(f"[LOCALIZE] Reusing existing localized file: {localized_path}") | |
| return | |
| print(f"[LOCALIZE] Using model: {DEFAULT_LOCALIZE_MODEL}") | |
| reset_output_file(localized_path) | |
| total_batches = len(batch_ranges) | |
| for batch_index, (batch_start, batch_end) in enumerate(batch_ranges): | |
| print( | |
| f"\n========== LOCALIZE BATCH {batch_index + 1}/{total_batches} " | |
| f"(start_idx={batch_start}, end_idx={batch_end}) ==========" | |
| ) | |
| slice_args = build_slice_args( | |
| batch_start=batch_start, | |
| batch_end=batch_end, | |
| model=DEFAULT_LOCALIZE_MODEL, | |
| dont_use_cached=args.dont_use_cached, | |
| ) | |
| run_python_module( | |
| "src.localize", | |
| [ | |
| "--input_path", str(encoded_path), | |
| "--output_path", str(temp_localized_path), | |
| "--language", args.language, | |
| "--region", args.region, | |
| *slice_args, | |
| ], | |
| ) | |
| localized_batch = load_json(temp_localized_path) | |
| append_json_list(localized_path, localized_batch) | |
| cleanup_file(temp_localized_path) | |
| is_last_batch = batch_index == total_batches - 1 | |
| if not is_last_batch and args.rest_seconds > 0: | |
| print(f"\n[REST] Sleeping for {args.rest_seconds} seconds before next batch...") | |
| time.sleep(args.rest_seconds) | |
| def run_decode_batches( | |
| args: argparse.Namespace, | |
| localized_path: Path, | |
| decoded_paths: dict[str, Path], | |
| temp_decoded_paths: dict[str, Path], | |
| batch_ranges: List[Tuple[int, Optional[int]]], | |
| ) -> None: | |
| decode_models = parse_model_csv(args.decode_models) | |
| for model in decode_models: | |
| reset_output_file(decoded_paths[model]) | |
| total_batches = len(batch_ranges) | |
| for batch_index, (batch_start, batch_end) in enumerate(batch_ranges): | |
| print( | |
| f"\n========== DECODE BATCH {batch_index + 1}/{total_batches} " | |
| f"(start_idx={batch_start}, end_idx={batch_end}) ==========" | |
| ) | |
| for model in decode_models: | |
| print(f"[DECODE] Running model: {model}") | |
| try: | |
| temp_decoded_path = temp_decoded_paths[model] | |
| decode_args = [ | |
| "--input_path", str(localized_path), | |
| "--output_path", str(temp_decoded_path), | |
| "--language", args.language, | |
| "--region", args.region, | |
| "--model", model, | |
| "--start_idx", str(batch_start), | |
| ] | |
| if batch_end is not None: | |
| decode_args += ["--end_idx", str(batch_end)] | |
| if args.dont_use_cached: | |
| decode_args += ["--dont_use_cached"] | |
| run_python_module("src.decode", decode_args) | |
| decoded_batch = load_json(temp_decoded_path) | |
| append_json_list(decoded_paths[model], decoded_batch) | |
| cleanup_file(temp_decoded_path) | |
| except Exception as exc: | |
| print(f"\n[ERROR] Model '{model}' failed on batch {batch_index + 1}: {exc}") | |
| print(f"[SKIP] Continuing with remaining models...") | |
| cleanup_file(temp_decoded_paths[model]) | |
| is_last_batch = batch_index == total_batches - 1 | |
| if not is_last_batch and args.rest_seconds > 0: | |
| print(f"\n[REST] Sleeping for {args.rest_seconds} seconds before next batch...") | |
| time.sleep(args.rest_seconds) | |
| def main() -> None: | |
| args = parse_args() | |
| paths = make_paths(args) | |
| effective_end_idx = compute_effective_end_idx(args) | |
| batch_ranges = build_batch_ranges( | |
| start_idx=args.start_idx, | |
| effective_end_idx=effective_end_idx, | |
| batch_size=args.batch_size, | |
| ) | |
| if not batch_ranges: | |
| print("No examples to process with the given slicing arguments.") | |
| return | |
| print( | |
| f"language='{args.language}' | region='{args.region}' | " | |
| f"batches={len(batch_ranges)} | batch_size={args.batch_size} | " | |
| f"rest_seconds={args.rest_seconds}" | |
| ) | |
| print(f"Encode/Localize model: {DEFAULT_LOCALIZE_MODEL}") | |
| print(f"Decode models: {args.decode_models}") | |
| # Step 1: Encode (Gemini, once, region-independent) | |
| ensure_encoded_file(args, paths["encoded"]) | |
| # Step 2: Localize (Gemini, once per region) | |
| ensure_localized_file( | |
| args=args, | |
| encoded_path=paths["encoded"], | |
| localized_path=paths["localized"], | |
| batch_ranges=batch_ranges, | |
| temp_localized_path=paths["temp_localized"], | |
| ) | |
| # Step 3: Decode (all models, from shared localized file) | |
| run_decode_batches( | |
| args=args, | |
| localized_path=paths["localized"], | |
| decoded_paths=paths["decoded"], | |
| temp_decoded_paths=paths["temp_decoded"], | |
| batch_ranges=batch_ranges, | |
| ) | |
| print("\n[FINAL OUTPUTS]") | |
| print(f"Encoded: {paths['encoded']}") | |
| print(f"Localized: {paths['localized']}") | |
| for model in parse_model_csv(args.decode_models): | |
| print(f" {model} decoded: {paths['decoded'][model]}") | |
| print("\nAll batches completed.") | |
| if __name__ == "__main__": | |
| main() |