Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| from torch.nn import functional as F | |
| import math, time, os | |
| from torch.utils.data import Dataset, DataLoader | |
| import tiktoken | |
| # from torch.cuda.amp import autocast, GradScaler | |
| from torch.amp.autocast_mode import autocast | |
| from torch.amp.grad_scaler import GradScaler | |
| from tqdm import tqdm | |
| from datasets import load_dataset | |
| from components.model import GPTModel | |
| from components.tokenizer import encode, decode, tokenizer | |
| def decode(tokens): | |
| return tokenizer.decode(tokens) | |
| class TextDataset(Dataset): | |
| def __init__(self, hf_dataset, block_size): | |
| self.dataset = hf_dataset | |
| # self.tokenizer = tokenizer | |
| self.block_size = block_size | |
| def __len__(self): | |
| return len(self.dataset["train"]) | |
| def __getitem__(self, idx): | |
| # Start with a random index sample | |
| rand_idx = torch.randint(0, len(self.dataset["train"]), (1,)).item() | |
| text = self.dataset["train"][rand_idx]["text"] | |
| tokens = encode(text) | |
| # Keep appending more samples if too short | |
| while len(tokens) < self.block_size + 1: | |
| next_idx = torch.randint(0, len(self.dataset["train"]), (1,)).item() | |
| next_text = self.dataset["train"][next_idx]["text"] | |
| tokens.extend(encode(" " + next_text)) | |
| # Prevent runaway growth | |
| if len(tokens) > self.block_size * 2: | |
| break | |
| # Truncate to block_size + 1 | |
| tokens = torch.tensor(tokens[: self.block_size + 1]) | |
| x = tokens[: self.block_size] | |
| y = tokens[1 : self.block_size + 1] | |
| return x.long(), y.long() | |