afridialeval / src /multi_localize.py
millicentochieng's picture
Upload folder using huggingface_hub
e2b8b61 verified
import argparse
import json
from pathlib import Path
from typing import Any, Dict, List, Optional
from src.config import LOCALIZED_DIR, PROMPTS_DIR
from src.generator import Generator
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Generate multiple localization candidates using different models."
)
parser.add_argument(
"--input_path",
type=str,
required=True,
help="Path to encoded JSON file.",
)
parser.add_argument(
"--output_path",
type=str,
default=None,
help="Path to save multi-localized output JSON.",
)
parser.add_argument(
"--language",
type=str,
required=True,
help="Target language label, e.g. 'Swahili'.",
)
parser.add_argument(
"--region",
type=str,
default="",
help="Optional target region/community, e.g. 'Kenya - Nairobi'.",
)
parser.add_argument(
"--context_prompt",
type=str,
default=str(PROMPTS_DIR / "das_localize_context.md"),
help="Path to DAS context localization prompt.",
)
parser.add_argument(
"--localize_prompt",
type=str,
default=str(PROMPTS_DIR / "das_localize.md"),
help="Path to DAS localization prompt.",
)
parser.add_argument(
"--models",
type=str,
required=True,
help='Comma-separated model aliases, e.g. "gpt4,gemma,qwen".',
)
parser.add_argument(
"--max_instances",
type=int,
default=None,
help="Optional cap on number of dialogues to process.",
)
parser.add_argument(
"--start_idx",
type=int,
default=0,
help="Optional start index for slicing input data.",
)
parser.add_argument(
"--end_idx",
type=int,
default=None,
help="Optional end index for slicing input data.",
)
parser.add_argument(
"--dont_use_cached",
action="store_true",
help="Disable cached prompt responses.",
)
return parser.parse_args()
def load_json(path: str) -> Any:
return json.loads(Path(path).read_text(encoding="utf-8"))
def save_json(path: str, data: Any) -> None:
output_path = Path(path)
output_path.parent.mkdir(parents=True, exist_ok=True)
output_path.write_text(
json.dumps(data, indent=2, ensure_ascii=False),
encoding="utf-8",
)
def normalize_speaker_id(speaker_id: Any) -> str:
speaker = str(speaker_id)
if speaker == "1":
return "A"
if speaker == "2":
return "B"
return speaker
def stringify_functions(functions: Any) -> str:
if isinstance(functions, list):
return "; ".join(str(f) for f in functions)
return str(functions)
def preprocess_conversation(das_encoding: List[Dict[str, Any]]) -> str:
formatted_turns: List[str] = []
for idx, turn in enumerate(das_encoding, start=1):
speaker = normalize_speaker_id(turn.get("speaker_id", "A"))
functions = stringify_functions(turn.get("functions", []))
formatted_turns.append(f"{idx}: {speaker}.{functions}")
return "\n".join(formatted_turns)
def preprocess_localize_input(
data: List[Dict[str, Any]],
language: str,
region: str,
) -> List[Dict[str, Any]]:
processed: List[Dict[str, Any]] = []
for item in data:
if "das_encoding" not in item:
raise ValueError("Each input item must contain 'das_encoding'.")
if "context" not in item:
raise ValueError("Each input item must contain 'context'.")
new_item = dict(item)
new_item["language"] = language
new_item["region"] = region
new_item["turns"] = preprocess_conversation(item["das_encoding"])
processed.append(new_item)
return processed
def default_output_path(input_path: str, language: str, region: str) -> str:
stem = Path(input_path).stem
suffix_parts = [language.strip().lower().replace(" ", "_")]
if region.strip():
suffix_parts.append(region.strip().lower().replace(" ", "_").replace("/", "_"))
suffix = "_".join(suffix_parts)
return str(LOCALIZED_DIR / f"{stem}_{suffix}_multi_localized.json")
def parse_model_list(raw_models: str) -> List[str]:
models = [m.strip() for m in raw_models.split(",") if m.strip()]
if not models:
raise ValueError("You must provide at least one model alias in --models.")
return models
def build_candidate_for_model(
processed_data: List[Dict[str, Any]],
model_alias: str,
context_prompt: str,
localize_prompt: str,
dont_use_cached: bool,
) -> List[Dict[str, Any]]:
generator = Generator(
model_alias=model_alias,
use_cache=not dont_use_cached,
)
context_prompts, context_response_format = generator.build_prompts(
context_prompt,
processed_data,
)
context_responses = generator.prompt(
prompts=context_prompts,
response_format=context_response_format,
dont_use_cached=dont_use_cached,
)
context_localized_items: List[Dict[str, Any]] = []
for item, response_text in zip(processed_data, context_responses):
response_json = Generator.parse_json_response(response_text)
new_item = dict(item)
new_item["localized_context"] = response_json["localized_context"]
context_localized_items.append(new_item)
localize_prompts, localize_response_format = generator.build_prompts(
localize_prompt,
context_localized_items,
)
localize_responses = generator.prompt(
prompts=localize_prompts,
response_format=localize_response_format,
dont_use_cached=dont_use_cached,
)
final_items: List[Dict[str, Any]] = []
for item, response_text in zip(context_localized_items, localize_responses):
response_json = Generator.parse_json_response(response_text)
final_item = dict(item)
final_item["localized_das"] = response_json["localized_das"]
final_item["candidate_model"] = model_alias
final_items.append(final_item)
return final_items
def assemble_multi_localized_output(
base_items: List[Dict[str, Any]],
model_outputs: Dict[str, List[Dict[str, Any]]],
) -> List[Dict[str, Any]]:
if not model_outputs:
return []
model_names = list(model_outputs.keys())
first_model = model_names[0]
num_items = len(model_outputs[first_model])
for model_name, outputs in model_outputs.items():
if len(outputs) != num_items:
raise ValueError(
f"Model '{model_name}' produced {len(outputs)} items, expected {num_items}."
)
final_data: List[Dict[str, Any]] = []
for idx in range(num_items):
base_item = dict(base_items[idx])
candidates: List[Dict[str, Any]] = []
for model_name in model_names:
candidate_item = model_outputs[model_name][idx]
candidates.append(
{
"model": model_name,
"localized_context": candidate_item["localized_context"],
"localized_das": candidate_item["localized_das"],
}
)
base_item["localizer_candidates"] = candidates
final_data.append(base_item)
return final_data
def main() -> None:
args = parse_args()
raw_data = load_json(args.input_path)
if not isinstance(raw_data, list):
raise ValueError("Input JSON must be a list of dialogue objects.")
sliced_data = raw_data[args.start_idx:args.end_idx]
if args.max_instances is not None:
sliced_data = sliced_data[: args.max_instances]
output_path = args.output_path or default_output_path(
args.input_path,
args.language,
args.region,
)
model_aliases = parse_model_list(args.models)
processed_data = preprocess_localize_input(
data=sliced_data,
language=args.language,
region=args.region,
)
model_outputs: Dict[str, List[Dict[str, Any]]] = {}
for model_alias in model_aliases:
print(f"\n[MultiLocalize] Running localization with model: {model_alias}")
model_outputs[model_alias] = build_candidate_for_model(
processed_data=processed_data,
model_alias=model_alias,
context_prompt=args.context_prompt,
localize_prompt=args.localize_prompt,
dont_use_cached=args.dont_use_cached,
)
final_data = assemble_multi_localized_output(
base_items=processed_data,
model_outputs=model_outputs,
)
save_json(output_path, final_data)
print(f"Saved multi-localized data to: {output_path}")
if __name__ == "__main__":
main()