File size: 10,173 Bytes
83aefdf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 |
# Model/data_loader.py
import torch
import os
import logging
# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
class TextDataLoader:
def __init__(self, file_path, batch_size, block_size, tokenizer, chunk_size=10**4):
self.file_path = file_path
self.batch_size = batch_size
self.block_size = block_size
self.tokenizer = tokenizer
self.chunk_size = chunk_size
self.file = open(self.file_path, 'r', encoding='utf-8')
self.data = None
self.end_of_file = False
# Load the initial chunk of data
self.load_chunk()
def load_chunk(self):
"""Load a chunk from the file, encode it, and handle end-of-file conditions."""
text = self.file.read()
if not text:
self.end_of_file = True
logging.info("End of file reached.")
else:
try:
# Encode the text using the tokenizer
encoded = self.tokenizer.encode(text)
if len(encoded) > 0:
self.data = torch.tensor(encoded, dtype=torch.long)
logging.info(f"Loaded new data chunk of size: {len(self.data)} tokens.")
# save the encoded data to a file
torch.save(self.data, "encoded_data.pth")
except Exception as e:
logging.error(f"Error encoding text chunk: {e}")
self.end_of_file = True
def num_batches(self):
"""Calculate the total number of batches in the current chunk."""
if self.data is not None:
return (len(self.data) - 1) // self.block_size # Total batches in the current chunk
return 0
def get_batch(self):
"""Retrieve a batch of data from the current chunk or load a new chunk if needed."""
if self.end_of_file:
return None, None # Return None when no data is left
# Generate a batch of data
ix = torch.randint(len(self.data) - self.block_size, (self.batch_size,))
x = torch.stack([self.data[i:i+self.block_size] for i in ix])
y = torch.stack([self.data[i+1:i+self.block_size+1] for i in ix])
return x, y
def reset(self):
"""Reset the file and flags for a new epoch."""
self.file.seek(0)
self.end_of_file = False
logging.info("Resetting file for a new epoch.")
self.load_chunk()
def close(self):
"""Clean up file resources when done."""
self.file.close()
logging.info("File closed.")
def __iter__(self):
"""Make the data loader iterable so it can be used in a loop."""
while not self.end_of_file:
x, y = self.get_batch()
if x is None or y is None:
break # Stop iteration if there's no more data
yield x, y # Yield a batch of data
# Once iteration is done, close the file
self.close()
#before parallelizing
# Set up logging
# logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
# class TextDataLoader:
# def __init__(self, file_path, batch_size, block_size, tokenizer, device='cpu', chunk_size=10**4):
# self.file_path = file_path
# self.batch_size = batch_size
# self.block_size = block_size
# self.tokenizer = tokenizer
# self.device = device
# self.chunk_size = chunk_size
# self.file = open(self.file_path, 'r', encoding='utf-8')
# self.data = None
# self.end_of_file = False
# # Load the initial chunk of data
# self.load_chunk()
# def load_chunk(self):
# """Load a chunk from the file, encode it, and handle end-of-file conditions."""
# text = self.file.read()
# if not text:
# self.end_of_file = True
# logging.info("End of file reached.")
# else:
# try:
# # Encode the text using the tokenizer
# encoded = self.tokenizer.encode(text)
# if len(encoded) > 0:
# self.data = torch.tensor(encoded, dtype=torch.long).to(self.device)
# logging.info(f"Loaded new data chunk of size: {len(self.data)} tokens.")
# except Exception as e:
# logging.error(f"Error encoding text chunk: {e}")
# self.end_of_file = True
# def num_batches(self):
# """Calculate the total number of batches in the current chunk."""
# if self.data is not None:
# return (len(self.data) - 1) // self.block_size # Total batches in the current chunk
# return 0
# def get_batch(self):
# """Retrieve a batch of data from the current chunk or load a new chunk if needed."""
# if self.end_of_file:
# return None, None # Return None when no data is left
# # Generate a batch of data
# ix = torch.randint(len(self.data) - self.block_size, (self.batch_size,))
# x = torch.stack([self.data[i:i+self.block_size] for i in ix])
# y = torch.stack([self.data[i+1:i+self.block_size+1] for i in ix])
# return x, y
# def reset(self):
# """Reset the file and flags for a new epoch."""
# self.file.seek(0)
# self.end_of_file = False
# logging.info("Resetting file for a new epoch.")
# self.load_chunk()
# def close(self):
# """Clean up file resources when done."""
# self.file.close()
# logging.info("File closed.")
# def __iter__(self):
# """Make the data loader iterable so it can be used in a loop."""
# while not self.end_of_file:
# x, y = self.get_batch()
# if x is None or y is None:
# break # Stop iteration if there's no more data
# yield x, y # Yield a batch of data
# # Once iteration is done, close the file
# self.close()
# # Set up logging
# logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
# class TextDataLoader:
# def __init__(self, file_path, batch_size, block_size, tokenizer, device='cpu', chunk_size=10**4):
# self.file_path = file_path
# self.batch_size = batch_size
# self.block_size = block_size
# self.tokenizer = tokenizer
# self.device = device
# self.chunk_size = chunk_size
# self.file = open(self.file_path, 'r', encoding='utf-8')
# self.data = None
# self.end_of_file = False
# # Print a preview of the file
# # self.print_file_preview()
# # Initial chunk loading
# self.load_chunk()
# def print_file_preview(self):
# """Prints the first few lines of the text file for preview"""
# self.file.seek(0) # Go to the beginning of the file
# lines = [self.file.readline() for _ in range(5)]
# preview_text = ''.join(lines)
# print("File preview:\n", preview_text)
# self.file.seek(0) # Reset to the start of the file for chunk reading
# def load_chunk(self):
# """Load a chunk from the file, encode it, and handle end-of-file conditions."""
# text = self.file.read()
# if not text:
# self.end_of_file = True
# logging.info("End of file reached.")
# else:
# try:
# # Log the first 100 characters of the text chunk to verify Urdu content
# # logging.info(f"First 100 characters of the chunk: {text[:100]}")
# # print("This is the chunk:", text)
# # Encode the text using the tokenizer
# # print("Tokenizer:", self.tokenizer)
# encoded = self.tokenizer.encode(text)
# print(len(encoded))
# print("encoded data: ")
# # Log the encoded output length to confirm successful encoding
# logging.info(f"Encoded data length: {len(encoded)} tokens")
# # if len(encoded) < self.block_size:
# # # Only stop if there's absolutely no usable data left
# # self.end_of_file = len(encoded) == 0
# # if self.end_of_file:
# # logging.warning("Insufficient data in chunk; stopping further loading.")
# # else:
# # logging.warning("Data chunk smaller than block size loaded; may limit training batch size.")
# if len(encoded) > 0:
# self.data = torch.tensor(encoded, dtype=torch.long).to(self.device)
# logging.info(f"Loaded new data chunk of size: {len(self.data)} tokens.")
# except Exception as e:
# logging.error(f"Error encoding text chunk: {e}")
# self.end_of_file = True
# def get_batch(self):
# """Retrieve a batch of data from the current chunk or load a new chunk if needed."""
# # if self.end_of_file:
# # return None, None # Return None when no data is left
# # if self.data is None or len(self.data) <= self.block_size:
# # self.load_chunk()
# # if self.end_of_file or self.data is None or len(self.data) < self.block_size:
# # return None, None # Stop if there’s insufficient data
# # Generate a batch of data
# ix = torch.randint(len(self.data) - self.block_size, (self.batch_size,))
# x = torch.stack([self.data[i:i+self.block_size] for i in ix])
# y = torch.stack([self.data[i+1:i+self.block_size+1] for i in ix])
# return x, y
# def reset(self):
# """Reset the file and flags for a new epoch."""
# self.file.seek(0)
# self.end_of_file = False
# logging.info("Resetting file for a new epoch.")
# self.load_chunk()
# def close(self):
# """Clean up file resources when done."""
# self.file.close()
# logging.info("File closed.") |