dexifried
Replace with tiny-router trainer (ZeroGPU/H200)
3bfff54
from __future__ import annotations
import math
from pathlib import Path
from typing import Any
from datasets import DatasetDict, load_dataset
from transformers import PreTrainedTokenizerBase
from tiny_router.constants import (
ACTION_VOCAB,
DEFAULT_MAX_LENGTH,
DEFAULT_RECENCY_MAX,
FEATURE_MODES,
HEAD_LABELS,
INTERACTION_SENTINELS,
OUTCOME_VOCAB,
)
ACTION_TO_ID = {label: idx for idx, label in enumerate(ACTION_VOCAB)}
OUTCOME_TO_ID = {label: idx for idx, label in enumerate(OUTCOME_VOCAB)}
LABEL_TO_ID = {
head: {label: idx for idx, label in enumerate(labels)}
for head, labels in HEAD_LABELS.items()
}
LEGACY_ACTION_MAP = {
"created_reminder": "schedule",
"updated_item": "update",
"sent_reply": "send",
"stored_memory": "store",
"routed_to_queue": "route",
"scheduled_task": "schedule",
"dismissed": "dismissed",
"requested_clarification": "clarify",
}
ACTION_KEYWORDS = [
("schedule", ("schedule", "calendar", "meeting", "remind", "booking", "book")),
("send", ("send", "reply", "email", "message", "respond", "share")),
("route", ("route", "assign", "queue", "forward", "handoff", "hand_off", "triage")),
("store", ("store", "save", "remember", "record", "log", "note")),
("update", ("update", "edit", "change", "modify", "patch", "fix")),
("create", ("create", "add", "open", "draft", "new")),
("clarify", ("clarify", "ask", "question", "confirm_detail", "follow_up_question")),
("search", ("search", "find", "lookup", "retrieve", "fetch")),
("notify", ("notify", "alert", "ping", "nudge", "inform")),
("cancel", ("cancel", "stop", "abort", "undo", "revoke")),
("complete", ("close", "complete", "resolve", "finish", "done")),
("dismissed", ("dismiss", "ignore", "archive", "mute", "snooze")),
]
def canonicalize_action(action: Any) -> tuple[str, str]:
raw_action = str(action or "").strip()
if not raw_action:
return "", "none"
normalized = raw_action.lower().replace("-", "_").replace(" ", "_")
if normalized in ACTION_TO_ID:
return raw_action, normalized
if normalized in LEGACY_ACTION_MAP:
return raw_action, LEGACY_ACTION_MAP[normalized]
for canonical, keywords in ACTION_KEYWORDS:
if any(keyword in normalized for keyword in keywords):
return raw_action, canonical
return raw_action, "other"
def normalize_interaction(
interaction: dict[str, Any] | None,
recency_max: int = DEFAULT_RECENCY_MAX,
) -> dict[str, Any]:
normalized = dict(INTERACTION_SENTINELS)
if not interaction:
normalized["log_recency_seconds"] = 0.0
normalized["has_interaction"] = 0
normalized["has_recency"] = 0
normalized["previous_action_raw"] = ""
normalized["previous_action_canonical"] = "none"
return normalized
normalized["previous_text"] = str(interaction.get("previous_text", "") or "")
previous_action_raw, previous_action_canonical = canonicalize_action(
interaction.get("previous_action", "none")
)
previous_outcome = str(interaction.get("previous_outcome", "unknown") or "unknown")
if previous_outcome not in OUTCOME_TO_ID:
previous_outcome = "unknown"
normalized["previous_action"] = previous_action_raw
normalized["previous_action_raw"] = previous_action_raw
normalized["previous_action_canonical"] = previous_action_canonical
normalized["previous_outcome"] = previous_outcome
has_context_fields = int(
bool(normalized["previous_text"])
or bool(previous_action_raw)
or previous_outcome != "unknown"
)
raw_recency = interaction.get("recency_seconds", -1)
try:
raw_recency = int(raw_recency)
except (TypeError, ValueError):
raw_recency = -1
if raw_recency >= 0:
clamped = min(raw_recency, recency_max)
normalized["recency_seconds"] = clamped
normalized["log_recency_seconds"] = math.log1p(clamped)
normalized["has_recency"] = 1
else:
normalized["recency_seconds"] = -1
normalized["log_recency_seconds"] = 0.0
normalized["has_recency"] = 0
normalized["has_interaction"] = int(has_context_fields or normalized["has_recency"])
return normalized
def normalize_example(example: dict[str, Any], recency_max: int = DEFAULT_RECENCY_MAX) -> dict[str, Any]:
interaction = normalize_interaction(example.get("interaction"), recency_max=recency_max)
labels = example.get("labels", {}) or {}
normalized = {
"current_text": str(example.get("current_text", "") or "").strip(),
"interaction": interaction,
"labels": labels,
}
return normalized
def build_prompt(example: dict[str, Any], feature_mode: str = "full_interaction") -> str:
if feature_mode not in FEATURE_MODES:
raise ValueError(f"Unsupported feature mode: {feature_mode}")
current_text = example["current_text"]
interaction = example["interaction"]
if feature_mode == "current_text_only":
return f"Current: {current_text}"
if feature_mode == "current_plus_previous_text":
previous_text = interaction["previous_text"] if interaction["has_interaction"] else "None"
return f"Current: {current_text}\nPrevious: {previous_text}"
previous_text = interaction["previous_text"] if interaction["has_interaction"] else "None"
previous_action = interaction["previous_action_raw"] if interaction["has_interaction"] else "None"
previous_outcome = interaction["previous_outcome"] if interaction["has_interaction"] else "None"
return (
f"Current: {current_text}\n"
f"Previous: {previous_text}\n"
f"Previous action: {previous_action}\n"
f"Previous outcome: {previous_outcome}"
)
def prepare_record(
example: dict[str, Any],
tokenizer: PreTrainedTokenizerBase,
feature_mode: str = "full_interaction",
max_length: int = DEFAULT_MAX_LENGTH,
recency_max: int = DEFAULT_RECENCY_MAX,
) -> dict[str, Any]:
normalized = normalize_example(example, recency_max=recency_max)
prompt = build_prompt(normalized, feature_mode=feature_mode)
tokenized = tokenizer(
prompt,
truncation=True,
max_length=max_length,
padding=False,
)
interaction = normalized["interaction"]
use_structured = 1 if feature_mode == "full_interaction" and interaction["has_interaction"] else 0
record = {
"input_ids": tokenized["input_ids"],
"attention_mask": tokenized["attention_mask"],
"previous_action_id": ACTION_TO_ID[interaction["previous_action_canonical"]] if use_structured else ACTION_TO_ID["none"],
"previous_outcome_id": OUTCOME_TO_ID[interaction["previous_outcome"]] if use_structured else OUTCOME_TO_ID["unknown"],
"log_recency_seconds": float(interaction["log_recency_seconds"]) if use_structured else 0.0,
"has_interaction": int(use_structured),
"has_recency": int(interaction["has_recency"]) if use_structured else 0,
}
for head, mapping in LABEL_TO_ID.items():
label_value = normalized["labels"].get(head)
if label_value is not None:
record[f"labels_{head}"] = mapping[label_value]
return record
def build_dataset_dict(
train_file: str | None,
validation_file: str | None,
test_file: str | None = None,
) -> DatasetDict:
data_files = {}
if train_file:
data_files["train"] = str(Path(train_file))
if validation_file:
data_files["validation"] = str(Path(validation_file))
if test_file:
data_files["test"] = str(Path(test_file))
if not data_files:
raise ValueError("At least one dataset file is required.")
return load_dataset("json", data_files=data_files)
def tokenize_dataset_dict(
dataset_dict: DatasetDict,
tokenizer: PreTrainedTokenizerBase,
feature_mode: str = "full_interaction",
max_length: int = DEFAULT_MAX_LENGTH,
recency_max: int = DEFAULT_RECENCY_MAX,
) -> DatasetDict:
def mapper(example: dict[str, Any]) -> dict[str, Any]:
return prepare_record(
example,
tokenizer=tokenizer,
feature_mode=feature_mode,
max_length=max_length,
recency_max=recency_max,
)
remove_columns = list(next(iter(dataset_dict.values())).column_names)
return dataset_dict.map(mapper, remove_columns=remove_columns)
class RouterCollator:
def __init__(self, tokenizer: PreTrainedTokenizerBase):
self.tokenizer = tokenizer
def __call__(self, features: list[dict[str, Any]]) -> dict[str, Any]:
batch = self.tokenizer.pad(features, padding=True, return_tensors="pt")
for key in ("previous_action_id", "previous_outcome_id", "has_interaction", "has_recency"):
batch[key] = batch[key].long()
batch["log_recency_seconds"] = batch["log_recency_seconds"].float()
for head in HEAD_LABELS:
label_key = f"labels_{head}"
if label_key in batch:
batch[label_key] = batch[label_key].long()
return batch