from __future__ import annotations import hashlib import math import sys from pathlib import Path from typing import Any import torch from datasets import Dataset as HFDataset from datasets import load_dataset from torch.utils.data import BatchSampler, Dataset SCRIPT_DIR = Path(__file__).resolve().parent if str(SCRIPT_DIR) not in sys.path: sys.path.insert(0, str(SCRIPT_DIR)) from config import ensure_dir, fingerprint_payload OBS_ROLE_NONE = 0 OBS_ROLE_USER = 1 OBS_ROLE_AGENT_FEEDBACK = 2 SFT_CACHE_SCHEMA_VERSION = 2 DURATION_BUCKET_ORDER = { "short": 0, "medium": 1, "long": 2, "hour_scale": 3, } def load_clean_split(dataset_cfg: dict[str, Any], split: str) -> HFDataset: local_parquet_dir = dataset_cfg.get("cleaned_local_parquet_dir") if local_parquet_dir and Path(local_parquet_dir).exists(): parquet_path = Path(local_parquet_dir) / f"{split}.parquet" if parquet_path.exists(): return load_dataset("parquet", data_files={split: str(parquet_path)}, split=split) cleaned_repo_id = dataset_cfg.get("cleaned_repo_id") if not cleaned_repo_id: raise ValueError("Dataset config must provide either cleaned_local_parquet_dir or cleaned_repo_id.") source_split = dataset_cfg.get("source_split") if source_split: dataset = load_dataset(cleaned_repo_id, split=str(source_split)) if split not in {"train", "validation"}: return dataset validation_ratio = float(dataset_cfg.get("validation_ratio", 0.02)) split_seed = int(dataset_cfg.get("split_seed", 17)) selected_indices = [ index for index, row in enumerate(dataset) if assign_deterministic_split(str(row["id"]), validation_ratio, split_seed) == split ] return dataset.select(selected_indices) return load_dataset(cleaned_repo_id, split=split) def tokenize_text(tokenizer: Any, text: str | None, max_length: int) -> tuple[list[int], list[int]]: if not text: return [tokenizer.pad_token_id], [0] encoded = tokenizer( text, add_special_tokens=True, truncation=True, max_length=max_length, ) return encoded["input_ids"], encoded["attention_mask"] def tokenizer_supports_chat_template(tokenizer: Any) -> bool: return bool(getattr(tokenizer, "chat_template", None)) def should_use_base_chat_template(config: dict[str, Any], tokenizer: Any) -> bool: dataset_cfg = config.get("dataset", {}) inference_cfg = config.get("inference", {}) if "use_base_chat_template" in inference_cfg: return bool(inference_cfg["use_base_chat_template"]) and tokenizer_supports_chat_template(tokenizer) return bool(dataset_cfg.get("use_base_chat_template", False)) and tokenizer_supports_chat_template(tokenizer) def build_user_generation_observation(tokenizer: Any, text: str, use_base_chat_template: bool) -> str: if use_base_chat_template: return tokenizer.apply_chat_template( [{"role": "user", "content": text}], tokenize=False, add_generation_prompt=True, ) return text def build_assistant_feedback_observation(tokenizer: Any, text: str, use_base_chat_template: bool) -> str: del tokenizer if use_base_chat_template: return f"model\n{text}\n" return text def build_assistant_decoder_target(tokenizer: Any, text: str, use_base_chat_template: bool) -> str: if not use_base_chat_template: return text prompt_text = tokenizer.apply_chat_template( [{"role": "user", "content": ""}], tokenize=False, add_generation_prompt=True, ) full_text = tokenizer.apply_chat_template( [ {"role": "user", "content": ""}, {"role": "assistant", "content": text}, ], tokenize=False, add_generation_prompt=False, ) if not full_text.startswith(prompt_text): raise ValueError("Chat template full text did not start with the generation prompt prefix.") return full_text[len(prompt_text) :] def duration_bucket_rank(bucket: str | None) -> int: return DURATION_BUCKET_ORDER.get(str(bucket), len(DURATION_BUCKET_ORDER)) def assign_deterministic_split(row_id: str, validation_ratio: float, seed: int) -> str: digest = hashlib.sha1(f"{seed}:{row_id}".encode("utf-8")).hexdigest() score = int(digest[:8], 16) / 0xFFFFFFFF return "validation" if score < validation_ratio else "train" def stable_tick_index(time_seconds: float, tick_seconds: float) -> int: ratio = max(0.0, float(time_seconds)) / tick_seconds return int(math.ceil(ratio - 1e-9)) def quantize_message_sequence( messages: list[dict[str, Any]], *, tick_seconds: float, feedback_delay_seconds: float, ) -> list[dict[str, Any]]: raw_events: list[dict[str, Any]] = [] event_order = 0 for message in messages: speaker = str(message["speaker"]) text = str(message["text"]) time_seconds = float(message["t"]) if speaker == "user": raw_events.append( { "time_seconds": time_seconds, "order": event_order, "kind": "user", "text": text, } ) event_order += 1 continue raw_events.append( { "time_seconds": time_seconds, "order": event_order, "kind": "agent_speak", "text": text, } ) event_order += 1 raw_events.append( { "time_seconds": time_seconds + feedback_delay_seconds, "order": event_order, "kind": "agent_feedback", "text": text, } ) event_order += 1 raw_events.sort(key=lambda item: (float(item["time_seconds"]), int(item["order"]))) quantized_events: list[dict[str, Any]] = [] previous_tick = -1 for event in raw_events: desired_tick = stable_tick_index(float(event["time_seconds"]), tick_seconds) assigned_tick = max(desired_tick, previous_tick + 1) quantized_events.append( { "tick": assigned_tick, "kind": event["kind"], "text": event["text"], } ) previous_tick = assigned_tick return quantized_events def resolve_bucket_horizon_ticks(duration_bucket: str, rollout_cfg: dict[str, Any]) -> int: bucket_map = rollout_cfg.get("horizon_ticks_by_bucket", {}) default_horizon = int(rollout_cfg.get("max_horizon_ticks", 36000)) return int(bucket_map.get(duration_bucket, default_horizon)) def build_fixed_tick_conversation( *, row: dict[str, Any], tokenizer: Any, config: dict[str, Any], rollout_cfg: dict[str, Any], max_observation_tokens: int, max_decoder_tokens: int, ) -> dict[str, Any] | None: tick_seconds = float(rollout_cfg["tick_seconds"]) chunk_ticks = int(rollout_cfg["chunk_ticks"]) feedback_delay_seconds = float(rollout_cfg["post_speech_feedback_delay_seconds"]) window_strategy = str(rollout_cfg.get("long_window_strategy", "tail")) duration_bucket = str(row.get("meta", {}).get("duration_bucket", "short")) use_base_chat_template = should_use_base_chat_template(config, tokenizer) quantized_events = quantize_message_sequence( messages=row["messages"], tick_seconds=tick_seconds, feedback_delay_seconds=feedback_delay_seconds, ) total_ticks = max((event["tick"] for event in quantized_events), default=0) + 1 horizon_ticks = resolve_bucket_horizon_ticks(duration_bucket, rollout_cfg) if bool(rollout_cfg.get("drop_overlong_examples", False)) and total_ticks > horizon_ticks: return None if total_ticks <= horizon_ticks: effective_start_tick = 0 effective_total_ticks = total_ticks elif window_strategy == "tail": effective_start_tick = total_ticks - horizon_ticks effective_total_ticks = horizon_ticks else: effective_start_tick = 0 effective_total_ticks = horizon_ticks chunk_count = max(1, math.ceil(effective_total_ticks / chunk_ticks)) chunk_lengths = [ min(chunk_ticks, max(0, effective_total_ticks - chunk_index * chunk_ticks)) for chunk_index in range(chunk_count) ] chunk_events: list[list[dict[str, Any]]] = [[] for _ in range(chunk_count)] event_count = 0 for event in quantized_events: absolute_tick = int(event["tick"]) if absolute_tick < effective_start_tick or absolute_tick >= effective_start_tick + effective_total_ticks: continue relative_tick = absolute_tick - effective_start_tick chunk_index = relative_tick // chunk_ticks offset = relative_tick % chunk_ticks if event["kind"] == "user": observation_role = OBS_ROLE_USER observation_text = build_user_generation_observation( tokenizer, str(event["text"]), use_base_chat_template, ) gate_target = 0 decoder_text = None elif event["kind"] == "agent_feedback": observation_role = OBS_ROLE_AGENT_FEEDBACK observation_text = build_assistant_feedback_observation( tokenizer, str(event["text"]), use_base_chat_template, ) gate_target = 0 decoder_text = None else: observation_role = OBS_ROLE_NONE observation_text = None gate_target = 1 decoder_text = build_assistant_decoder_target( tokenizer, str(event["text"]), use_base_chat_template, ) observation_input_ids, observation_attention_mask = tokenize_text( tokenizer=tokenizer, text=observation_text, max_length=max_observation_tokens, ) decoder_labels, _ = tokenize_text( tokenizer=tokenizer, text=decoder_text, max_length=max_decoder_tokens, ) if gate_target == 0: decoder_labels = [] chunk_events[chunk_index].append( { "offset": int(offset), "absolute_tick": absolute_tick, "observation_role": int(observation_role), "observation_input_ids": observation_input_ids, "observation_attention_mask": observation_attention_mask, "gate_target": int(gate_target), "decoder_labels": decoder_labels, } ) event_count += 1 return { "row_id": str(row["id"]), "scenario_id": row.get("meta", {}).get("scenario_id"), "duration_bucket": duration_bucket, "bucket_rank": duration_bucket_rank(duration_bucket), "total_ticks": int(total_ticks), "effective_start_tick": int(effective_start_tick), "effective_total_ticks": int(effective_total_ticks), "chunk_count": int(chunk_count), "chunk_lengths": chunk_lengths, "chunk_events": chunk_events, "event_count": int(event_count), } def build_chunk_batch( *, conversations: list[dict[str, Any]], chunk_index: int, pad_token_id: int, chunk_ticks: int, tick_seconds: float, ) -> dict[str, torch.Tensor] | None: active_lengths: list[int] = [] max_observation_tokens = 1 max_decoder_tokens = 1 for conversation in conversations: if chunk_index >= int(conversation["chunk_count"]): active_lengths.append(0) continue step_count = int(conversation["chunk_lengths"][chunk_index]) active_lengths.append(step_count) for event in conversation["chunk_events"][chunk_index]: max_observation_tokens = max(max_observation_tokens, len(event["observation_input_ids"])) max_decoder_tokens = max(max_decoder_tokens, len(event["decoder_labels"]) or 1) max_steps = max(active_lengths, default=0) if max_steps <= 0: return None batch_size = len(conversations) observation_input_ids = torch.full( (batch_size, max_steps, max_observation_tokens), fill_value=pad_token_id, dtype=torch.long, ) observation_attention_mask = torch.zeros((batch_size, max_steps, max_observation_tokens), dtype=torch.long) decoder_labels = torch.full((batch_size, max_steps, max_decoder_tokens), fill_value=-100, dtype=torch.long) tick_mask = torch.zeros((batch_size, max_steps), dtype=torch.bool) gate_target = torch.zeros((batch_size, max_steps), dtype=torch.float32) observation_role = torch.zeros((batch_size, max_steps), dtype=torch.long) delta_seconds = torch.zeros((batch_size, max_steps), dtype=torch.float32) elapsed_seconds = torch.zeros((batch_size, max_steps), dtype=torch.float32) for batch_index, conversation in enumerate(conversations): step_count = active_lengths[batch_index] if step_count <= 0: continue tick_mask[batch_index, :step_count] = True for local_tick in range(step_count): global_tick = chunk_index * chunk_ticks + local_tick delta_seconds[batch_index, local_tick] = 0.0 if global_tick == 0 else tick_seconds elapsed_seconds[batch_index, local_tick] = global_tick * tick_seconds for event in conversation["chunk_events"][chunk_index]: offset = int(event["offset"]) observation_role[batch_index, offset] = int(event["observation_role"]) gate_target[batch_index, offset] = float(event["gate_target"]) observation_ids = event["observation_input_ids"] observation_mask = event["observation_attention_mask"] observation_input_ids[batch_index, offset, : len(observation_ids)] = torch.tensor( observation_ids, dtype=torch.long, ) observation_attention_mask[batch_index, offset, : len(observation_mask)] = torch.tensor( observation_mask, dtype=torch.long, ) labels = event["decoder_labels"] if labels: decoder_labels[batch_index, offset, : len(labels)] = torch.tensor(labels, dtype=torch.long) return { "tick_mask": tick_mask, "gate_target": gate_target, "observation_role": observation_role, "observation_input_ids": observation_input_ids, "observation_attention_mask": observation_attention_mask, "decoder_labels": decoder_labels, "delta_seconds": delta_seconds, "elapsed_seconds": elapsed_seconds, } class ThoughtLoopConversationDataset(Dataset): def __init__(self, config: dict[str, Any], tokenizer: Any, split: str) -> None: self.config = config self.tokenizer = tokenizer self.split = split self.examples = self._load_or_build() def _cache_path(self) -> Path: dataset_cfg = self.config["dataset"] model_cfg = self.config["model"] rollout_cfg = self.config["rollout"] cache_cfg = self.config["cache"] cache_root = ensure_dir(cache_cfg["preprocessed_root"]) payload = { "cache_schema_version": SFT_CACHE_SCHEMA_VERSION, "split": self.split, "dataset": dataset_cfg, "model_name": model_cfg["base_model_name"], "rollout": rollout_cfg, "max_observation_tokens": model_cfg["max_observation_tokens"], "max_decoder_tokens": model_cfg["max_decoder_tokens"], "tokenizer_vocab_size": getattr(self.tokenizer, "vocab_size", None), } return cache_root / f"{self.split}_{fingerprint_payload(payload)}.pt" def _load_or_build(self) -> list[dict[str, Any]]: cache_path = self._cache_path() if cache_path.exists(): return torch.load(cache_path, map_location="cpu") dataset = load_clean_split(self.config["dataset"], self.split) rollout_cfg = self.config["rollout"] model_cfg = self.config["model"] examples: list[dict[str, Any]] = [] for row in dataset: example = build_fixed_tick_conversation( row=row, tokenizer=self.tokenizer, config=self.config, rollout_cfg=rollout_cfg, max_observation_tokens=int(model_cfg["max_observation_tokens"]), max_decoder_tokens=int(model_cfg["max_decoder_tokens"]), ) if example is not None: examples.append(example) if bool(rollout_cfg.get("sort_by_duration_bucket", True)): examples.sort(key=lambda item: (int(item["bucket_rank"]), int(item["total_ticks"]), str(item["row_id"]))) torch.save(examples, cache_path) return examples def __len__(self) -> int: return len(self.examples) def __getitem__(self, index: int) -> dict[str, Any]: return self.examples[index] class DurationBucketBatchSampler(BatchSampler): def __init__( self, dataset: ThoughtLoopConversationDataset, batch_size: int, ) -> None: self.dataset = dataset self.batch_size = batch_size self.batches: list[list[int]] = [] current_bucket: str | None = None current_batch: list[int] = [] for index, example in enumerate(self.dataset.examples): bucket = str(example["duration_bucket"]) if current_batch and (bucket != current_bucket or len(current_batch) >= self.batch_size): self.batches.append(current_batch) current_batch = [] current_bucket = bucket current_batch.append(index) if len(current_batch) >= self.batch_size: self.batches.append(current_batch) current_batch = [] if current_batch: self.batches.append(current_batch) def __iter__(self) -> Any: return iter(self.batches) def __len__(self) -> int: return len(self.batches) def identity_collate(batch: list[dict[str, Any]]) -> list[dict[str, Any]]: return batch def estimate_total_chunk_microsteps( *, dataset: ThoughtLoopConversationDataset, batch_sampler: DurationBucketBatchSampler, ) -> int: total_microsteps = 0 for batch_indices in batch_sampler.batches: total_microsteps += max(int(dataset.examples[index]["chunk_count"]) for index in batch_indices) return total_microsteps