precis / init /tuning /data.py
compendious's picture
Output rendering, documentation, model readiness UI
851f234
Raw
History Blame Contribute Delete
2.35 kB
"""Data preparation utilities for training."""
import logging
from typing import Dict, List, Optional, Any
from torch.utils.data import Dataset
from transformers import PreTrainedTokenizer
from src.config import DataConfig
logger = logging.getLogger(__name__)
class SummarizationDataset(Dataset):
"""PyTorch Dataset for summarization training."""
def __init__(
self,
data: List[Dict[str, str]],
tokenizer: PreTrainedTokenizer,
config: Optional[DataConfig] = None,
):
self.data = data
self.tokenizer = tokenizer
self.config = config or DataConfig()
def __len__(self) -> int:
return len(self.data)
def __getitem__(self, idx: int) -> Dict[str, Any]:
item = self.data[idx]
prompt = self.config.format_prompt(item[self.config.input_column])
full_text = prompt + item[self.config.target_column] + self.tokenizer.eos_token
encoding = self.tokenizer(
full_text,
truncation=True,
max_length=self.config.max_input_length + self.config.max_target_length,
padding="max_length",
return_tensors="pt",
)
return {
"input_ids": encoding["input_ids"].squeeze(),
"attention_mask": encoding["attention_mask"].squeeze(),
"labels": encoding["input_ids"].squeeze(),
}
def create_dummy_data(num_samples: int = 10) -> List[Dict[str, str]]:
"""Generate dummy data for testing the training pipeline."""
samples = []
for i in range(num_samples):
samples.append({
"text": f"This is sample document {i}. It contains information about topic {i % 3}. "
f"The document discusses various aspects and provides detailed analysis. "
f"Key points include methodology, results, and conclusions for study {i}.",
"summary": f"Document {i} analyzes topic {i % 3}, covering methodology, results, and conclusions.",
})
logger.info(f"Created {num_samples} dummy samples")
return samples
def prepare_dataset(
data: List[Dict[str, str]],
tokenizer: PreTrainedTokenizer,
config: Optional[DataConfig] = None,
) -> SummarizationDataset:
"""Prepare dataset for training."""
return SummarizationDataset(data, tokenizer, config)