from __future__ import annotations import math from pathlib import Path from typing import Any from datasets import DatasetDict, load_dataset from transformers import PreTrainedTokenizerBase from tiny_router.constants import ( ACTION_VOCAB, DEFAULT_MAX_LENGTH, DEFAULT_RECENCY_MAX, FEATURE_MODES, HEAD_LABELS, INTERACTION_SENTINELS, OUTCOME_VOCAB, ) ACTION_TO_ID = {label: idx for idx, label in enumerate(ACTION_VOCAB)} OUTCOME_TO_ID = {label: idx for idx, label in enumerate(OUTCOME_VOCAB)} LABEL_TO_ID = { head: {label: idx for idx, label in enumerate(labels)} for head, labels in HEAD_LABELS.items() } LEGACY_ACTION_MAP = { "created_reminder": "schedule", "updated_item": "update", "sent_reply": "send", "stored_memory": "store", "routed_to_queue": "route", "scheduled_task": "schedule", "dismissed": "dismissed", "requested_clarification": "clarify", } ACTION_KEYWORDS = [ ("schedule", ("schedule", "calendar", "meeting", "remind", "booking", "book")), ("send", ("send", "reply", "email", "message", "respond", "share")), ("route", ("route", "assign", "queue", "forward", "handoff", "hand_off", "triage")), ("store", ("store", "save", "remember", "record", "log", "note")), ("update", ("update", "edit", "change", "modify", "patch", "fix")), ("create", ("create", "add", "open", "draft", "new")), ("clarify", ("clarify", "ask", "question", "confirm_detail", "follow_up_question")), ("search", ("search", "find", "lookup", "retrieve", "fetch")), ("notify", ("notify", "alert", "ping", "nudge", "inform")), ("cancel", ("cancel", "stop", "abort", "undo", "revoke")), ("complete", ("close", "complete", "resolve", "finish", "done")), ("dismissed", ("dismiss", "ignore", "archive", "mute", "snooze")), ] def canonicalize_action(action: Any) -> tuple[str, str]: raw_action = str(action or "").strip() if not raw_action: return "", "none" normalized = raw_action.lower().replace("-", "_").replace(" ", "_") if normalized in ACTION_TO_ID: return raw_action, normalized if normalized in LEGACY_ACTION_MAP: return raw_action, LEGACY_ACTION_MAP[normalized] for canonical, keywords in ACTION_KEYWORDS: if any(keyword in normalized for keyword in keywords): return raw_action, canonical return raw_action, "other" def normalize_interaction( interaction: dict[str, Any] | None, recency_max: int = DEFAULT_RECENCY_MAX, ) -> dict[str, Any]: normalized = dict(INTERACTION_SENTINELS) if not interaction: normalized["log_recency_seconds"] = 0.0 normalized["has_interaction"] = 0 normalized["has_recency"] = 0 normalized["previous_action_raw"] = "" normalized["previous_action_canonical"] = "none" return normalized normalized["previous_text"] = str(interaction.get("previous_text", "") or "") previous_action_raw, previous_action_canonical = canonicalize_action( interaction.get("previous_action", "none") ) previous_outcome = str(interaction.get("previous_outcome", "unknown") or "unknown") if previous_outcome not in OUTCOME_TO_ID: previous_outcome = "unknown" normalized["previous_action"] = previous_action_raw normalized["previous_action_raw"] = previous_action_raw normalized["previous_action_canonical"] = previous_action_canonical normalized["previous_outcome"] = previous_outcome has_context_fields = int( bool(normalized["previous_text"]) or bool(previous_action_raw) or previous_outcome != "unknown" ) raw_recency = interaction.get("recency_seconds", -1) try: raw_recency = int(raw_recency) except (TypeError, ValueError): raw_recency = -1 if raw_recency >= 0: clamped = min(raw_recency, recency_max) normalized["recency_seconds"] = clamped normalized["log_recency_seconds"] = math.log1p(clamped) normalized["has_recency"] = 1 else: normalized["recency_seconds"] = -1 normalized["log_recency_seconds"] = 0.0 normalized["has_recency"] = 0 normalized["has_interaction"] = int(has_context_fields or normalized["has_recency"]) return normalized def normalize_example(example: dict[str, Any], recency_max: int = DEFAULT_RECENCY_MAX) -> dict[str, Any]: interaction = normalize_interaction(example.get("interaction"), recency_max=recency_max) labels = example.get("labels", {}) or {} normalized = { "current_text": str(example.get("current_text", "") or "").strip(), "interaction": interaction, "labels": labels, } return normalized def build_prompt(example: dict[str, Any], feature_mode: str = "full_interaction") -> str: if feature_mode not in FEATURE_MODES: raise ValueError(f"Unsupported feature mode: {feature_mode}") current_text = example["current_text"] interaction = example["interaction"] if feature_mode == "current_text_only": return f"Current: {current_text}" if feature_mode == "current_plus_previous_text": previous_text = interaction["previous_text"] if interaction["has_interaction"] else "None" return f"Current: {current_text}\nPrevious: {previous_text}" previous_text = interaction["previous_text"] if interaction["has_interaction"] else "None" previous_action = interaction["previous_action_raw"] if interaction["has_interaction"] else "None" previous_outcome = interaction["previous_outcome"] if interaction["has_interaction"] else "None" return ( f"Current: {current_text}\n" f"Previous: {previous_text}\n" f"Previous action: {previous_action}\n" f"Previous outcome: {previous_outcome}" ) def prepare_record( example: dict[str, Any], tokenizer: PreTrainedTokenizerBase, feature_mode: str = "full_interaction", max_length: int = DEFAULT_MAX_LENGTH, recency_max: int = DEFAULT_RECENCY_MAX, ) -> dict[str, Any]: normalized = normalize_example(example, recency_max=recency_max) prompt = build_prompt(normalized, feature_mode=feature_mode) tokenized = tokenizer( prompt, truncation=True, max_length=max_length, padding=False, ) interaction = normalized["interaction"] use_structured = 1 if feature_mode == "full_interaction" and interaction["has_interaction"] else 0 record = { "input_ids": tokenized["input_ids"], "attention_mask": tokenized["attention_mask"], "previous_action_id": ACTION_TO_ID[interaction["previous_action_canonical"]] if use_structured else ACTION_TO_ID["none"], "previous_outcome_id": OUTCOME_TO_ID[interaction["previous_outcome"]] if use_structured else OUTCOME_TO_ID["unknown"], "log_recency_seconds": float(interaction["log_recency_seconds"]) if use_structured else 0.0, "has_interaction": int(use_structured), "has_recency": int(interaction["has_recency"]) if use_structured else 0, } for head, mapping in LABEL_TO_ID.items(): label_value = normalized["labels"].get(head) if label_value is not None: record[f"labels_{head}"] = mapping[label_value] return record def build_dataset_dict( train_file: str | None, validation_file: str | None, test_file: str | None = None, ) -> DatasetDict: data_files = {} if train_file: data_files["train"] = str(Path(train_file)) if validation_file: data_files["validation"] = str(Path(validation_file)) if test_file: data_files["test"] = str(Path(test_file)) if not data_files: raise ValueError("At least one dataset file is required.") return load_dataset("json", data_files=data_files) def tokenize_dataset_dict( dataset_dict: DatasetDict, tokenizer: PreTrainedTokenizerBase, feature_mode: str = "full_interaction", max_length: int = DEFAULT_MAX_LENGTH, recency_max: int = DEFAULT_RECENCY_MAX, ) -> DatasetDict: def mapper(example: dict[str, Any]) -> dict[str, Any]: return prepare_record( example, tokenizer=tokenizer, feature_mode=feature_mode, max_length=max_length, recency_max=recency_max, ) remove_columns = list(next(iter(dataset_dict.values())).column_names) return dataset_dict.map(mapper, remove_columns=remove_columns) class RouterCollator: def __init__(self, tokenizer: PreTrainedTokenizerBase): self.tokenizer = tokenizer def __call__(self, features: list[dict[str, Any]]) -> dict[str, Any]: batch = self.tokenizer.pad(features, padding=True, return_tensors="pt") for key in ("previous_action_id", "previous_outcome_id", "has_interaction", "has_recency"): batch[key] = batch[key].long() batch["log_recency_seconds"] = batch["log_recency_seconds"].float() for head in HEAD_LABELS: label_key = f"labels_{head}" if label_key in batch: batch[label_key] = batch[label_key].long() return batch