test-true2 / data.py
BRlkl's picture
Upload folder using huggingface_hub
bc7437c verified
Raw
History Blame Contribute Delete
18.8 kB
from __future__ import annotations
import hashlib
import math
import sys
from pathlib import Path
from typing import Any
import torch
from datasets import Dataset as HFDataset
from datasets import load_dataset
from torch.utils.data import BatchSampler, Dataset
SCRIPT_DIR = Path(__file__).resolve().parent
if str(SCRIPT_DIR) not in sys.path:
sys.path.insert(0, str(SCRIPT_DIR))
from config import ensure_dir, fingerprint_payload
OBS_ROLE_NONE = 0
OBS_ROLE_USER = 1
OBS_ROLE_AGENT_FEEDBACK = 2
SFT_CACHE_SCHEMA_VERSION = 2
DURATION_BUCKET_ORDER = {
"short": 0,
"medium": 1,
"long": 2,
"hour_scale": 3,
}
def load_clean_split(dataset_cfg: dict[str, Any], split: str) -> HFDataset:
local_parquet_dir = dataset_cfg.get("cleaned_local_parquet_dir")
if local_parquet_dir and Path(local_parquet_dir).exists():
parquet_path = Path(local_parquet_dir) / f"{split}.parquet"
if parquet_path.exists():
return load_dataset("parquet", data_files={split: str(parquet_path)}, split=split)
cleaned_repo_id = dataset_cfg.get("cleaned_repo_id")
if not cleaned_repo_id:
raise ValueError("Dataset config must provide either cleaned_local_parquet_dir or cleaned_repo_id.")
source_split = dataset_cfg.get("source_split")
if source_split:
dataset = load_dataset(cleaned_repo_id, split=str(source_split))
if split not in {"train", "validation"}:
return dataset
validation_ratio = float(dataset_cfg.get("validation_ratio", 0.02))
split_seed = int(dataset_cfg.get("split_seed", 17))
selected_indices = [
index
for index, row in enumerate(dataset)
if assign_deterministic_split(str(row["id"]), validation_ratio, split_seed) == split
]
return dataset.select(selected_indices)
return load_dataset(cleaned_repo_id, split=split)
def tokenize_text(tokenizer: Any, text: str | None, max_length: int) -> tuple[list[int], list[int]]:
if not text:
return [tokenizer.pad_token_id], [0]
encoded = tokenizer(
text,
add_special_tokens=True,
truncation=True,
max_length=max_length,
)
return encoded["input_ids"], encoded["attention_mask"]
def tokenizer_supports_chat_template(tokenizer: Any) -> bool:
return bool(getattr(tokenizer, "chat_template", None))
def should_use_base_chat_template(config: dict[str, Any], tokenizer: Any) -> bool:
dataset_cfg = config.get("dataset", {})
inference_cfg = config.get("inference", {})
if "use_base_chat_template" in inference_cfg:
return bool(inference_cfg["use_base_chat_template"]) and tokenizer_supports_chat_template(tokenizer)
return bool(dataset_cfg.get("use_base_chat_template", False)) and tokenizer_supports_chat_template(tokenizer)
def build_user_generation_observation(tokenizer: Any, text: str, use_base_chat_template: bool) -> str:
if use_base_chat_template:
return tokenizer.apply_chat_template(
[{"role": "user", "content": text}],
tokenize=False,
add_generation_prompt=True,
)
return text
def build_assistant_feedback_observation(tokenizer: Any, text: str, use_base_chat_template: bool) -> str:
del tokenizer
if use_base_chat_template:
return f"<start_of_turn>model\n{text}<end_of_turn>\n"
return text
def build_assistant_decoder_target(tokenizer: Any, text: str, use_base_chat_template: bool) -> str:
if not use_base_chat_template:
return text
prompt_text = tokenizer.apply_chat_template(
[{"role": "user", "content": ""}],
tokenize=False,
add_generation_prompt=True,
)
full_text = tokenizer.apply_chat_template(
[
{"role": "user", "content": ""},
{"role": "assistant", "content": text},
],
tokenize=False,
add_generation_prompt=False,
)
if not full_text.startswith(prompt_text):
raise ValueError("Chat template full text did not start with the generation prompt prefix.")
return full_text[len(prompt_text) :]
def duration_bucket_rank(bucket: str | None) -> int:
return DURATION_BUCKET_ORDER.get(str(bucket), len(DURATION_BUCKET_ORDER))
def assign_deterministic_split(row_id: str, validation_ratio: float, seed: int) -> str:
digest = hashlib.sha1(f"{seed}:{row_id}".encode("utf-8")).hexdigest()
score = int(digest[:8], 16) / 0xFFFFFFFF
return "validation" if score < validation_ratio else "train"
def stable_tick_index(time_seconds: float, tick_seconds: float) -> int:
ratio = max(0.0, float(time_seconds)) / tick_seconds
return int(math.ceil(ratio - 1e-9))
def quantize_message_sequence(
messages: list[dict[str, Any]],
*,
tick_seconds: float,
feedback_delay_seconds: float,
) -> list[dict[str, Any]]:
raw_events: list[dict[str, Any]] = []
event_order = 0
for message in messages:
speaker = str(message["speaker"])
text = str(message["text"])
time_seconds = float(message["t"])
if speaker == "user":
raw_events.append(
{
"time_seconds": time_seconds,
"order": event_order,
"kind": "user",
"text": text,
}
)
event_order += 1
continue
raw_events.append(
{
"time_seconds": time_seconds,
"order": event_order,
"kind": "agent_speak",
"text": text,
}
)
event_order += 1
raw_events.append(
{
"time_seconds": time_seconds + feedback_delay_seconds,
"order": event_order,
"kind": "agent_feedback",
"text": text,
}
)
event_order += 1
raw_events.sort(key=lambda item: (float(item["time_seconds"]), int(item["order"])))
quantized_events: list[dict[str, Any]] = []
previous_tick = -1
for event in raw_events:
desired_tick = stable_tick_index(float(event["time_seconds"]), tick_seconds)
assigned_tick = max(desired_tick, previous_tick + 1)
quantized_events.append(
{
"tick": assigned_tick,
"kind": event["kind"],
"text": event["text"],
}
)
previous_tick = assigned_tick
return quantized_events
def resolve_bucket_horizon_ticks(duration_bucket: str, rollout_cfg: dict[str, Any]) -> int:
bucket_map = rollout_cfg.get("horizon_ticks_by_bucket", {})
default_horizon = int(rollout_cfg.get("max_horizon_ticks", 36000))
return int(bucket_map.get(duration_bucket, default_horizon))
def build_fixed_tick_conversation(
*,
row: dict[str, Any],
tokenizer: Any,
config: dict[str, Any],
rollout_cfg: dict[str, Any],
max_observation_tokens: int,
max_decoder_tokens: int,
) -> dict[str, Any] | None:
tick_seconds = float(rollout_cfg["tick_seconds"])
chunk_ticks = int(rollout_cfg["chunk_ticks"])
feedback_delay_seconds = float(rollout_cfg["post_speech_feedback_delay_seconds"])
window_strategy = str(rollout_cfg.get("long_window_strategy", "tail"))
duration_bucket = str(row.get("meta", {}).get("duration_bucket", "short"))
use_base_chat_template = should_use_base_chat_template(config, tokenizer)
quantized_events = quantize_message_sequence(
messages=row["messages"],
tick_seconds=tick_seconds,
feedback_delay_seconds=feedback_delay_seconds,
)
total_ticks = max((event["tick"] for event in quantized_events), default=0) + 1
horizon_ticks = resolve_bucket_horizon_ticks(duration_bucket, rollout_cfg)
if bool(rollout_cfg.get("drop_overlong_examples", False)) and total_ticks > horizon_ticks:
return None
if total_ticks <= horizon_ticks:
effective_start_tick = 0
effective_total_ticks = total_ticks
elif window_strategy == "tail":
effective_start_tick = total_ticks - horizon_ticks
effective_total_ticks = horizon_ticks
else:
effective_start_tick = 0
effective_total_ticks = horizon_ticks
chunk_count = max(1, math.ceil(effective_total_ticks / chunk_ticks))
chunk_lengths = [
min(chunk_ticks, max(0, effective_total_ticks - chunk_index * chunk_ticks))
for chunk_index in range(chunk_count)
]
chunk_events: list[list[dict[str, Any]]] = [[] for _ in range(chunk_count)]
event_count = 0
for event in quantized_events:
absolute_tick = int(event["tick"])
if absolute_tick < effective_start_tick or absolute_tick >= effective_start_tick + effective_total_ticks:
continue
relative_tick = absolute_tick - effective_start_tick
chunk_index = relative_tick // chunk_ticks
offset = relative_tick % chunk_ticks
if event["kind"] == "user":
observation_role = OBS_ROLE_USER
observation_text = build_user_generation_observation(
tokenizer,
str(event["text"]),
use_base_chat_template,
)
gate_target = 0
decoder_text = None
elif event["kind"] == "agent_feedback":
observation_role = OBS_ROLE_AGENT_FEEDBACK
observation_text = build_assistant_feedback_observation(
tokenizer,
str(event["text"]),
use_base_chat_template,
)
gate_target = 0
decoder_text = None
else:
observation_role = OBS_ROLE_NONE
observation_text = None
gate_target = 1
decoder_text = build_assistant_decoder_target(
tokenizer,
str(event["text"]),
use_base_chat_template,
)
observation_input_ids, observation_attention_mask = tokenize_text(
tokenizer=tokenizer,
text=observation_text,
max_length=max_observation_tokens,
)
decoder_labels, _ = tokenize_text(
tokenizer=tokenizer,
text=decoder_text,
max_length=max_decoder_tokens,
)
if gate_target == 0:
decoder_labels = []
chunk_events[chunk_index].append(
{
"offset": int(offset),
"absolute_tick": absolute_tick,
"observation_role": int(observation_role),
"observation_input_ids": observation_input_ids,
"observation_attention_mask": observation_attention_mask,
"gate_target": int(gate_target),
"decoder_labels": decoder_labels,
}
)
event_count += 1
return {
"row_id": str(row["id"]),
"scenario_id": row.get("meta", {}).get("scenario_id"),
"duration_bucket": duration_bucket,
"bucket_rank": duration_bucket_rank(duration_bucket),
"total_ticks": int(total_ticks),
"effective_start_tick": int(effective_start_tick),
"effective_total_ticks": int(effective_total_ticks),
"chunk_count": int(chunk_count),
"chunk_lengths": chunk_lengths,
"chunk_events": chunk_events,
"event_count": int(event_count),
}
def build_chunk_batch(
*,
conversations: list[dict[str, Any]],
chunk_index: int,
pad_token_id: int,
chunk_ticks: int,
tick_seconds: float,
) -> dict[str, torch.Tensor] | None:
active_lengths: list[int] = []
max_observation_tokens = 1
max_decoder_tokens = 1
for conversation in conversations:
if chunk_index >= int(conversation["chunk_count"]):
active_lengths.append(0)
continue
step_count = int(conversation["chunk_lengths"][chunk_index])
active_lengths.append(step_count)
for event in conversation["chunk_events"][chunk_index]:
max_observation_tokens = max(max_observation_tokens, len(event["observation_input_ids"]))
max_decoder_tokens = max(max_decoder_tokens, len(event["decoder_labels"]) or 1)
max_steps = max(active_lengths, default=0)
if max_steps <= 0:
return None
batch_size = len(conversations)
observation_input_ids = torch.full(
(batch_size, max_steps, max_observation_tokens),
fill_value=pad_token_id,
dtype=torch.long,
)
observation_attention_mask = torch.zeros((batch_size, max_steps, max_observation_tokens), dtype=torch.long)
decoder_labels = torch.full((batch_size, max_steps, max_decoder_tokens), fill_value=-100, dtype=torch.long)
tick_mask = torch.zeros((batch_size, max_steps), dtype=torch.bool)
gate_target = torch.zeros((batch_size, max_steps), dtype=torch.float32)
observation_role = torch.zeros((batch_size, max_steps), dtype=torch.long)
delta_seconds = torch.zeros((batch_size, max_steps), dtype=torch.float32)
elapsed_seconds = torch.zeros((batch_size, max_steps), dtype=torch.float32)
for batch_index, conversation in enumerate(conversations):
step_count = active_lengths[batch_index]
if step_count <= 0:
continue
tick_mask[batch_index, :step_count] = True
for local_tick in range(step_count):
global_tick = chunk_index * chunk_ticks + local_tick
delta_seconds[batch_index, local_tick] = 0.0 if global_tick == 0 else tick_seconds
elapsed_seconds[batch_index, local_tick] = global_tick * tick_seconds
for event in conversation["chunk_events"][chunk_index]:
offset = int(event["offset"])
observation_role[batch_index, offset] = int(event["observation_role"])
gate_target[batch_index, offset] = float(event["gate_target"])
observation_ids = event["observation_input_ids"]
observation_mask = event["observation_attention_mask"]
observation_input_ids[batch_index, offset, : len(observation_ids)] = torch.tensor(
observation_ids,
dtype=torch.long,
)
observation_attention_mask[batch_index, offset, : len(observation_mask)] = torch.tensor(
observation_mask,
dtype=torch.long,
)
labels = event["decoder_labels"]
if labels:
decoder_labels[batch_index, offset, : len(labels)] = torch.tensor(labels, dtype=torch.long)
return {
"tick_mask": tick_mask,
"gate_target": gate_target,
"observation_role": observation_role,
"observation_input_ids": observation_input_ids,
"observation_attention_mask": observation_attention_mask,
"decoder_labels": decoder_labels,
"delta_seconds": delta_seconds,
"elapsed_seconds": elapsed_seconds,
}
class ThoughtLoopConversationDataset(Dataset):
def __init__(self, config: dict[str, Any], tokenizer: Any, split: str) -> None:
self.config = config
self.tokenizer = tokenizer
self.split = split
self.examples = self._load_or_build()
def _cache_path(self) -> Path:
dataset_cfg = self.config["dataset"]
model_cfg = self.config["model"]
rollout_cfg = self.config["rollout"]
cache_cfg = self.config["cache"]
cache_root = ensure_dir(cache_cfg["preprocessed_root"])
payload = {
"cache_schema_version": SFT_CACHE_SCHEMA_VERSION,
"split": self.split,
"dataset": dataset_cfg,
"model_name": model_cfg["base_model_name"],
"rollout": rollout_cfg,
"max_observation_tokens": model_cfg["max_observation_tokens"],
"max_decoder_tokens": model_cfg["max_decoder_tokens"],
"tokenizer_vocab_size": getattr(self.tokenizer, "vocab_size", None),
}
return cache_root / f"{self.split}_{fingerprint_payload(payload)}.pt"
def _load_or_build(self) -> list[dict[str, Any]]:
cache_path = self._cache_path()
if cache_path.exists():
return torch.load(cache_path, map_location="cpu")
dataset = load_clean_split(self.config["dataset"], self.split)
rollout_cfg = self.config["rollout"]
model_cfg = self.config["model"]
examples: list[dict[str, Any]] = []
for row in dataset:
example = build_fixed_tick_conversation(
row=row,
tokenizer=self.tokenizer,
config=self.config,
rollout_cfg=rollout_cfg,
max_observation_tokens=int(model_cfg["max_observation_tokens"]),
max_decoder_tokens=int(model_cfg["max_decoder_tokens"]),
)
if example is not None:
examples.append(example)
if bool(rollout_cfg.get("sort_by_duration_bucket", True)):
examples.sort(key=lambda item: (int(item["bucket_rank"]), int(item["total_ticks"]), str(item["row_id"])))
torch.save(examples, cache_path)
return examples
def __len__(self) -> int:
return len(self.examples)
def __getitem__(self, index: int) -> dict[str, Any]:
return self.examples[index]
class DurationBucketBatchSampler(BatchSampler):
def __init__(
self,
dataset: ThoughtLoopConversationDataset,
batch_size: int,
) -> None:
self.dataset = dataset
self.batch_size = batch_size
self.batches: list[list[int]] = []
current_bucket: str | None = None
current_batch: list[int] = []
for index, example in enumerate(self.dataset.examples):
bucket = str(example["duration_bucket"])
if current_batch and (bucket != current_bucket or len(current_batch) >= self.batch_size):
self.batches.append(current_batch)
current_batch = []
current_bucket = bucket
current_batch.append(index)
if len(current_batch) >= self.batch_size:
self.batches.append(current_batch)
current_batch = []
if current_batch:
self.batches.append(current_batch)
def __iter__(self) -> Any:
return iter(self.batches)
def __len__(self) -> int:
return len(self.batches)
def identity_collate(batch: list[dict[str, Any]]) -> list[dict[str, Any]]:
return batch
def estimate_total_chunk_microsteps(
*,
dataset: ThoughtLoopConversationDataset,
batch_sampler: DurationBucketBatchSampler,
) -> int:
total_microsteps = 0
for batch_indices in batch_sampler.batches:
total_microsteps += max(int(dataset.examples[index]["chunk_count"]) for index in batch_indices)
return total_microsteps