Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import math | |
| from pathlib import Path | |
| from typing import Any | |
| from datasets import DatasetDict, load_dataset | |
| from transformers import PreTrainedTokenizerBase | |
| from tiny_router.constants import ( | |
| ACTION_VOCAB, | |
| DEFAULT_MAX_LENGTH, | |
| DEFAULT_RECENCY_MAX, | |
| FEATURE_MODES, | |
| HEAD_LABELS, | |
| INTERACTION_SENTINELS, | |
| OUTCOME_VOCAB, | |
| ) | |
| ACTION_TO_ID = {label: idx for idx, label in enumerate(ACTION_VOCAB)} | |
| OUTCOME_TO_ID = {label: idx for idx, label in enumerate(OUTCOME_VOCAB)} | |
| LABEL_TO_ID = { | |
| head: {label: idx for idx, label in enumerate(labels)} | |
| for head, labels in HEAD_LABELS.items() | |
| } | |
| LEGACY_ACTION_MAP = { | |
| "created_reminder": "schedule", | |
| "updated_item": "update", | |
| "sent_reply": "send", | |
| "stored_memory": "store", | |
| "routed_to_queue": "route", | |
| "scheduled_task": "schedule", | |
| "dismissed": "dismissed", | |
| "requested_clarification": "clarify", | |
| } | |
| ACTION_KEYWORDS = [ | |
| ("schedule", ("schedule", "calendar", "meeting", "remind", "booking", "book")), | |
| ("send", ("send", "reply", "email", "message", "respond", "share")), | |
| ("route", ("route", "assign", "queue", "forward", "handoff", "hand_off", "triage")), | |
| ("store", ("store", "save", "remember", "record", "log", "note")), | |
| ("update", ("update", "edit", "change", "modify", "patch", "fix")), | |
| ("create", ("create", "add", "open", "draft", "new")), | |
| ("clarify", ("clarify", "ask", "question", "confirm_detail", "follow_up_question")), | |
| ("search", ("search", "find", "lookup", "retrieve", "fetch")), | |
| ("notify", ("notify", "alert", "ping", "nudge", "inform")), | |
| ("cancel", ("cancel", "stop", "abort", "undo", "revoke")), | |
| ("complete", ("close", "complete", "resolve", "finish", "done")), | |
| ("dismissed", ("dismiss", "ignore", "archive", "mute", "snooze")), | |
| ] | |
| def canonicalize_action(action: Any) -> tuple[str, str]: | |
| raw_action = str(action or "").strip() | |
| if not raw_action: | |
| return "", "none" | |
| normalized = raw_action.lower().replace("-", "_").replace(" ", "_") | |
| if normalized in ACTION_TO_ID: | |
| return raw_action, normalized | |
| if normalized in LEGACY_ACTION_MAP: | |
| return raw_action, LEGACY_ACTION_MAP[normalized] | |
| for canonical, keywords in ACTION_KEYWORDS: | |
| if any(keyword in normalized for keyword in keywords): | |
| return raw_action, canonical | |
| return raw_action, "other" | |
| def normalize_interaction( | |
| interaction: dict[str, Any] | None, | |
| recency_max: int = DEFAULT_RECENCY_MAX, | |
| ) -> dict[str, Any]: | |
| normalized = dict(INTERACTION_SENTINELS) | |
| if not interaction: | |
| normalized["log_recency_seconds"] = 0.0 | |
| normalized["has_interaction"] = 0 | |
| normalized["has_recency"] = 0 | |
| normalized["previous_action_raw"] = "" | |
| normalized["previous_action_canonical"] = "none" | |
| return normalized | |
| normalized["previous_text"] = str(interaction.get("previous_text", "") or "") | |
| previous_action_raw, previous_action_canonical = canonicalize_action( | |
| interaction.get("previous_action", "none") | |
| ) | |
| previous_outcome = str(interaction.get("previous_outcome", "unknown") or "unknown") | |
| if previous_outcome not in OUTCOME_TO_ID: | |
| previous_outcome = "unknown" | |
| normalized["previous_action"] = previous_action_raw | |
| normalized["previous_action_raw"] = previous_action_raw | |
| normalized["previous_action_canonical"] = previous_action_canonical | |
| normalized["previous_outcome"] = previous_outcome | |
| has_context_fields = int( | |
| bool(normalized["previous_text"]) | |
| or bool(previous_action_raw) | |
| or previous_outcome != "unknown" | |
| ) | |
| raw_recency = interaction.get("recency_seconds", -1) | |
| try: | |
| raw_recency = int(raw_recency) | |
| except (TypeError, ValueError): | |
| raw_recency = -1 | |
| if raw_recency >= 0: | |
| clamped = min(raw_recency, recency_max) | |
| normalized["recency_seconds"] = clamped | |
| normalized["log_recency_seconds"] = math.log1p(clamped) | |
| normalized["has_recency"] = 1 | |
| else: | |
| normalized["recency_seconds"] = -1 | |
| normalized["log_recency_seconds"] = 0.0 | |
| normalized["has_recency"] = 0 | |
| normalized["has_interaction"] = int(has_context_fields or normalized["has_recency"]) | |
| return normalized | |
| def normalize_example(example: dict[str, Any], recency_max: int = DEFAULT_RECENCY_MAX) -> dict[str, Any]: | |
| interaction = normalize_interaction(example.get("interaction"), recency_max=recency_max) | |
| labels = example.get("labels", {}) or {} | |
| normalized = { | |
| "current_text": str(example.get("current_text", "") or "").strip(), | |
| "interaction": interaction, | |
| "labels": labels, | |
| } | |
| return normalized | |
| def build_prompt(example: dict[str, Any], feature_mode: str = "full_interaction") -> str: | |
| if feature_mode not in FEATURE_MODES: | |
| raise ValueError(f"Unsupported feature mode: {feature_mode}") | |
| current_text = example["current_text"] | |
| interaction = example["interaction"] | |
| if feature_mode == "current_text_only": | |
| return f"Current: {current_text}" | |
| if feature_mode == "current_plus_previous_text": | |
| previous_text = interaction["previous_text"] if interaction["has_interaction"] else "None" | |
| return f"Current: {current_text}\nPrevious: {previous_text}" | |
| previous_text = interaction["previous_text"] if interaction["has_interaction"] else "None" | |
| previous_action = interaction["previous_action_raw"] if interaction["has_interaction"] else "None" | |
| previous_outcome = interaction["previous_outcome"] if interaction["has_interaction"] else "None" | |
| return ( | |
| f"Current: {current_text}\n" | |
| f"Previous: {previous_text}\n" | |
| f"Previous action: {previous_action}\n" | |
| f"Previous outcome: {previous_outcome}" | |
| ) | |
| def prepare_record( | |
| example: dict[str, Any], | |
| tokenizer: PreTrainedTokenizerBase, | |
| feature_mode: str = "full_interaction", | |
| max_length: int = DEFAULT_MAX_LENGTH, | |
| recency_max: int = DEFAULT_RECENCY_MAX, | |
| ) -> dict[str, Any]: | |
| normalized = normalize_example(example, recency_max=recency_max) | |
| prompt = build_prompt(normalized, feature_mode=feature_mode) | |
| tokenized = tokenizer( | |
| prompt, | |
| truncation=True, | |
| max_length=max_length, | |
| padding=False, | |
| ) | |
| interaction = normalized["interaction"] | |
| use_structured = 1 if feature_mode == "full_interaction" and interaction["has_interaction"] else 0 | |
| record = { | |
| "input_ids": tokenized["input_ids"], | |
| "attention_mask": tokenized["attention_mask"], | |
| "previous_action_id": ACTION_TO_ID[interaction["previous_action_canonical"]] if use_structured else ACTION_TO_ID["none"], | |
| "previous_outcome_id": OUTCOME_TO_ID[interaction["previous_outcome"]] if use_structured else OUTCOME_TO_ID["unknown"], | |
| "log_recency_seconds": float(interaction["log_recency_seconds"]) if use_structured else 0.0, | |
| "has_interaction": int(use_structured), | |
| "has_recency": int(interaction["has_recency"]) if use_structured else 0, | |
| } | |
| for head, mapping in LABEL_TO_ID.items(): | |
| label_value = normalized["labels"].get(head) | |
| if label_value is not None: | |
| record[f"labels_{head}"] = mapping[label_value] | |
| return record | |
| def build_dataset_dict( | |
| train_file: str | None, | |
| validation_file: str | None, | |
| test_file: str | None = None, | |
| ) -> DatasetDict: | |
| data_files = {} | |
| if train_file: | |
| data_files["train"] = str(Path(train_file)) | |
| if validation_file: | |
| data_files["validation"] = str(Path(validation_file)) | |
| if test_file: | |
| data_files["test"] = str(Path(test_file)) | |
| if not data_files: | |
| raise ValueError("At least one dataset file is required.") | |
| return load_dataset("json", data_files=data_files) | |
| def tokenize_dataset_dict( | |
| dataset_dict: DatasetDict, | |
| tokenizer: PreTrainedTokenizerBase, | |
| feature_mode: str = "full_interaction", | |
| max_length: int = DEFAULT_MAX_LENGTH, | |
| recency_max: int = DEFAULT_RECENCY_MAX, | |
| ) -> DatasetDict: | |
| def mapper(example: dict[str, Any]) -> dict[str, Any]: | |
| return prepare_record( | |
| example, | |
| tokenizer=tokenizer, | |
| feature_mode=feature_mode, | |
| max_length=max_length, | |
| recency_max=recency_max, | |
| ) | |
| remove_columns = list(next(iter(dataset_dict.values())).column_names) | |
| return dataset_dict.map(mapper, remove_columns=remove_columns) | |
| class RouterCollator: | |
| def __init__(self, tokenizer: PreTrainedTokenizerBase): | |
| self.tokenizer = tokenizer | |
| def __call__(self, features: list[dict[str, Any]]) -> dict[str, Any]: | |
| batch = self.tokenizer.pad(features, padding=True, return_tensors="pt") | |
| for key in ("previous_action_id", "previous_outcome_id", "has_interaction", "has_recency"): | |
| batch[key] = batch[key].long() | |
| batch["log_recency_seconds"] = batch["log_recency_seconds"].float() | |
| for head in HEAD_LABELS: | |
| label_key = f"labels_{head}" | |
| if label_key in batch: | |
| batch[label_key] = batch[label_key].long() | |
| return batch | |