testing-model / Naive_gpt /data_loader.py
AliMuhammad73's picture
gpt
83aefdf
raw
history blame
10.2 kB
# Model/data_loader.py
import torch
import os
import logging
# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
class TextDataLoader:
def __init__(self, file_path, batch_size, block_size, tokenizer, chunk_size=10**4):
self.file_path = file_path
self.batch_size = batch_size
self.block_size = block_size
self.tokenizer = tokenizer
self.chunk_size = chunk_size
self.file = open(self.file_path, 'r', encoding='utf-8')
self.data = None
self.end_of_file = False
# Load the initial chunk of data
self.load_chunk()
def load_chunk(self):
"""Load a chunk from the file, encode it, and handle end-of-file conditions."""
text = self.file.read()
if not text:
self.end_of_file = True
logging.info("End of file reached.")
else:
try:
# Encode the text using the tokenizer
encoded = self.tokenizer.encode(text)
if len(encoded) > 0:
self.data = torch.tensor(encoded, dtype=torch.long)
logging.info(f"Loaded new data chunk of size: {len(self.data)} tokens.")
# save the encoded data to a file
torch.save(self.data, "encoded_data.pth")
except Exception as e:
logging.error(f"Error encoding text chunk: {e}")
self.end_of_file = True
def num_batches(self):
"""Calculate the total number of batches in the current chunk."""
if self.data is not None:
return (len(self.data) - 1) // self.block_size # Total batches in the current chunk
return 0
def get_batch(self):
"""Retrieve a batch of data from the current chunk or load a new chunk if needed."""
if self.end_of_file:
return None, None # Return None when no data is left
# Generate a batch of data
ix = torch.randint(len(self.data) - self.block_size, (self.batch_size,))
x = torch.stack([self.data[i:i+self.block_size] for i in ix])
y = torch.stack([self.data[i+1:i+self.block_size+1] for i in ix])
return x, y
def reset(self):
"""Reset the file and flags for a new epoch."""
self.file.seek(0)
self.end_of_file = False
logging.info("Resetting file for a new epoch.")
self.load_chunk()
def close(self):
"""Clean up file resources when done."""
self.file.close()
logging.info("File closed.")
def __iter__(self):
"""Make the data loader iterable so it can be used in a loop."""
while not self.end_of_file:
x, y = self.get_batch()
if x is None or y is None:
break # Stop iteration if there's no more data
yield x, y # Yield a batch of data
# Once iteration is done, close the file
self.close()
#before parallelizing
# Set up logging
# logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
# class TextDataLoader:
# def __init__(self, file_path, batch_size, block_size, tokenizer, device='cpu', chunk_size=10**4):
# self.file_path = file_path
# self.batch_size = batch_size
# self.block_size = block_size
# self.tokenizer = tokenizer
# self.device = device
# self.chunk_size = chunk_size
# self.file = open(self.file_path, 'r', encoding='utf-8')
# self.data = None
# self.end_of_file = False
# # Load the initial chunk of data
# self.load_chunk()
# def load_chunk(self):
# """Load a chunk from the file, encode it, and handle end-of-file conditions."""
# text = self.file.read()
# if not text:
# self.end_of_file = True
# logging.info("End of file reached.")
# else:
# try:
# # Encode the text using the tokenizer
# encoded = self.tokenizer.encode(text)
# if len(encoded) > 0:
# self.data = torch.tensor(encoded, dtype=torch.long).to(self.device)
# logging.info(f"Loaded new data chunk of size: {len(self.data)} tokens.")
# except Exception as e:
# logging.error(f"Error encoding text chunk: {e}")
# self.end_of_file = True
# def num_batches(self):
# """Calculate the total number of batches in the current chunk."""
# if self.data is not None:
# return (len(self.data) - 1) // self.block_size # Total batches in the current chunk
# return 0
# def get_batch(self):
# """Retrieve a batch of data from the current chunk or load a new chunk if needed."""
# if self.end_of_file:
# return None, None # Return None when no data is left
# # Generate a batch of data
# ix = torch.randint(len(self.data) - self.block_size, (self.batch_size,))
# x = torch.stack([self.data[i:i+self.block_size] for i in ix])
# y = torch.stack([self.data[i+1:i+self.block_size+1] for i in ix])
# return x, y
# def reset(self):
# """Reset the file and flags for a new epoch."""
# self.file.seek(0)
# self.end_of_file = False
# logging.info("Resetting file for a new epoch.")
# self.load_chunk()
# def close(self):
# """Clean up file resources when done."""
# self.file.close()
# logging.info("File closed.")
# def __iter__(self):
# """Make the data loader iterable so it can be used in a loop."""
# while not self.end_of_file:
# x, y = self.get_batch()
# if x is None or y is None:
# break # Stop iteration if there's no more data
# yield x, y # Yield a batch of data
# # Once iteration is done, close the file
# self.close()
# # Set up logging
# logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
# class TextDataLoader:
# def __init__(self, file_path, batch_size, block_size, tokenizer, device='cpu', chunk_size=10**4):
# self.file_path = file_path
# self.batch_size = batch_size
# self.block_size = block_size
# self.tokenizer = tokenizer
# self.device = device
# self.chunk_size = chunk_size
# self.file = open(self.file_path, 'r', encoding='utf-8')
# self.data = None
# self.end_of_file = False
# # Print a preview of the file
# # self.print_file_preview()
# # Initial chunk loading
# self.load_chunk()
# def print_file_preview(self):
# """Prints the first few lines of the text file for preview"""
# self.file.seek(0) # Go to the beginning of the file
# lines = [self.file.readline() for _ in range(5)]
# preview_text = ''.join(lines)
# print("File preview:\n", preview_text)
# self.file.seek(0) # Reset to the start of the file for chunk reading
# def load_chunk(self):
# """Load a chunk from the file, encode it, and handle end-of-file conditions."""
# text = self.file.read()
# if not text:
# self.end_of_file = True
# logging.info("End of file reached.")
# else:
# try:
# # Log the first 100 characters of the text chunk to verify Urdu content
# # logging.info(f"First 100 characters of the chunk: {text[:100]}")
# # print("This is the chunk:", text)
# # Encode the text using the tokenizer
# # print("Tokenizer:", self.tokenizer)
# encoded = self.tokenizer.encode(text)
# print(len(encoded))
# print("encoded data: ")
# # Log the encoded output length to confirm successful encoding
# logging.info(f"Encoded data length: {len(encoded)} tokens")
# # if len(encoded) < self.block_size:
# # # Only stop if there's absolutely no usable data left
# # self.end_of_file = len(encoded) == 0
# # if self.end_of_file:
# # logging.warning("Insufficient data in chunk; stopping further loading.")
# # else:
# # logging.warning("Data chunk smaller than block size loaded; may limit training batch size.")
# if len(encoded) > 0:
# self.data = torch.tensor(encoded, dtype=torch.long).to(self.device)
# logging.info(f"Loaded new data chunk of size: {len(self.data)} tokens.")
# except Exception as e:
# logging.error(f"Error encoding text chunk: {e}")
# self.end_of_file = True
# def get_batch(self):
# """Retrieve a batch of data from the current chunk or load a new chunk if needed."""
# # if self.end_of_file:
# # return None, None # Return None when no data is left
# # if self.data is None or len(self.data) <= self.block_size:
# # self.load_chunk()
# # if self.end_of_file or self.data is None or len(self.data) < self.block_size:
# # return None, None # Stop if there’s insufficient data
# # Generate a batch of data
# ix = torch.randint(len(self.data) - self.block_size, (self.batch_size,))
# x = torch.stack([self.data[i:i+self.block_size] for i in ix])
# y = torch.stack([self.data[i+1:i+self.block_size+1] for i in ix])
# return x, y
# def reset(self):
# """Reset the file and flags for a new epoch."""
# self.file.seek(0)
# self.end_of_file = False
# logging.info("Resetting file for a new epoch.")
# self.load_chunk()
# def close(self):
# """Clean up file resources when done."""
# self.file.close()
# logging.info("File closed.")