afridialeval / src /encode.py
millicentochieng's picture
Upload folder using huggingface_hub
e2b8b61 verified
import argparse
import json
from pathlib import Path
from typing import Any, Dict, List, Optional
from src.config import (
ENCODED_DIR,
PROMPTS_DIR,
RAW_DATA_PATH,
)
from src.generator import Generator
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Encode dialogues into DAS format and generate context.")
parser.add_argument(
"--input_path",
type=str,
default=str(RAW_DATA_PATH),
help="Path to raw dialogue JSON file.",
)
parser.add_argument(
"--output_path",
type=str,
default=str(ENCODED_DIR / "dailydialog_encoded.json"),
help="Path to save encoded output JSON.",
)
parser.add_argument(
"--encode_prompt",
type=str,
default=str(PROMPTS_DIR / "das_encode.md"),
help="Path to DAS encode prompt file.",
)
parser.add_argument(
"--context_prompt",
type=str,
default=str(PROMPTS_DIR / "das_context.md"),
help="Path to DAS context prompt file.",
)
parser.add_argument(
"--functions_path",
type=str,
default=str(PROMPTS_DIR / "das_functions.json"),
help="Path to DAS function definitions JSON.",
)
parser.add_argument(
"--model",
type=str,
default=None,
help="Model alias from model_registry.py",
)
parser.add_argument(
"--max_instances",
type=int,
default=None,
help="Optional cap on number of dialogues to process.",
)
parser.add_argument(
"--start_idx",
type=int,
default=0,
help="Optional start index for slicing input data.",
)
parser.add_argument(
"--end_idx",
type=int,
default=None,
help="Optional end index for slicing input data.",
)
parser.add_argument(
"--dont_use_cached",
action="store_true",
help="Disable cached prompt responses.",
)
return parser.parse_args()
def load_json(path: str) -> Any:
return json.loads(Path(path).read_text(encoding="utf-8"))
def save_json(path: str, data: Any) -> None:
output_path = Path(path)
output_path.parent.mkdir(parents=True, exist_ok=True)
output_path.write_text(
json.dumps(data, indent=2, ensure_ascii=False),
encoding="utf-8",
)
def preprocess_conversation(dialogue_turns: List[str]) -> str:
return "\n".join(dialogue_turns)
def preprocess_dailydialogue(
data: List[Dict[str, Any]],
functions: List[Dict[str, Any]],
) -> List[Dict[str, Any]]:
processed: List[Dict[str, Any]] = []
for item in data:
new_item = dict(item)
new_item["conversation"] = preprocess_conversation(item["dialogue"])
new_item["functions"] = functions
processed.append(new_item)
return processed
def create_encoding_prompts(
generator: Generator,
data: List[Dict[str, Any]],
prompt_path: str,
) -> tuple[list[list[dict[str, str]]], Optional[Dict[str, Any]]]:
return generator.build_prompts(prompt_path, data)
def merge_encoding_responses(
original_data: List[Dict[str, Any]],
responses: List[str],
) -> List[Dict[str, Any]]:
merged: List[Dict[str, Any]] = []
for item, response_text in zip(original_data, responses):
response_json = Generator.parse_json_response(response_text)
new_item = dict(item)
new_item["das_encoding"] = response_json["das_encoding"]
merged.append(new_item)
return merged
def create_context_input(
data_with_encoding: List[Dict[str, Any]],
) -> List[Dict[str, Any]]:
"""
Prepare items for the context prompt.
The DAS context prompt loops over 'das_encoding',
so we just pass the data through.
"""
prepared: List[Dict[str, Any]] = []
for item in data_with_encoding:
new_item = dict(item)
prepared.append(new_item)
return prepared
def merge_context_responses(
data_with_encoding: List[Dict[str, Any]],
responses: List[str],
) -> List[Dict[str, Any]]:
merged: List[Dict[str, Any]] = []
for item, response_text in zip(data_with_encoding, responses):
response_json = Generator.parse_json_response(response_text)
new_item = dict(item)
new_item["context"] = response_json["context"]
merged.append(new_item)
return merged
def main() -> None:
args = parse_args()
raw_data = load_json(args.input_path)
if not isinstance(raw_data, list):
raise ValueError("Input JSON must be a list of dialogue objects.")
sliced_data = raw_data[args.start_idx:args.end_idx]
if args.max_instances is not None:
sliced_data = sliced_data[: args.max_instances]
functions = load_json(args.functions_path)
if not isinstance(functions, list):
raise ValueError("das_functions.json must be a list.")
generator = Generator(
model_alias=args.model,
use_cache=not args.dont_use_cached,
)
# Step 1: build encode prompts
processed_data = preprocess_dailydialogue(sliced_data, functions)
encode_prompts, encode_response_format = create_encoding_prompts(
generator=generator,
data=processed_data,
prompt_path=args.encode_prompt,
)
# Step 2: run encoding
encode_responses = generator.prompt(
prompts=encode_prompts,
response_format=encode_response_format,
dont_use_cached=args.dont_use_cached,
)
encoded_data = merge_encoding_responses(sliced_data, encode_responses)
# Step 3: build context prompts
context_input = create_context_input(encoded_data)
context_prompts, context_response_format = generator.build_prompts(
args.context_prompt,
context_input,
)
# Step 4: run context generation
context_responses = generator.prompt(
prompts=context_prompts,
response_format=context_response_format,
dont_use_cached=args.dont_use_cached,
)
final_data = merge_context_responses(encoded_data, context_responses)
# Step 5: save
save_json(args.output_path, final_data)
print(f"Saved encoded data to: {args.output_path}")
generator.print_usage_summary(stage="Encode")
if __name__ == "__main__":
main()