afridialeval / src /orchestrator.py
millicentochieng's picture
Upload folder using huggingface_hub
edf8cae verified
import argparse
import json
import subprocess
import sys
import time
from pathlib import Path
from typing import Any, List, Optional, Tuple
from src.config import (
DECODED_DIR,
ENCODED_DIR,
LOCALIZED_DIR,
RAW_DATA_PATH,
)
DEFAULT_ENCODING_MODEL = "gemini-3-flash"
DEFAULT_LOCALIZE_MODEL = "gemini-3-flash"
DEFAULT_DECODE_MODELS = [
"gemini-3-flash",
"gpt-5.1",
"gemma-3-27b-it",
]
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Run the African Dialogue Framework pipeline."
)
parser.add_argument(
"--input_path",
type=str,
default=str(RAW_DATA_PATH),
help="Path to raw DailyDialog JSON file.",
)
parser.add_argument(
"--language",
type=str,
required=True,
help="Target language, e.g. 'Swahili'.",
)
parser.add_argument(
"--region",
type=str,
default="",
help="Optional region/community, e.g. 'Kenya - Nairobi'.",
)
parser.add_argument(
"--decode_models",
type=str,
default=",".join(DEFAULT_DECODE_MODELS),
help="Comma-separated model aliases for decode comparison.",
)
parser.add_argument(
"--max_instances",
type=int,
default=None,
help="Optional cap on number of dialogues to process.",
)
parser.add_argument(
"--start_idx",
type=int,
default=0,
help="Optional start index for slicing input data.",
)
parser.add_argument(
"--end_idx",
type=int,
default=None,
help="Optional end index for slicing input data.",
)
parser.add_argument(
"--batch_size",
type=int,
default=20,
help="Number of examples to process per batch.",
)
parser.add_argument(
"--rest_seconds",
type=int,
default=300,
help="Seconds to rest between batches.",
)
parser.add_argument(
"--dont_use_cached",
action="store_true",
help="Disable cached prompt responses.",
)
return parser.parse_args()
def slugify(text: str) -> str:
return (
text.strip()
.lower()
.replace(" ", "_")
.replace("/", "_")
.replace("-", "_")
)
def build_run_tag(language: str, region: str) -> str:
lang = slugify(language)
if region.strip():
return f"{lang}_{slugify(region)}"
return lang
def parse_model_csv(raw: str) -> List[str]:
models = [item.strip() for item in raw.split(",") if item.strip()]
if not models:
raise ValueError("Model list cannot be empty.")
return models
def load_json(path: Path) -> Any:
return json.loads(path.read_text(encoding="utf-8"))
def save_json(path: Path, data: Any) -> None:
path.parent.mkdir(parents=True, exist_ok=True)
path.write_text(
json.dumps(data, indent=2, ensure_ascii=False),
encoding="utf-8",
)
def ensure_parent(path: Path) -> None:
path.parent.mkdir(parents=True, exist_ok=True)
def reset_output_file(path: Path) -> None:
ensure_parent(path)
save_json(path, [])
def append_json_list(destination: Path, batch_data: List[Any]) -> None:
ensure_parent(destination)
if destination.exists():
existing = load_json(destination)
if not isinstance(existing, list):
raise ValueError(f"Expected JSON list in existing file: {destination}")
else:
existing = []
existing.extend(batch_data)
save_json(destination, existing)
def cleanup_file(path: Path) -> None:
if path.exists():
path.unlink()
def run_python_module(module_name: str, args: List[str]) -> None:
cmd = [sys.executable, "-m", module_name] + args
print("\n[RUN]", " ".join(cmd))
subprocess.run(cmd, check=True)
def compute_effective_end_idx(args: argparse.Namespace) -> Optional[int]:
if args.end_idx is not None and args.max_instances is not None:
capped_end = args.start_idx + args.max_instances
return min(args.end_idx, capped_end)
if args.end_idx is not None:
return args.end_idx
if args.max_instances is not None:
return args.start_idx + args.max_instances
return None
def build_batch_ranges(
start_idx: int,
effective_end_idx: Optional[int],
batch_size: int,
) -> List[Tuple[int, Optional[int]]]:
if batch_size <= 0:
raise ValueError("--batch_size must be greater than 0.")
if effective_end_idx is None:
return [(start_idx, None)]
if effective_end_idx <= start_idx:
return []
ranges: List[Tuple[int, Optional[int]]] = []
current = start_idx
while current < effective_end_idx:
batch_end = min(current + batch_size, effective_end_idx)
ranges.append((current, batch_end))
current = batch_end
return ranges
def build_slice_args(
batch_start: int,
batch_end: Optional[int],
dont_use_cached: bool,
model: Optional[str] = None,
) -> List[str]:
args: List[str] = [
"--start_idx", str(batch_start),
]
if batch_end is not None:
args += ["--end_idx", str(batch_end)]
if model is not None:
args += ["--model", model]
if dont_use_cached:
args += ["--dont_use_cached"]
return args
def model_tag(model_alias: str) -> str:
return slugify(model_alias).replace(".", "_")
def make_paths(args: argparse.Namespace) -> dict[str, Any]:
run_tag = build_run_tag(args.language, args.region)
input_stem = Path(args.input_path).stem
# Encoding is region-independent (English → DAS). One file for all regions.
encoded_path = ENCODED_DIR / f"{input_stem}_encoded.json"
# Localization is done once with Gemini (shared across all decode models)
localized_path = LOCALIZED_DIR / f"{input_stem}_{run_tag}_localized.json"
# Each decode model gets its own output
decode_models = parse_model_csv(args.decode_models)
decoded_paths = {
model: DECODED_DIR / f"{input_stem}_{run_tag}_{model_tag(model)}_decoded.json"
for model in decode_models
}
temp_dir = LOCALIZED_DIR / "_tmp"
temp_localized = temp_dir / f"{input_stem}_{run_tag}_localized_tmp.json"
temp_decoded = {
model: DECODED_DIR / f"{input_stem}_{run_tag}_{model_tag(model)}_decoded_tmp.json"
for model in decode_models
}
return {
"encoded": encoded_path,
"localized": localized_path,
"decoded": decoded_paths,
"temp_localized": temp_localized,
"temp_decoded": temp_decoded,
}
def ensure_encoded_file(
args: argparse.Namespace,
encoded_path: Path,
) -> None:
if encoded_path.exists() and not args.dont_use_cached:
print(f"[ENCODE] Reusing existing encoded file: {encoded_path}")
return
print(f"[ENCODE] Using model: {DEFAULT_ENCODING_MODEL}")
encode_args: List[str] = [
"--input_path", args.input_path,
"--output_path", str(encoded_path),
"--model", DEFAULT_ENCODING_MODEL,
"--start_idx", str(args.start_idx),
]
if args.end_idx is not None:
encode_args += ["--end_idx", str(args.end_idx)]
if args.max_instances is not None:
encode_args += ["--max_instances", str(args.max_instances)]
if args.dont_use_cached:
encode_args += ["--dont_use_cached"]
run_python_module("src.encode", encode_args)
def ensure_localized_file(
args: argparse.Namespace,
encoded_path: Path,
localized_path: Path,
batch_ranges: List[Tuple[int, Optional[int]]],
temp_localized_path: Path,
) -> None:
if localized_path.exists() and not args.dont_use_cached:
print(f"[LOCALIZE] Reusing existing localized file: {localized_path}")
return
print(f"[LOCALIZE] Using model: {DEFAULT_LOCALIZE_MODEL}")
reset_output_file(localized_path)
total_batches = len(batch_ranges)
for batch_index, (batch_start, batch_end) in enumerate(batch_ranges):
print(
f"\n========== LOCALIZE BATCH {batch_index + 1}/{total_batches} "
f"(start_idx={batch_start}, end_idx={batch_end}) =========="
)
slice_args = build_slice_args(
batch_start=batch_start,
batch_end=batch_end,
model=DEFAULT_LOCALIZE_MODEL,
dont_use_cached=args.dont_use_cached,
)
run_python_module(
"src.localize",
[
"--input_path", str(encoded_path),
"--output_path", str(temp_localized_path),
"--language", args.language,
"--region", args.region,
*slice_args,
],
)
localized_batch = load_json(temp_localized_path)
append_json_list(localized_path, localized_batch)
cleanup_file(temp_localized_path)
is_last_batch = batch_index == total_batches - 1
if not is_last_batch and args.rest_seconds > 0:
print(f"\n[REST] Sleeping for {args.rest_seconds} seconds before next batch...")
time.sleep(args.rest_seconds)
def run_decode_batches(
args: argparse.Namespace,
localized_path: Path,
decoded_paths: dict[str, Path],
temp_decoded_paths: dict[str, Path],
batch_ranges: List[Tuple[int, Optional[int]]],
) -> None:
decode_models = parse_model_csv(args.decode_models)
for model in decode_models:
reset_output_file(decoded_paths[model])
total_batches = len(batch_ranges)
for batch_index, (batch_start, batch_end) in enumerate(batch_ranges):
print(
f"\n========== DECODE BATCH {batch_index + 1}/{total_batches} "
f"(start_idx={batch_start}, end_idx={batch_end}) =========="
)
for model in decode_models:
print(f"[DECODE] Running model: {model}")
try:
temp_decoded_path = temp_decoded_paths[model]
decode_args = [
"--input_path", str(localized_path),
"--output_path", str(temp_decoded_path),
"--language", args.language,
"--region", args.region,
"--model", model,
"--start_idx", str(batch_start),
]
if batch_end is not None:
decode_args += ["--end_idx", str(batch_end)]
if args.dont_use_cached:
decode_args += ["--dont_use_cached"]
run_python_module("src.decode", decode_args)
decoded_batch = load_json(temp_decoded_path)
append_json_list(decoded_paths[model], decoded_batch)
cleanup_file(temp_decoded_path)
except Exception as exc:
print(f"\n[ERROR] Model '{model}' failed on batch {batch_index + 1}: {exc}")
print(f"[SKIP] Continuing with remaining models...")
cleanup_file(temp_decoded_paths[model])
is_last_batch = batch_index == total_batches - 1
if not is_last_batch and args.rest_seconds > 0:
print(f"\n[REST] Sleeping for {args.rest_seconds} seconds before next batch...")
time.sleep(args.rest_seconds)
def main() -> None:
args = parse_args()
paths = make_paths(args)
effective_end_idx = compute_effective_end_idx(args)
batch_ranges = build_batch_ranges(
start_idx=args.start_idx,
effective_end_idx=effective_end_idx,
batch_size=args.batch_size,
)
if not batch_ranges:
print("No examples to process with the given slicing arguments.")
return
print(
f"language='{args.language}' | region='{args.region}' | "
f"batches={len(batch_ranges)} | batch_size={args.batch_size} | "
f"rest_seconds={args.rest_seconds}"
)
print(f"Encode/Localize model: {DEFAULT_LOCALIZE_MODEL}")
print(f"Decode models: {args.decode_models}")
# Step 1: Encode (Gemini, once, region-independent)
ensure_encoded_file(args, paths["encoded"])
# Step 2: Localize (Gemini, once per region)
ensure_localized_file(
args=args,
encoded_path=paths["encoded"],
localized_path=paths["localized"],
batch_ranges=batch_ranges,
temp_localized_path=paths["temp_localized"],
)
# Step 3: Decode (all models, from shared localized file)
run_decode_batches(
args=args,
localized_path=paths["localized"],
decoded_paths=paths["decoded"],
temp_decoded_paths=paths["temp_decoded"],
batch_ranges=batch_ranges,
)
print("\n[FINAL OUTPUTS]")
print(f"Encoded: {paths['encoded']}")
print(f"Localized: {paths['localized']}")
for model in parse_model_csv(args.decode_models):
print(f" {model} decoded: {paths['decoded'][model]}")
print("\nAll batches completed.")
if __name__ == "__main__":
main()