from typing import Dict, List import torch from torch.utils.data import Dataset from config import PATHS, TRAINING_CONFIG from utils import read_jsonl def format_prompt(instruction: str, response: str) -> str: return ( f"### Instruction:\n{instruction}\n\n" f"### Response:\n{response}" ) class LocalJsonlInstructionDataset(Dataset): def __init__(self, tokenizer, max_length: int = TRAINING_CONFIG.max_length): self.tokenizer = tokenizer self.max_length = max_length self.samples: List[Dict[str, str]] = read_jsonl(PATHS.train_jsonl) if not self.samples: raise ValueError(f"No training samples found in {PATHS.train_jsonl}") def __len__(self) -> int: return len(self.samples) def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: sample = self.samples[idx] instruction = sample["instruction"] response = sample["response"] # 🔥 Build prompt (without response first) prompt = f"### Instruction:\n{instruction}\n\n### Response:\n" full_text = prompt + response # Tokenize full text encoded = self.tokenizer( full_text, truncation=True, max_length=self.max_length, padding="max_length", return_tensors="pt", ) input_ids = encoded["input_ids"].squeeze(0) attention_mask = encoded["attention_mask"].squeeze(0) labels = input_ids.clone() # 🔥 Mask instruction part (ONLY train on response) prompt_ids = self.tokenizer( prompt, truncation=True, max_length=self.max_length, return_tensors="pt", )["input_ids"].squeeze(0) prompt_len = min(len(prompt_ids), self.max_length) labels[:prompt_len] = -100 # Mask padding labels[attention_mask == 0] = -100 return { "input_ids": input_ids, "attention_mask": attention_mask, "labels": labels, }