"""RL JSONL dataset with async-only streaming.""" from typing import Dict import torch from taoTrain.config import TrainingConfig from taoTrain.data.jsonl_base import BaseJSONLDataset class RLJSONLDataset(BaseJSONLDataset): """Dataset for RL training with local JSONL files with chunked loading.""" def _preprocess_chunk(self): """Prepare prompts for RL from current chunk.""" if not self._current_chunk_data or "text" not in self._current_chunk_data: return max_seq_length = self.config.model.max_seq_length texts = self._current_chunk_data["text"] # Tokenize all prompts in this chunk all_token_ids = [] all_attention_masks = [] for text in texts: tokenized = self.tokenizer( text, truncation=True, max_length=max_seq_length, padding="max_length", return_attention_mask=True, ) all_token_ids.append(tokenized["input_ids"]) all_attention_masks.append(tokenized["attention_mask"]) self._current_chunk_data = { "input_ids": all_token_ids, "attention_mask": all_attention_masks, } def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: """Get preprocessed prompt, loading chunk if needed.""" # Load appropriate chunk if using streaming if self.chunk_manager: chunk_num = self._get_chunk_for_idx(idx) if chunk_num != self._current_chunk_num: self._load_chunk(chunk_num) local_idx = self._get_local_idx_in_chunk(idx) else: local_idx = idx input_ids = torch.tensor(self._current_chunk_data["input_ids"][local_idx], dtype=torch.long) attention_mask = torch.tensor(self._current_chunk_data["attention_mask"][local_idx], dtype=torch.long) # For RL, no labels yet (generated during training) return { "input_ids": input_ids, "attention_mask": attention_mask, }