""" src/data.py — Data preparation for DialogSum fine-tuning of Phi-3-mini with LoRA. Loads DialogSum from HuggingFace Hub, formats examples with the Phi-3 chat template, tokenizes with a 1024-token limit, and applies manual label masking so the Trainer computes loss only on the assistant's summary tokens. The tokenizer is NOT loaded here — it is passed in as a parameter by train.py. """ from __future__ import annotations import functools from typing import Callable import torch from datasets import Dataset, DatasetDict, load_dataset from transformers import PreTrainedTokenizerBase # --------------------------------------------------------------------------- # Constants # --------------------------------------------------------------------------- INSTRUCTION = "Summarize the following conversation in a few sentences." DATASET_NAME = "knkarthick/dialogsum" DEFAULT_MAX_LENGTH = 1024 # --------------------------------------------------------------------------- # Formatting # --------------------------------------------------------------------------- def format_example( example: dict, tokenizer: PreTrainedTokenizerBase, ) -> str: """ Format a single DialogSum example into a Phi-3 chat template string. Uses tokenizer.apply_chat_template with role=user for the instruction+dialogue and role=assistant for the summary. Returns the raw formatted text (not tokenized). This is the string that gets printed in the CLI sanity check. Args: example: A DialogSum example dict with keys 'dialogue' and 'summary'. tokenizer: A Phi-3 tokenizer with apply_chat_template support. Returns: Formatted string with Phi-3 special tokens embedded. """ messages = [ { "role": "user", "content": f"{INSTRUCTION}\n\nConversation:\n{example['dialogue']}", }, { "role": "assistant", "content": example["summary"], }, ] return tokenizer.apply_chat_template(messages, tokenize=False) # --------------------------------------------------------------------------- # Tokenization + label masking # --------------------------------------------------------------------------- def tokenize_and_mask( example: dict, tokenizer: PreTrainedTokenizerBase, max_length: int = DEFAULT_MAX_LENGTH, ) -> dict: """ Tokenize one DialogSum example and apply prompt-masking on labels. Builds input_ids from the full formatted sequence (prompt + summary). Builds labels as a copy of input_ids with prompt positions replaced by -100, so cross-entropy loss is computed ONLY on the assistant's summary tokens. The prompt boundary is determined by tokenizing just the prompt portion (with add_generation_prompt=True) and using its token count as the mask cutoff. If the full sequence exceeds max_length, transformers truncates from the right (cutting the end of the dialogue, not the summary). If prompt_len >= len(input_ids) after truncation (extremely long dialogue edge case), all labels are -100 and this example contributes zero loss — this is acceptable. Args: example: A DialogSum example dict with keys 'dialogue' and 'summary'. tokenizer: A Phi-3 tokenizer (must have apply_chat_template, padding_side='right'). max_length: Maximum token sequence length. Defaults to 1024. Returns: Dict with keys: - 'input_ids': List[int], length <= max_length - 'attention_mask': List[int], all 1s (pre-padding) - 'labels': List[int], same length as input_ids, prompt positions are -100 """ # Step 1: build full text (prompt + summary) messages_full = [ { "role": "user", "content": f"{INSTRUCTION}\n\nConversation:\n{example['dialogue']}", }, { "role": "assistant", "content": example["summary"], }, ] full_text: str = tokenizer.apply_chat_template(messages_full, tokenize=False) # Step 2: build prompt-only text (up to and including <|assistant|>\n) messages_prompt = [ { "role": "user", "content": f"{INSTRUCTION}\n\nConversation:\n{example['dialogue']}", }, ] prompt_text: str = tokenizer.apply_chat_template( messages_prompt, tokenize=False, add_generation_prompt=True, ) # Step 3: tokenize full sequence with truncation enc = tokenizer( full_text, add_special_tokens=False, max_length=max_length, truncation=True, ) input_ids: list[int] = enc["input_ids"] attention_mask: list[int] = enc["attention_mask"] # Step 4: tokenize prompt only to find the boundary prompt_enc = tokenizer( prompt_text, add_special_tokens=False, max_length=max_length, truncation=True, ) prompt_len: int = len(prompt_enc["input_ids"]) # Step 5: clamp and build labels # Guard: if an extremely long dialogue was truncated so hard that the summary # was completely cut off, prompt_len could equal len(input_ids). # In that case all labels are -100 (zero loss contribution). Acceptable. prompt_len = min(prompt_len, len(input_ids)) labels: list[int] = [-100] * prompt_len + input_ids[prompt_len:] return { "input_ids": input_ids, "attention_mask": attention_mask, "labels": labels, } # --------------------------------------------------------------------------- # Dataset preparation # --------------------------------------------------------------------------- def prepare_datasets( tokenizer: PreTrainedTokenizerBase, max_length: int = DEFAULT_MAX_LENGTH, dataset_name: str = DATASET_NAME, num_proc: int = 4, ) -> tuple[Dataset, Dataset, Dataset]: """ Load DialogSum, format, tokenize, and return train/val/test splits. Pipeline: 1. load_dataset(dataset_name) — uses HF Hub, returns DatasetDict with train (~12,460), validation (~500), test (~819) splits. 2. Dataset.map(tokenize_and_mask, ...) on each split. 3. Remove original 'id', 'dialogue', 'summary' columns. The returned datasets contain only 'input_ids', 'attention_mask', 'labels'. They have variable sequence lengths (no padding) — pass the collator returned by make_data_collator() to the Trainer's data_collator argument. Args: tokenizer: Loaded and configured Phi-3 tokenizer. Caller must ensure tokenizer.padding_side == 'right' for training. max_length: Max token length for truncation. Defaults to 1024. dataset_name: HuggingFace dataset identifier. Defaults to 'knkarthick/dialogsum'. num_proc: Number of processes for Dataset.map. Defaults to 4. Returns: Tuple of (train_dataset, val_dataset, test_dataset). """ raw: DatasetDict = load_dataset(dataset_name) # Subsample train split to reduce training time on free Colab T4. # Val and test splits are kept in full for accurate evaluation. train_size = min(4000, len(raw["train"])) raw["train"] = raw["train"].shuffle(seed=42).select(range(train_size)) _map_fn = functools.partial( tokenize_and_mask, tokenizer=tokenizer, max_length=max_length, ) original_columns = raw["train"].column_names # ['id', 'dialogue', 'summary'] tokenized: DatasetDict = raw.map( _map_fn, batched=False, num_proc=num_proc, remove_columns=original_columns, desc="Tokenizing and masking labels", ) return tokenized["train"], tokenized["validation"], tokenized["test"] # --------------------------------------------------------------------------- # Collator (padding) # --------------------------------------------------------------------------- def make_data_collator( tokenizer: PreTrainedTokenizerBase, ) -> Callable[[list[dict]], dict[str, torch.Tensor]]: """ Return a collate function that pads a batch of tokenized examples. Pads 'input_ids' to the longest sequence in the batch using tokenizer.pad_token_id. Pads 'attention_mask' with 0. Pads 'labels' with -100 (ignored by cross-entropy loss). Padding is applied on the RIGHT (required for causal LM training). This collator is intentionally minimal — it does NOT do label shifting, DataCollatorForSeq2Seq masking, or any other transformation. All label masking was already done in tokenize_and_mask(). Args: tokenizer: Must have pad_token_id set (Phi-3: <|endoftext|>, id=32000). Returns: A collate_fn compatible with HuggingFace Trainer's data_collator argument. """ pad_id: int = tokenizer.pad_token_id def collate_fn(batch: list[dict]) -> dict[str, torch.Tensor]: max_len = max(len(item["input_ids"]) for item in batch) input_ids_padded = [] attention_mask_padded = [] labels_padded = [] for item in batch: seq_len = len(item["input_ids"]) pad_len = max_len - seq_len input_ids_padded.append(item["input_ids"] + [pad_id] * pad_len) attention_mask_padded.append(item["attention_mask"] + [0] * pad_len) labels_padded.append(item["labels"] + [-100] * pad_len) return { "input_ids": torch.tensor(input_ids_padded, dtype=torch.long), "attention_mask": torch.tensor(attention_mask_padded, dtype=torch.long), "labels": torch.tensor(labels_padded, dtype=torch.long), } return collate_fn # --------------------------------------------------------------------------- # CLI entry point # --------------------------------------------------------------------------- if __name__ == "__main__": import random from dotenv import load_dotenv from transformers import AutoTokenizer load_dotenv() # loads HF_TOKEN from .env if present MODEL_ID = "microsoft/Phi-3-mini-4k-instruct" print(f"Loading tokenizer: {MODEL_ID} ...") tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) tokenizer.padding_side = "right" # required for training print(f"Loading dataset: {DATASET_NAME} ...") raw = load_dataset(DATASET_NAME) print("\nSplit sizes:") print(f" train: {len(raw['train']):>6,}") print(f" validation: {len(raw['validation']):>6,}") print(f" test: {len(raw['test']):>6,}") # Show one formatted example example = raw["train"][0] formatted = format_example(example, tokenizer) print("\n--- Formatted example (raw chat template text) ---") print(formatted) # Sanity check: label masking tokenized_example = tokenize_and_mask(example, tokenizer) non_masked_ids = [t for t in tokenized_example["labels"] if t != -100] decoded = tokenizer.decode(non_masked_ids, skip_special_tokens=True) print("\n--- Sanity check: decoded labels vs. original summary ---") print(f"Original summary : {example['summary']!r}") print(f"Decoded labels : {decoded!r}") match = decoded.strip() == example["summary"].strip() print(f"Match : {match}") if not match: print("WARNING: Mismatch detected. Check tokenize_and_mask boundary logic.") # Token length stats on a random sample of 500 train examples sample = random.sample(list(raw["train"]), min(500, len(raw["train"]))) lengths = [] for ex in sample: enc = tokenize_and_mask(ex, tokenizer) lengths.append(len(enc["input_ids"])) print(f"\n--- Token length stats (sample n={len(sample)}) ---") print(f" min: {min(lengths)}") print(f" max: {max(lengths)}") print(f" mean: {sum(lengths) / len(lengths):.1f}") over_limit = sum(1 for length in lengths if length >= DEFAULT_MAX_LENGTH) print(f" >= {DEFAULT_MAX_LENGTH} (truncated): {over_limit} ({100 * over_limit / len(lengths):.1f}%)")