| from typing import Dict, List |
|
|
| import torch |
| from torch.utils.data import Dataset |
|
|
| from config import PATHS, TRAINING_CONFIG |
| from utils import read_jsonl |
|
|
|
|
| def format_prompt(instruction: str, response: str) -> str: |
| return ( |
| f"### Instruction:\n{instruction}\n\n" |
| f"### Response:\n{response}" |
| ) |
|
|
|
|
| class LocalJsonlInstructionDataset(Dataset): |
| def __init__(self, tokenizer, max_length: int = TRAINING_CONFIG.max_length): |
| self.tokenizer = tokenizer |
| self.max_length = max_length |
| self.samples: List[Dict[str, str]] = read_jsonl(PATHS.train_jsonl) |
|
|
| if not self.samples: |
| raise ValueError(f"No training samples found in {PATHS.train_jsonl}") |
|
|
| def __len__(self) -> int: |
| return len(self.samples) |
|
|
| def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: |
| sample = self.samples[idx] |
|
|
| instruction = sample["instruction"] |
| response = sample["response"] |
|
|
| |
| prompt = f"### Instruction:\n{instruction}\n\n### Response:\n" |
| full_text = prompt + response |
|
|
| |
| encoded = self.tokenizer( |
| full_text, |
| truncation=True, |
| max_length=self.max_length, |
| padding="max_length", |
| return_tensors="pt", |
| ) |
|
|
| input_ids = encoded["input_ids"].squeeze(0) |
| attention_mask = encoded["attention_mask"].squeeze(0) |
|
|
| labels = input_ids.clone() |
|
|
| |
| prompt_ids = self.tokenizer( |
| prompt, |
| truncation=True, |
| max_length=self.max_length, |
| return_tensors="pt", |
| )["input_ids"].squeeze(0) |
|
|
| prompt_len = min(len(prompt_ids), self.max_length) |
| labels[:prompt_len] = -100 |
|
|
| |
| labels[attention_mask == 0] = -100 |
|
|
| return { |
| "input_ids": input_ids, |
| "attention_mask": attention_mask, |
| "labels": labels, |
| } |