| from __future__ import annotations |
|
|
| import hashlib |
| import math |
| import sys |
| from pathlib import Path |
| from typing import Any |
|
|
| import torch |
| from datasets import Dataset as HFDataset |
| from datasets import load_dataset |
| from torch.utils.data import BatchSampler, Dataset |
|
|
| SCRIPT_DIR = Path(__file__).resolve().parent |
| if str(SCRIPT_DIR) not in sys.path: |
| sys.path.insert(0, str(SCRIPT_DIR)) |
|
|
| from config import ensure_dir, fingerprint_payload |
|
|
|
|
| OBS_ROLE_NONE = 0 |
| OBS_ROLE_USER = 1 |
| OBS_ROLE_AGENT_FEEDBACK = 2 |
| SFT_CACHE_SCHEMA_VERSION = 2 |
|
|
| DURATION_BUCKET_ORDER = { |
| "short": 0, |
| "medium": 1, |
| "long": 2, |
| "hour_scale": 3, |
| } |
|
|
|
|
| def load_clean_split(dataset_cfg: dict[str, Any], split: str) -> HFDataset: |
| local_parquet_dir = dataset_cfg.get("cleaned_local_parquet_dir") |
| if local_parquet_dir and Path(local_parquet_dir).exists(): |
| parquet_path = Path(local_parquet_dir) / f"{split}.parquet" |
| if parquet_path.exists(): |
| return load_dataset("parquet", data_files={split: str(parquet_path)}, split=split) |
|
|
| cleaned_repo_id = dataset_cfg.get("cleaned_repo_id") |
| if not cleaned_repo_id: |
| raise ValueError("Dataset config must provide either cleaned_local_parquet_dir or cleaned_repo_id.") |
|
|
| source_split = dataset_cfg.get("source_split") |
| if source_split: |
| dataset = load_dataset(cleaned_repo_id, split=str(source_split)) |
| if split not in {"train", "validation"}: |
| return dataset |
|
|
| validation_ratio = float(dataset_cfg.get("validation_ratio", 0.02)) |
| split_seed = int(dataset_cfg.get("split_seed", 17)) |
| selected_indices = [ |
| index |
| for index, row in enumerate(dataset) |
| if assign_deterministic_split(str(row["id"]), validation_ratio, split_seed) == split |
| ] |
| return dataset.select(selected_indices) |
|
|
| return load_dataset(cleaned_repo_id, split=split) |
|
|
|
|
| def tokenize_text(tokenizer: Any, text: str | None, max_length: int) -> tuple[list[int], list[int]]: |
| if not text: |
| return [tokenizer.pad_token_id], [0] |
| encoded = tokenizer( |
| text, |
| add_special_tokens=True, |
| truncation=True, |
| max_length=max_length, |
| ) |
| return encoded["input_ids"], encoded["attention_mask"] |
|
|
|
|
| def tokenizer_supports_chat_template(tokenizer: Any) -> bool: |
| return bool(getattr(tokenizer, "chat_template", None)) |
|
|
|
|
| def should_use_base_chat_template(config: dict[str, Any], tokenizer: Any) -> bool: |
| dataset_cfg = config.get("dataset", {}) |
| inference_cfg = config.get("inference", {}) |
| if "use_base_chat_template" in inference_cfg: |
| return bool(inference_cfg["use_base_chat_template"]) and tokenizer_supports_chat_template(tokenizer) |
| return bool(dataset_cfg.get("use_base_chat_template", False)) and tokenizer_supports_chat_template(tokenizer) |
|
|
|
|
| def build_user_generation_observation(tokenizer: Any, text: str, use_base_chat_template: bool) -> str: |
| if use_base_chat_template: |
| return tokenizer.apply_chat_template( |
| [{"role": "user", "content": text}], |
| tokenize=False, |
| add_generation_prompt=True, |
| ) |
| return text |
|
|
|
|
| def build_assistant_feedback_observation(tokenizer: Any, text: str, use_base_chat_template: bool) -> str: |
| del tokenizer |
| if use_base_chat_template: |
| return f"<start_of_turn>model\n{text}<end_of_turn>\n" |
| return text |
|
|
|
|
| def build_assistant_decoder_target(tokenizer: Any, text: str, use_base_chat_template: bool) -> str: |
| if not use_base_chat_template: |
| return text |
|
|
| prompt_text = tokenizer.apply_chat_template( |
| [{"role": "user", "content": ""}], |
| tokenize=False, |
| add_generation_prompt=True, |
| ) |
| full_text = tokenizer.apply_chat_template( |
| [ |
| {"role": "user", "content": ""}, |
| {"role": "assistant", "content": text}, |
| ], |
| tokenize=False, |
| add_generation_prompt=False, |
| ) |
| if not full_text.startswith(prompt_text): |
| raise ValueError("Chat template full text did not start with the generation prompt prefix.") |
| return full_text[len(prompt_text) :] |
|
|
|
|
| def duration_bucket_rank(bucket: str | None) -> int: |
| return DURATION_BUCKET_ORDER.get(str(bucket), len(DURATION_BUCKET_ORDER)) |
|
|
|
|
| def assign_deterministic_split(row_id: str, validation_ratio: float, seed: int) -> str: |
| digest = hashlib.sha1(f"{seed}:{row_id}".encode("utf-8")).hexdigest() |
| score = int(digest[:8], 16) / 0xFFFFFFFF |
| return "validation" if score < validation_ratio else "train" |
|
|
|
|
| def stable_tick_index(time_seconds: float, tick_seconds: float) -> int: |
| ratio = max(0.0, float(time_seconds)) / tick_seconds |
| return int(math.ceil(ratio - 1e-9)) |
|
|
|
|
| def quantize_message_sequence( |
| messages: list[dict[str, Any]], |
| *, |
| tick_seconds: float, |
| feedback_delay_seconds: float, |
| ) -> list[dict[str, Any]]: |
| raw_events: list[dict[str, Any]] = [] |
| event_order = 0 |
| for message in messages: |
| speaker = str(message["speaker"]) |
| text = str(message["text"]) |
| time_seconds = float(message["t"]) |
|
|
| if speaker == "user": |
| raw_events.append( |
| { |
| "time_seconds": time_seconds, |
| "order": event_order, |
| "kind": "user", |
| "text": text, |
| } |
| ) |
| event_order += 1 |
| continue |
|
|
| raw_events.append( |
| { |
| "time_seconds": time_seconds, |
| "order": event_order, |
| "kind": "agent_speak", |
| "text": text, |
| } |
| ) |
| event_order += 1 |
| raw_events.append( |
| { |
| "time_seconds": time_seconds + feedback_delay_seconds, |
| "order": event_order, |
| "kind": "agent_feedback", |
| "text": text, |
| } |
| ) |
| event_order += 1 |
|
|
| raw_events.sort(key=lambda item: (float(item["time_seconds"]), int(item["order"]))) |
|
|
| quantized_events: list[dict[str, Any]] = [] |
| previous_tick = -1 |
| for event in raw_events: |
| desired_tick = stable_tick_index(float(event["time_seconds"]), tick_seconds) |
| assigned_tick = max(desired_tick, previous_tick + 1) |
| quantized_events.append( |
| { |
| "tick": assigned_tick, |
| "kind": event["kind"], |
| "text": event["text"], |
| } |
| ) |
| previous_tick = assigned_tick |
|
|
| return quantized_events |
|
|
|
|
| def resolve_bucket_horizon_ticks(duration_bucket: str, rollout_cfg: dict[str, Any]) -> int: |
| bucket_map = rollout_cfg.get("horizon_ticks_by_bucket", {}) |
| default_horizon = int(rollout_cfg.get("max_horizon_ticks", 36000)) |
| return int(bucket_map.get(duration_bucket, default_horizon)) |
|
|
|
|
| def build_fixed_tick_conversation( |
| *, |
| row: dict[str, Any], |
| tokenizer: Any, |
| config: dict[str, Any], |
| rollout_cfg: dict[str, Any], |
| max_observation_tokens: int, |
| max_decoder_tokens: int, |
| ) -> dict[str, Any] | None: |
| tick_seconds = float(rollout_cfg["tick_seconds"]) |
| chunk_ticks = int(rollout_cfg["chunk_ticks"]) |
| feedback_delay_seconds = float(rollout_cfg["post_speech_feedback_delay_seconds"]) |
| window_strategy = str(rollout_cfg.get("long_window_strategy", "tail")) |
| duration_bucket = str(row.get("meta", {}).get("duration_bucket", "short")) |
| use_base_chat_template = should_use_base_chat_template(config, tokenizer) |
|
|
| quantized_events = quantize_message_sequence( |
| messages=row["messages"], |
| tick_seconds=tick_seconds, |
| feedback_delay_seconds=feedback_delay_seconds, |
| ) |
| total_ticks = max((event["tick"] for event in quantized_events), default=0) + 1 |
|
|
| horizon_ticks = resolve_bucket_horizon_ticks(duration_bucket, rollout_cfg) |
| if bool(rollout_cfg.get("drop_overlong_examples", False)) and total_ticks > horizon_ticks: |
| return None |
|
|
| if total_ticks <= horizon_ticks: |
| effective_start_tick = 0 |
| effective_total_ticks = total_ticks |
| elif window_strategy == "tail": |
| effective_start_tick = total_ticks - horizon_ticks |
| effective_total_ticks = horizon_ticks |
| else: |
| effective_start_tick = 0 |
| effective_total_ticks = horizon_ticks |
|
|
| chunk_count = max(1, math.ceil(effective_total_ticks / chunk_ticks)) |
| chunk_lengths = [ |
| min(chunk_ticks, max(0, effective_total_ticks - chunk_index * chunk_ticks)) |
| for chunk_index in range(chunk_count) |
| ] |
| chunk_events: list[list[dict[str, Any]]] = [[] for _ in range(chunk_count)] |
|
|
| event_count = 0 |
| for event in quantized_events: |
| absolute_tick = int(event["tick"]) |
| if absolute_tick < effective_start_tick or absolute_tick >= effective_start_tick + effective_total_ticks: |
| continue |
|
|
| relative_tick = absolute_tick - effective_start_tick |
| chunk_index = relative_tick // chunk_ticks |
| offset = relative_tick % chunk_ticks |
|
|
| if event["kind"] == "user": |
| observation_role = OBS_ROLE_USER |
| observation_text = build_user_generation_observation( |
| tokenizer, |
| str(event["text"]), |
| use_base_chat_template, |
| ) |
| gate_target = 0 |
| decoder_text = None |
| elif event["kind"] == "agent_feedback": |
| observation_role = OBS_ROLE_AGENT_FEEDBACK |
| observation_text = build_assistant_feedback_observation( |
| tokenizer, |
| str(event["text"]), |
| use_base_chat_template, |
| ) |
| gate_target = 0 |
| decoder_text = None |
| else: |
| observation_role = OBS_ROLE_NONE |
| observation_text = None |
| gate_target = 1 |
| decoder_text = build_assistant_decoder_target( |
| tokenizer, |
| str(event["text"]), |
| use_base_chat_template, |
| ) |
|
|
| observation_input_ids, observation_attention_mask = tokenize_text( |
| tokenizer=tokenizer, |
| text=observation_text, |
| max_length=max_observation_tokens, |
| ) |
| decoder_labels, _ = tokenize_text( |
| tokenizer=tokenizer, |
| text=decoder_text, |
| max_length=max_decoder_tokens, |
| ) |
| if gate_target == 0: |
| decoder_labels = [] |
|
|
| chunk_events[chunk_index].append( |
| { |
| "offset": int(offset), |
| "absolute_tick": absolute_tick, |
| "observation_role": int(observation_role), |
| "observation_input_ids": observation_input_ids, |
| "observation_attention_mask": observation_attention_mask, |
| "gate_target": int(gate_target), |
| "decoder_labels": decoder_labels, |
| } |
| ) |
| event_count += 1 |
|
|
| return { |
| "row_id": str(row["id"]), |
| "scenario_id": row.get("meta", {}).get("scenario_id"), |
| "duration_bucket": duration_bucket, |
| "bucket_rank": duration_bucket_rank(duration_bucket), |
| "total_ticks": int(total_ticks), |
| "effective_start_tick": int(effective_start_tick), |
| "effective_total_ticks": int(effective_total_ticks), |
| "chunk_count": int(chunk_count), |
| "chunk_lengths": chunk_lengths, |
| "chunk_events": chunk_events, |
| "event_count": int(event_count), |
| } |
|
|
|
|
| def build_chunk_batch( |
| *, |
| conversations: list[dict[str, Any]], |
| chunk_index: int, |
| pad_token_id: int, |
| chunk_ticks: int, |
| tick_seconds: float, |
| ) -> dict[str, torch.Tensor] | None: |
| active_lengths: list[int] = [] |
| max_observation_tokens = 1 |
| max_decoder_tokens = 1 |
|
|
| for conversation in conversations: |
| if chunk_index >= int(conversation["chunk_count"]): |
| active_lengths.append(0) |
| continue |
| step_count = int(conversation["chunk_lengths"][chunk_index]) |
| active_lengths.append(step_count) |
| for event in conversation["chunk_events"][chunk_index]: |
| max_observation_tokens = max(max_observation_tokens, len(event["observation_input_ids"])) |
| max_decoder_tokens = max(max_decoder_tokens, len(event["decoder_labels"]) or 1) |
|
|
| max_steps = max(active_lengths, default=0) |
| if max_steps <= 0: |
| return None |
|
|
| batch_size = len(conversations) |
| observation_input_ids = torch.full( |
| (batch_size, max_steps, max_observation_tokens), |
| fill_value=pad_token_id, |
| dtype=torch.long, |
| ) |
| observation_attention_mask = torch.zeros((batch_size, max_steps, max_observation_tokens), dtype=torch.long) |
| decoder_labels = torch.full((batch_size, max_steps, max_decoder_tokens), fill_value=-100, dtype=torch.long) |
| tick_mask = torch.zeros((batch_size, max_steps), dtype=torch.bool) |
| gate_target = torch.zeros((batch_size, max_steps), dtype=torch.float32) |
| observation_role = torch.zeros((batch_size, max_steps), dtype=torch.long) |
| delta_seconds = torch.zeros((batch_size, max_steps), dtype=torch.float32) |
| elapsed_seconds = torch.zeros((batch_size, max_steps), dtype=torch.float32) |
|
|
| for batch_index, conversation in enumerate(conversations): |
| step_count = active_lengths[batch_index] |
| if step_count <= 0: |
| continue |
|
|
| tick_mask[batch_index, :step_count] = True |
| for local_tick in range(step_count): |
| global_tick = chunk_index * chunk_ticks + local_tick |
| delta_seconds[batch_index, local_tick] = 0.0 if global_tick == 0 else tick_seconds |
| elapsed_seconds[batch_index, local_tick] = global_tick * tick_seconds |
|
|
| for event in conversation["chunk_events"][chunk_index]: |
| offset = int(event["offset"]) |
| observation_role[batch_index, offset] = int(event["observation_role"]) |
| gate_target[batch_index, offset] = float(event["gate_target"]) |
|
|
| observation_ids = event["observation_input_ids"] |
| observation_mask = event["observation_attention_mask"] |
| observation_input_ids[batch_index, offset, : len(observation_ids)] = torch.tensor( |
| observation_ids, |
| dtype=torch.long, |
| ) |
| observation_attention_mask[batch_index, offset, : len(observation_mask)] = torch.tensor( |
| observation_mask, |
| dtype=torch.long, |
| ) |
|
|
| labels = event["decoder_labels"] |
| if labels: |
| decoder_labels[batch_index, offset, : len(labels)] = torch.tensor(labels, dtype=torch.long) |
|
|
| return { |
| "tick_mask": tick_mask, |
| "gate_target": gate_target, |
| "observation_role": observation_role, |
| "observation_input_ids": observation_input_ids, |
| "observation_attention_mask": observation_attention_mask, |
| "decoder_labels": decoder_labels, |
| "delta_seconds": delta_seconds, |
| "elapsed_seconds": elapsed_seconds, |
| } |
|
|
|
|
| class ThoughtLoopConversationDataset(Dataset): |
| def __init__(self, config: dict[str, Any], tokenizer: Any, split: str) -> None: |
| self.config = config |
| self.tokenizer = tokenizer |
| self.split = split |
| self.examples = self._load_or_build() |
|
|
| def _cache_path(self) -> Path: |
| dataset_cfg = self.config["dataset"] |
| model_cfg = self.config["model"] |
| rollout_cfg = self.config["rollout"] |
| cache_cfg = self.config["cache"] |
| cache_root = ensure_dir(cache_cfg["preprocessed_root"]) |
| payload = { |
| "cache_schema_version": SFT_CACHE_SCHEMA_VERSION, |
| "split": self.split, |
| "dataset": dataset_cfg, |
| "model_name": model_cfg["base_model_name"], |
| "rollout": rollout_cfg, |
| "max_observation_tokens": model_cfg["max_observation_tokens"], |
| "max_decoder_tokens": model_cfg["max_decoder_tokens"], |
| "tokenizer_vocab_size": getattr(self.tokenizer, "vocab_size", None), |
| } |
| return cache_root / f"{self.split}_{fingerprint_payload(payload)}.pt" |
|
|
| def _load_or_build(self) -> list[dict[str, Any]]: |
| cache_path = self._cache_path() |
| if cache_path.exists(): |
| return torch.load(cache_path, map_location="cpu") |
|
|
| dataset = load_clean_split(self.config["dataset"], self.split) |
| rollout_cfg = self.config["rollout"] |
| model_cfg = self.config["model"] |
|
|
| examples: list[dict[str, Any]] = [] |
| for row in dataset: |
| example = build_fixed_tick_conversation( |
| row=row, |
| tokenizer=self.tokenizer, |
| config=self.config, |
| rollout_cfg=rollout_cfg, |
| max_observation_tokens=int(model_cfg["max_observation_tokens"]), |
| max_decoder_tokens=int(model_cfg["max_decoder_tokens"]), |
| ) |
| if example is not None: |
| examples.append(example) |
|
|
| if bool(rollout_cfg.get("sort_by_duration_bucket", True)): |
| examples.sort(key=lambda item: (int(item["bucket_rank"]), int(item["total_ticks"]), str(item["row_id"]))) |
|
|
| torch.save(examples, cache_path) |
| return examples |
|
|
| def __len__(self) -> int: |
| return len(self.examples) |
|
|
| def __getitem__(self, index: int) -> dict[str, Any]: |
| return self.examples[index] |
|
|
|
|
| class DurationBucketBatchSampler(BatchSampler): |
| def __init__( |
| self, |
| dataset: ThoughtLoopConversationDataset, |
| batch_size: int, |
| ) -> None: |
| self.dataset = dataset |
| self.batch_size = batch_size |
| self.batches: list[list[int]] = [] |
|
|
| current_bucket: str | None = None |
| current_batch: list[int] = [] |
| for index, example in enumerate(self.dataset.examples): |
| bucket = str(example["duration_bucket"]) |
| if current_batch and (bucket != current_bucket or len(current_batch) >= self.batch_size): |
| self.batches.append(current_batch) |
| current_batch = [] |
| current_bucket = bucket |
| current_batch.append(index) |
| if len(current_batch) >= self.batch_size: |
| self.batches.append(current_batch) |
| current_batch = [] |
|
|
| if current_batch: |
| self.batches.append(current_batch) |
|
|
| def __iter__(self) -> Any: |
| return iter(self.batches) |
|
|
| def __len__(self) -> int: |
| return len(self.batches) |
|
|
|
|
| def identity_collate(batch: list[dict[str, Any]]) -> list[dict[str, Any]]: |
| return batch |
|
|
|
|
| def estimate_total_chunk_microsteps( |
| *, |
| dataset: ThoughtLoopConversationDataset, |
| batch_sampler: DurationBucketBatchSampler, |
| ) -> int: |
| total_microsteps = 0 |
| for batch_indices in batch_sampler.batches: |
| total_microsteps += max(int(dataset.examples[index]["chunk_count"]) for index in batch_indices) |
| return total_microsteps |
|
|