chatbot / components /dataset.py
frc 10252
add files
3905c4a
import torch
import torch.nn as nn
from torch.nn import functional as F
import math, time, os
from torch.utils.data import Dataset, DataLoader
import tiktoken
# from torch.cuda.amp import autocast, GradScaler
from torch.amp.autocast_mode import autocast
from torch.amp.grad_scaler import GradScaler
from tqdm import tqdm
from datasets import load_dataset
from components.model import GPTModel
from components.tokenizer import encode, decode, tokenizer
def decode(tokens):
return tokenizer.decode(tokens)
class TextDataset(Dataset):
def __init__(self, hf_dataset, block_size):
self.dataset = hf_dataset
# self.tokenizer = tokenizer
self.block_size = block_size
def __len__(self):
return len(self.dataset["train"])
def __getitem__(self, idx):
# Start with a random index sample
rand_idx = torch.randint(0, len(self.dataset["train"]), (1,)).item()
text = self.dataset["train"][rand_idx]["text"]
tokens = encode(text)
# Keep appending more samples if too short
while len(tokens) < self.block_size + 1:
next_idx = torch.randint(0, len(self.dataset["train"]), (1,)).item()
next_text = self.dataset["train"][next_idx]["text"]
tokens.extend(encode(" " + next_text))
# Prevent runaway growth
if len(tokens) > self.block_size * 2:
break
# Truncate to block_size + 1
tokens = torch.tensor(tokens[: self.block_size + 1])
x = tokens[: self.block_size]
y = tokens[1 : self.block_size + 1]
return x.long(), y.long()