Spaces:
Sleeping
Sleeping
| import argparse | |
| import json | |
| from pathlib import Path | |
| from typing import Any, Dict, List, Optional | |
| from src.config import LOCALIZED_DIR, PROMPTS_DIR | |
| from src.generator import Generator | |
| def parse_args() -> argparse.Namespace: | |
| parser = argparse.ArgumentParser( | |
| description="Generate multiple localization candidates using different models." | |
| ) | |
| parser.add_argument( | |
| "--input_path", | |
| type=str, | |
| required=True, | |
| help="Path to encoded JSON file.", | |
| ) | |
| parser.add_argument( | |
| "--output_path", | |
| type=str, | |
| default=None, | |
| help="Path to save multi-localized output JSON.", | |
| ) | |
| parser.add_argument( | |
| "--language", | |
| type=str, | |
| required=True, | |
| help="Target language label, e.g. 'Swahili'.", | |
| ) | |
| parser.add_argument( | |
| "--region", | |
| type=str, | |
| default="", | |
| help="Optional target region/community, e.g. 'Kenya - Nairobi'.", | |
| ) | |
| parser.add_argument( | |
| "--context_prompt", | |
| type=str, | |
| default=str(PROMPTS_DIR / "das_localize_context.md"), | |
| help="Path to DAS context localization prompt.", | |
| ) | |
| parser.add_argument( | |
| "--localize_prompt", | |
| type=str, | |
| default=str(PROMPTS_DIR / "das_localize.md"), | |
| help="Path to DAS localization prompt.", | |
| ) | |
| parser.add_argument( | |
| "--models", | |
| type=str, | |
| required=True, | |
| help='Comma-separated model aliases, e.g. "gpt4,gemma,qwen".', | |
| ) | |
| parser.add_argument( | |
| "--max_instances", | |
| type=int, | |
| default=None, | |
| help="Optional cap on number of dialogues to process.", | |
| ) | |
| parser.add_argument( | |
| "--start_idx", | |
| type=int, | |
| default=0, | |
| help="Optional start index for slicing input data.", | |
| ) | |
| parser.add_argument( | |
| "--end_idx", | |
| type=int, | |
| default=None, | |
| help="Optional end index for slicing input data.", | |
| ) | |
| parser.add_argument( | |
| "--dont_use_cached", | |
| action="store_true", | |
| help="Disable cached prompt responses.", | |
| ) | |
| return parser.parse_args() | |
| def load_json(path: str) -> Any: | |
| return json.loads(Path(path).read_text(encoding="utf-8")) | |
| def save_json(path: str, data: Any) -> None: | |
| output_path = Path(path) | |
| output_path.parent.mkdir(parents=True, exist_ok=True) | |
| output_path.write_text( | |
| json.dumps(data, indent=2, ensure_ascii=False), | |
| encoding="utf-8", | |
| ) | |
| def normalize_speaker_id(speaker_id: Any) -> str: | |
| speaker = str(speaker_id) | |
| if speaker == "1": | |
| return "A" | |
| if speaker == "2": | |
| return "B" | |
| return speaker | |
| def stringify_functions(functions: Any) -> str: | |
| if isinstance(functions, list): | |
| return "; ".join(str(f) for f in functions) | |
| return str(functions) | |
| def preprocess_conversation(das_encoding: List[Dict[str, Any]]) -> str: | |
| formatted_turns: List[str] = [] | |
| for idx, turn in enumerate(das_encoding, start=1): | |
| speaker = normalize_speaker_id(turn.get("speaker_id", "A")) | |
| functions = stringify_functions(turn.get("functions", [])) | |
| formatted_turns.append(f"{idx}: {speaker}.{functions}") | |
| return "\n".join(formatted_turns) | |
| def preprocess_localize_input( | |
| data: List[Dict[str, Any]], | |
| language: str, | |
| region: str, | |
| ) -> List[Dict[str, Any]]: | |
| processed: List[Dict[str, Any]] = [] | |
| for item in data: | |
| if "das_encoding" not in item: | |
| raise ValueError("Each input item must contain 'das_encoding'.") | |
| if "context" not in item: | |
| raise ValueError("Each input item must contain 'context'.") | |
| new_item = dict(item) | |
| new_item["language"] = language | |
| new_item["region"] = region | |
| new_item["turns"] = preprocess_conversation(item["das_encoding"]) | |
| processed.append(new_item) | |
| return processed | |
| def default_output_path(input_path: str, language: str, region: str) -> str: | |
| stem = Path(input_path).stem | |
| suffix_parts = [language.strip().lower().replace(" ", "_")] | |
| if region.strip(): | |
| suffix_parts.append(region.strip().lower().replace(" ", "_").replace("/", "_")) | |
| suffix = "_".join(suffix_parts) | |
| return str(LOCALIZED_DIR / f"{stem}_{suffix}_multi_localized.json") | |
| def parse_model_list(raw_models: str) -> List[str]: | |
| models = [m.strip() for m in raw_models.split(",") if m.strip()] | |
| if not models: | |
| raise ValueError("You must provide at least one model alias in --models.") | |
| return models | |
| def build_candidate_for_model( | |
| processed_data: List[Dict[str, Any]], | |
| model_alias: str, | |
| context_prompt: str, | |
| localize_prompt: str, | |
| dont_use_cached: bool, | |
| ) -> List[Dict[str, Any]]: | |
| generator = Generator( | |
| model_alias=model_alias, | |
| use_cache=not dont_use_cached, | |
| ) | |
| context_prompts, context_response_format = generator.build_prompts( | |
| context_prompt, | |
| processed_data, | |
| ) | |
| context_responses = generator.prompt( | |
| prompts=context_prompts, | |
| response_format=context_response_format, | |
| dont_use_cached=dont_use_cached, | |
| ) | |
| context_localized_items: List[Dict[str, Any]] = [] | |
| for item, response_text in zip(processed_data, context_responses): | |
| response_json = Generator.parse_json_response(response_text) | |
| new_item = dict(item) | |
| new_item["localized_context"] = response_json["localized_context"] | |
| context_localized_items.append(new_item) | |
| localize_prompts, localize_response_format = generator.build_prompts( | |
| localize_prompt, | |
| context_localized_items, | |
| ) | |
| localize_responses = generator.prompt( | |
| prompts=localize_prompts, | |
| response_format=localize_response_format, | |
| dont_use_cached=dont_use_cached, | |
| ) | |
| final_items: List[Dict[str, Any]] = [] | |
| for item, response_text in zip(context_localized_items, localize_responses): | |
| response_json = Generator.parse_json_response(response_text) | |
| final_item = dict(item) | |
| final_item["localized_das"] = response_json["localized_das"] | |
| final_item["candidate_model"] = model_alias | |
| final_items.append(final_item) | |
| return final_items | |
| def assemble_multi_localized_output( | |
| base_items: List[Dict[str, Any]], | |
| model_outputs: Dict[str, List[Dict[str, Any]]], | |
| ) -> List[Dict[str, Any]]: | |
| if not model_outputs: | |
| return [] | |
| model_names = list(model_outputs.keys()) | |
| first_model = model_names[0] | |
| num_items = len(model_outputs[first_model]) | |
| for model_name, outputs in model_outputs.items(): | |
| if len(outputs) != num_items: | |
| raise ValueError( | |
| f"Model '{model_name}' produced {len(outputs)} items, expected {num_items}." | |
| ) | |
| final_data: List[Dict[str, Any]] = [] | |
| for idx in range(num_items): | |
| base_item = dict(base_items[idx]) | |
| candidates: List[Dict[str, Any]] = [] | |
| for model_name in model_names: | |
| candidate_item = model_outputs[model_name][idx] | |
| candidates.append( | |
| { | |
| "model": model_name, | |
| "localized_context": candidate_item["localized_context"], | |
| "localized_das": candidate_item["localized_das"], | |
| } | |
| ) | |
| base_item["localizer_candidates"] = candidates | |
| final_data.append(base_item) | |
| return final_data | |
| def main() -> None: | |
| args = parse_args() | |
| raw_data = load_json(args.input_path) | |
| if not isinstance(raw_data, list): | |
| raise ValueError("Input JSON must be a list of dialogue objects.") | |
| sliced_data = raw_data[args.start_idx:args.end_idx] | |
| if args.max_instances is not None: | |
| sliced_data = sliced_data[: args.max_instances] | |
| output_path = args.output_path or default_output_path( | |
| args.input_path, | |
| args.language, | |
| args.region, | |
| ) | |
| model_aliases = parse_model_list(args.models) | |
| processed_data = preprocess_localize_input( | |
| data=sliced_data, | |
| language=args.language, | |
| region=args.region, | |
| ) | |
| model_outputs: Dict[str, List[Dict[str, Any]]] = {} | |
| for model_alias in model_aliases: | |
| print(f"\n[MultiLocalize] Running localization with model: {model_alias}") | |
| model_outputs[model_alias] = build_candidate_for_model( | |
| processed_data=processed_data, | |
| model_alias=model_alias, | |
| context_prompt=args.context_prompt, | |
| localize_prompt=args.localize_prompt, | |
| dont_use_cached=args.dont_use_cached, | |
| ) | |
| final_data = assemble_multi_localized_output( | |
| base_items=processed_data, | |
| model_outputs=model_outputs, | |
| ) | |
| save_json(output_path, final_data) | |
| print(f"Saved multi-localized data to: {output_path}") | |
| if __name__ == "__main__": | |
| main() |