Spaces:
Running
Running
| import argparse | |
| import json | |
| from pathlib import Path | |
| from typing import Any, Dict, List, Optional | |
| from src.config import ( | |
| ENCODED_DIR, | |
| PROMPTS_DIR, | |
| RAW_DATA_PATH, | |
| ) | |
| from src.generator import Generator | |
| def parse_args() -> argparse.Namespace: | |
| parser = argparse.ArgumentParser(description="Encode dialogues into DAS format and generate context.") | |
| parser.add_argument( | |
| "--input_path", | |
| type=str, | |
| default=str(RAW_DATA_PATH), | |
| help="Path to raw dialogue JSON file.", | |
| ) | |
| parser.add_argument( | |
| "--output_path", | |
| type=str, | |
| default=str(ENCODED_DIR / "dailydialog_encoded.json"), | |
| help="Path to save encoded output JSON.", | |
| ) | |
| parser.add_argument( | |
| "--encode_prompt", | |
| type=str, | |
| default=str(PROMPTS_DIR / "das_encode.md"), | |
| help="Path to DAS encode prompt file.", | |
| ) | |
| parser.add_argument( | |
| "--context_prompt", | |
| type=str, | |
| default=str(PROMPTS_DIR / "das_context.md"), | |
| help="Path to DAS context prompt file.", | |
| ) | |
| parser.add_argument( | |
| "--functions_path", | |
| type=str, | |
| default=str(PROMPTS_DIR / "das_functions.json"), | |
| help="Path to DAS function definitions JSON.", | |
| ) | |
| parser.add_argument( | |
| "--model", | |
| type=str, | |
| default=None, | |
| help="Model alias from model_registry.py", | |
| ) | |
| parser.add_argument( | |
| "--max_instances", | |
| type=int, | |
| default=None, | |
| help="Optional cap on number of dialogues to process.", | |
| ) | |
| parser.add_argument( | |
| "--start_idx", | |
| type=int, | |
| default=0, | |
| help="Optional start index for slicing input data.", | |
| ) | |
| parser.add_argument( | |
| "--end_idx", | |
| type=int, | |
| default=None, | |
| help="Optional end index for slicing input data.", | |
| ) | |
| parser.add_argument( | |
| "--dont_use_cached", | |
| action="store_true", | |
| help="Disable cached prompt responses.", | |
| ) | |
| return parser.parse_args() | |
| def load_json(path: str) -> Any: | |
| return json.loads(Path(path).read_text(encoding="utf-8")) | |
| def save_json(path: str, data: Any) -> None: | |
| output_path = Path(path) | |
| output_path.parent.mkdir(parents=True, exist_ok=True) | |
| output_path.write_text( | |
| json.dumps(data, indent=2, ensure_ascii=False), | |
| encoding="utf-8", | |
| ) | |
| def preprocess_conversation(dialogue_turns: List[str]) -> str: | |
| return "\n".join(dialogue_turns) | |
| def preprocess_dailydialogue( | |
| data: List[Dict[str, Any]], | |
| functions: List[Dict[str, Any]], | |
| ) -> List[Dict[str, Any]]: | |
| processed: List[Dict[str, Any]] = [] | |
| for item in data: | |
| new_item = dict(item) | |
| new_item["conversation"] = preprocess_conversation(item["dialogue"]) | |
| new_item["functions"] = functions | |
| processed.append(new_item) | |
| return processed | |
| def create_encoding_prompts( | |
| generator: Generator, | |
| data: List[Dict[str, Any]], | |
| prompt_path: str, | |
| ) -> tuple[list[list[dict[str, str]]], Optional[Dict[str, Any]]]: | |
| return generator.build_prompts(prompt_path, data) | |
| def merge_encoding_responses( | |
| original_data: List[Dict[str, Any]], | |
| responses: List[str], | |
| ) -> List[Dict[str, Any]]: | |
| merged: List[Dict[str, Any]] = [] | |
| for item, response_text in zip(original_data, responses): | |
| response_json = Generator.parse_json_response(response_text) | |
| new_item = dict(item) | |
| new_item["das_encoding"] = response_json["das_encoding"] | |
| merged.append(new_item) | |
| return merged | |
| def create_context_input( | |
| data_with_encoding: List[Dict[str, Any]], | |
| ) -> List[Dict[str, Any]]: | |
| """ | |
| Prepare items for the context prompt. | |
| The DAS context prompt loops over 'das_encoding', | |
| so we just pass the data through. | |
| """ | |
| prepared: List[Dict[str, Any]] = [] | |
| for item in data_with_encoding: | |
| new_item = dict(item) | |
| prepared.append(new_item) | |
| return prepared | |
| def merge_context_responses( | |
| data_with_encoding: List[Dict[str, Any]], | |
| responses: List[str], | |
| ) -> List[Dict[str, Any]]: | |
| merged: List[Dict[str, Any]] = [] | |
| for item, response_text in zip(data_with_encoding, responses): | |
| response_json = Generator.parse_json_response(response_text) | |
| new_item = dict(item) | |
| new_item["context"] = response_json["context"] | |
| merged.append(new_item) | |
| return merged | |
| def main() -> None: | |
| args = parse_args() | |
| raw_data = load_json(args.input_path) | |
| if not isinstance(raw_data, list): | |
| raise ValueError("Input JSON must be a list of dialogue objects.") | |
| sliced_data = raw_data[args.start_idx:args.end_idx] | |
| if args.max_instances is not None: | |
| sliced_data = sliced_data[: args.max_instances] | |
| functions = load_json(args.functions_path) | |
| if not isinstance(functions, list): | |
| raise ValueError("das_functions.json must be a list.") | |
| generator = Generator( | |
| model_alias=args.model, | |
| use_cache=not args.dont_use_cached, | |
| ) | |
| # Step 1: build encode prompts | |
| processed_data = preprocess_dailydialogue(sliced_data, functions) | |
| encode_prompts, encode_response_format = create_encoding_prompts( | |
| generator=generator, | |
| data=processed_data, | |
| prompt_path=args.encode_prompt, | |
| ) | |
| # Step 2: run encoding | |
| encode_responses = generator.prompt( | |
| prompts=encode_prompts, | |
| response_format=encode_response_format, | |
| dont_use_cached=args.dont_use_cached, | |
| ) | |
| encoded_data = merge_encoding_responses(sliced_data, encode_responses) | |
| # Step 3: build context prompts | |
| context_input = create_context_input(encoded_data) | |
| context_prompts, context_response_format = generator.build_prompts( | |
| args.context_prompt, | |
| context_input, | |
| ) | |
| # Step 4: run context generation | |
| context_responses = generator.prompt( | |
| prompts=context_prompts, | |
| response_format=context_response_format, | |
| dont_use_cached=args.dont_use_cached, | |
| ) | |
| final_data = merge_context_responses(encoded_data, context_responses) | |
| # Step 5: save | |
| save_json(args.output_path, final_data) | |
| print(f"Saved encoded data to: {args.output_path}") | |
| generator.print_usage_summary(stage="Encode") | |
| if __name__ == "__main__": | |
| main() |