import json import torch from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler from .preprocessor import StepPreprocessor class AgentTrajectoryDataset(Dataset): """ Each item is one complete agent trajectory. The DataLoader must use batch_size=1 because trajectories have variable numbers of steps — there is no cross-trajectory batching. """ def __init__(self, json_path, preprocessor=None, max_len=512, max_steps=20): with open(json_path, "r", encoding="utf-8") as f: self.samples = json.load(f) self.preprocessor = preprocessor or StepPreprocessor(max_len=max_len) self.max_steps = max_steps def __len__(self): return len(self.samples) def __getitem__(self, idx): sample = self.samples[idx] is_hal = sample.get("is_hallucination", False) if isinstance(is_hal, str): is_hal = is_hal.lower() == "true" hal_step = sample.get("hallucination_step") if hal_step is not None: hal_step = int(hal_step) steps = self.preprocessor.encode_trajectory(sample) # Graceful empty-trajectory handling if not steps: ml = self.preprocessor.max_len return { "input_ids": torch.zeros((1, ml), dtype=torch.long), "attention_mask": torch.zeros((1, ml), dtype=torch.long), "hal_label": torch.tensor([0.0]), } # Truncation: if the trajectory is longer than max_steps, keep a window # that always INCLUDES the annotated hallucination step (if known). n = len(steps) if n > self.max_steps: if is_hal and hal_step is not None: # Find the position of the hallucination step hal_pos = next( (i for i, s in enumerate(steps) if s["step_idx"] == hal_step), n - 1, ) # Build a window of max_steps centred on hal_pos half = self.max_steps // 2 start = max(0, hal_pos - half) end = min(n, start + self.max_steps) start = max(0, end - self.max_steps) steps = steps[start:end] else: # No annotation — keep the last max_steps (model likely hallucinates late) steps = steps[-self.max_steps:] input_ids = torch.stack( [s["encoding"]["input_ids"].squeeze(0) for s in steps] ) attention_masks = torch.stack( [s["encoding"]["attention_mask"].squeeze(0) for s in steps] ) hal_labels = torch.tensor([float(s["label"]) for s in steps]) return { "input_ids": input_ids, # (N_steps, max_len) "attention_mask": attention_masks, # (N_steps, max_len) "hal_label": hal_labels, # (N_steps,) } def get_class_weights(self): """ Returns per-sample weights for WeightedRandomSampler so that each training batch has roughly equal representation of hallucinated and clean trajectories. """ is_hal_flags = [] for s in self.samples: flag = s.get("is_hallucination", False) if isinstance(flag, str): flag = flag.lower() == "true" is_hal_flags.append(int(flag)) n_hal = sum(is_hal_flags) n_clean = len(is_hal_flags) - n_hal w_hal = 1.0 / n_hal if n_hal else 1.0 w_clean = 1.0 / n_clean if n_clean else 1.0 weights = [w_hal if f else w_clean for f in is_hal_flags] return weights def get_dataloader(json_path, preprocessor, batch_size=1, shuffle=True, use_weighted_sampler=False): """ batch_size must be 1 because each item is a variable-length trajectory. use_weighted_sampler=True balances hallucinated vs. clean at the batch level. """ dataset = AgentTrajectoryDataset(json_path, preprocessor=preprocessor) if use_weighted_sampler and shuffle: weights = dataset.get_class_weights() sampler = WeightedRandomSampler( weights=weights, num_samples=len(weights), replacement=True, ) return DataLoader(dataset, batch_size=batch_size, sampler=sampler) return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)