import argparse import json import re from pathlib import Path from typing import Any, Dict, List from src.config import DECODED_DIR, PROMPTS_DIR from src.generator import Generator from src.region_registry import get_region_description def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser( description="Decode localized DAS into natural dialogue." ) parser.add_argument( "--input_path", type=str, required=True, help="Path to localized JSON file.", ) parser.add_argument( "--output_path", type=str, default=None, help="Path to save decoded output JSON.", ) parser.add_argument( "--language", type=str, required=True, help="Target language label, e.g. 'Swahili'.", ) parser.add_argument( "--region", type=str, default="", help="Optional target region/community, e.g. 'Kenya - Nairobi'.", ) parser.add_argument( "--decode_prompt", type=str, default=str(PROMPTS_DIR / "das_decode.md"), help="Path to DAS decode prompt.", ) parser.add_argument( "--model", type=str, default=None, help="Model alias from model_registry.py", ) parser.add_argument( "--max_instances", type=int, default=None, help="Optional cap on number of dialogues to process.", ) parser.add_argument( "--start_idx", type=int, default=0, help="Optional start index for slicing input data.", ) parser.add_argument( "--end_idx", type=int, default=None, help="Optional end index for slicing input data.", ) parser.add_argument( "--dont_use_cached", action="store_true", help="Disable cached prompt responses.", ) return parser.parse_args() def load_json(path: str) -> Any: return json.loads(Path(path).read_text(encoding="utf-8")) def save_json(path: str, data: Any) -> None: output_path = Path(path) output_path.parent.mkdir(parents=True, exist_ok=True) output_path.write_text( json.dumps(data, indent=2, ensure_ascii=False), encoding="utf-8", ) def normalize_speaker_id(speaker_id: Any) -> str: speaker = str(speaker_id) if speaker == "1": return "A" if speaker == "2": return "B" return speaker def preprocess_localized_conversation(localized_das: List[Dict[str, Any]]) -> str: formatted_turns: List[str] = [] for idx, turn in enumerate(localized_das, start=1): speaker = normalize_speaker_id(turn.get("speaker_id", "A")) functions = turn.get("functions", "") if isinstance(functions, list): functions_str = "; ".join(str(f) for f in functions) else: functions_str = str(functions) formatted_turns.append(f"{idx}: {speaker}.{functions_str}") return "\n".join(formatted_turns) def get_original_dialogue(item: Dict[str, Any]) -> List[str]: if "original" in item and isinstance(item["original"], list): return item["original"] if "conversation" in item and isinstance(item["conversation"], list): return item["conversation"] if "dialogue" in item and isinstance(item["dialogue"], list): return item["dialogue"] if "utterances" in item and isinstance(item["utterances"], list): return item["utterances"] return [] def preprocess_decode_input( data: List[Dict[str, Any]], language: str, region: str, ) -> List[Dict[str, Any]]: decode_desc = get_region_description(region, "decode", language) or "" processed: List[Dict[str, Any]] = [] for item in data: if "localized_das" not in item: raise ValueError("Each input item must contain 'localized_das'.") if "localized_context" not in item: raise ValueError("Each input item must contain 'localized_context'.") if not isinstance(item["localized_das"], list): raise ValueError( f"localized_das must be a list, got {type(item['localized_das']).__name__}" ) new_item = dict(item) new_item["language"] = language new_item["region"] = region new_item["region_description"] = decode_desc new_item["turns"] = preprocess_localized_conversation(item["localized_das"]) new_item["localized_context"] = item["localized_context"] new_item["context"] = item["localized_context"] processed.append(new_item) return processed def strip_code_fences(text: str) -> str: text = text.strip() if text.startswith("```"): text = re.sub(r"^```[a-zA-Z0-9_+-]*\n?", "", text) text = re.sub(r"\n?```$", "", text) return text.strip() def parse_numbered_dialogue_string(raw: str) -> List[str]: raw = strip_code_fences(raw) # Split on numbered turns like "1: ...", "2. ..." matches = list(re.finditer(r"(?m)^\s*(\d+)[\:\.]\s*", raw)) if not matches: lines = [line.strip() for line in raw.splitlines() if line.strip()] return lines turns: List[str] = [] for idx, match in enumerate(matches): start = match.end() end = matches[idx + 1].start() if idx + 1 < len(matches) else len(raw) turn_text = raw[start:end].strip() if turn_text: turns.append(turn_text) return turns def normalize_generated_conversation(generated_conversation: Any) -> List[str]: # Case 1: expected list of objects with text if isinstance(generated_conversation, list): text_only_dialogue: List[str] = [] for turn in generated_conversation: if isinstance(turn, dict) and "text" in turn: text_only_dialogue.append(str(turn["text"]).strip()) elif isinstance(turn, str): text_only_dialogue.append(turn.strip()) else: raise ValueError( f"Unsupported generated_conversation list item: {turn}" ) return text_only_dialogue # Case 2: model returned one big numbered string if isinstance(generated_conversation, str): return parse_numbered_dialogue_string(generated_conversation) raise ValueError( f"Unsupported generated_conversation type: {type(generated_conversation).__name__}" ) def merge_decoded_responses( base_data: List[Dict[str, Any]], responses: List[str], language: str, ) -> List[Dict[str, Any]]: merged: List[Dict[str, Any]] = [] decoded_key = f"decoded_{language.strip().lower().replace(' ', '_')}" for item, response_text in zip(base_data, responses): if response_text is None: print(f"[Decode] Skipping item with failed generation") continue response_json = Generator.parse_json_response(response_text) if "generated_conversation" not in response_json: raise ValueError( f"Missing 'generated_conversation' in model response:\n{response_text}" ) generated_conversation = response_json["generated_conversation"] text_only_dialogue = normalize_generated_conversation(generated_conversation) output_item: Dict[str, Any] = {} if "dialogue_id" in item: output_item["dialogue_id"] = item["dialogue_id"] elif "id" in item: output_item["dialogue_id"] = item["id"] output_item["original"] = get_original_dialogue(item) output_item[decoded_key] = text_only_dialogue merged.append(output_item) return merged def default_output_path(input_path: str, language: str, region: str) -> str: stem = Path(input_path).stem suffix_parts = [language.strip().lower().replace(" ", "_")] if region.strip(): suffix_parts.append(region.strip().lower().replace(" ", "_").replace("/", "_")) suffix = "_".join(suffix_parts) return str(DECODED_DIR / f"{stem}_{suffix}_decoded.json") def main() -> None: args = parse_args() raw_data = load_json(args.input_path) if not isinstance(raw_data, list): raise ValueError("Input JSON must be a list of dialogue objects.") sliced_data = raw_data[args.start_idx:args.end_idx] if args.max_instances is not None: sliced_data = sliced_data[: args.max_instances] output_path = args.output_path or default_output_path( args.input_path, args.language, args.region, ) generator = Generator( model_alias=args.model, use_cache=not args.dont_use_cached, ) processed_data = preprocess_decode_input( data=sliced_data, language=args.language, region=args.region, ) print(f"[Decode] Building decoded dialogue for {len(processed_data)} items...") decode_prompts, decode_response_format = generator.build_prompts( args.decode_prompt, processed_data, ) decode_responses = generator.prompt( prompts=decode_prompts, response_format=decode_response_format, dont_use_cached=args.dont_use_cached, skip_failures=True, ) final_data = merge_decoded_responses( processed_data, decode_responses, args.language, ) save_json(output_path, final_data) print(f"Saved decoded data to: {output_path}") generator.print_usage_summary(stage="Decode") if __name__ == "__main__": main()