import argparse import json from pathlib import Path from typing import Any, Dict, List, Optional from src.config import ( ENCODED_DIR, PROMPTS_DIR, RAW_DATA_PATH, ) from src.generator import Generator def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="Encode dialogues into DAS format and generate context.") parser.add_argument( "--input_path", type=str, default=str(RAW_DATA_PATH), help="Path to raw dialogue JSON file.", ) parser.add_argument( "--output_path", type=str, default=str(ENCODED_DIR / "dailydialog_encoded.json"), help="Path to save encoded output JSON.", ) parser.add_argument( "--encode_prompt", type=str, default=str(PROMPTS_DIR / "das_encode.md"), help="Path to DAS encode prompt file.", ) parser.add_argument( "--context_prompt", type=str, default=str(PROMPTS_DIR / "das_context.md"), help="Path to DAS context prompt file.", ) parser.add_argument( "--functions_path", type=str, default=str(PROMPTS_DIR / "das_functions.json"), help="Path to DAS function definitions JSON.", ) parser.add_argument( "--model", type=str, default=None, help="Model alias from model_registry.py", ) parser.add_argument( "--max_instances", type=int, default=None, help="Optional cap on number of dialogues to process.", ) parser.add_argument( "--start_idx", type=int, default=0, help="Optional start index for slicing input data.", ) parser.add_argument( "--end_idx", type=int, default=None, help="Optional end index for slicing input data.", ) parser.add_argument( "--dont_use_cached", action="store_true", help="Disable cached prompt responses.", ) return parser.parse_args() def load_json(path: str) -> Any: return json.loads(Path(path).read_text(encoding="utf-8")) def save_json(path: str, data: Any) -> None: output_path = Path(path) output_path.parent.mkdir(parents=True, exist_ok=True) output_path.write_text( json.dumps(data, indent=2, ensure_ascii=False), encoding="utf-8", ) def preprocess_conversation(dialogue_turns: List[str]) -> str: return "\n".join(dialogue_turns) def preprocess_dailydialogue( data: List[Dict[str, Any]], functions: List[Dict[str, Any]], ) -> List[Dict[str, Any]]: processed: List[Dict[str, Any]] = [] for item in data: new_item = dict(item) new_item["conversation"] = preprocess_conversation(item["dialogue"]) new_item["functions"] = functions processed.append(new_item) return processed def create_encoding_prompts( generator: Generator, data: List[Dict[str, Any]], prompt_path: str, ) -> tuple[list[list[dict[str, str]]], Optional[Dict[str, Any]]]: return generator.build_prompts(prompt_path, data) def merge_encoding_responses( original_data: List[Dict[str, Any]], responses: List[str], ) -> List[Dict[str, Any]]: merged: List[Dict[str, Any]] = [] for item, response_text in zip(original_data, responses): response_json = Generator.parse_json_response(response_text) new_item = dict(item) new_item["das_encoding"] = response_json["das_encoding"] merged.append(new_item) return merged def create_context_input( data_with_encoding: List[Dict[str, Any]], ) -> List[Dict[str, Any]]: """ Prepare items for the context prompt. The DAS context prompt loops over 'das_encoding', so we just pass the data through. """ prepared: List[Dict[str, Any]] = [] for item in data_with_encoding: new_item = dict(item) prepared.append(new_item) return prepared def merge_context_responses( data_with_encoding: List[Dict[str, Any]], responses: List[str], ) -> List[Dict[str, Any]]: merged: List[Dict[str, Any]] = [] for item, response_text in zip(data_with_encoding, responses): response_json = Generator.parse_json_response(response_text) new_item = dict(item) new_item["context"] = response_json["context"] merged.append(new_item) return merged def main() -> None: args = parse_args() raw_data = load_json(args.input_path) if not isinstance(raw_data, list): raise ValueError("Input JSON must be a list of dialogue objects.") sliced_data = raw_data[args.start_idx:args.end_idx] if args.max_instances is not None: sliced_data = sliced_data[: args.max_instances] functions = load_json(args.functions_path) if not isinstance(functions, list): raise ValueError("das_functions.json must be a list.") generator = Generator( model_alias=args.model, use_cache=not args.dont_use_cached, ) # Step 1: build encode prompts processed_data = preprocess_dailydialogue(sliced_data, functions) encode_prompts, encode_response_format = create_encoding_prompts( generator=generator, data=processed_data, prompt_path=args.encode_prompt, ) # Step 2: run encoding encode_responses = generator.prompt( prompts=encode_prompts, response_format=encode_response_format, dont_use_cached=args.dont_use_cached, ) encoded_data = merge_encoding_responses(sliced_data, encode_responses) # Step 3: build context prompts context_input = create_context_input(encoded_data) context_prompts, context_response_format = generator.build_prompts( args.context_prompt, context_input, ) # Step 4: run context generation context_responses = generator.prompt( prompts=context_prompts, response_format=context_response_format, dont_use_cached=args.dont_use_cached, ) final_data = merge_context_responses(encoded_data, context_responses) # Step 5: save save_json(args.output_path, final_data) print(f"Saved encoded data to: {args.output_path}") generator.print_usage_summary(stage="Encode") if __name__ == "__main__": main()