AgriQA-assistant / src /data /tokenize_dataset.py
othdu's picture
Upload 18 files
571aca4 verified
#!/usr/bin/env python3
"""
Tokenization script for preprocessed agriQA dataset.
This script only handles tokenization of already preprocessed data files.
"""
import os
import logging
from typing import List
from datasets import Dataset
from transformers import AutoTokenizer
import argparse
# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
class DatasetTokenizer:
def __init__(self, model_name: str = "Qwen/Qwen1.5-1.8B-Chat", output_dir: str = "data"):
self.model_name = model_name
self.output_dir = output_dir
self.tokenizer = None
def load_tokenizer(self):
"""Load the tokenizer for the model."""
logger.info(f"Loading tokenizer for {self.model_name}")
try:
self.tokenizer = AutoTokenizer.from_pretrained(
self.model_name,
trust_remote_code=True,
padding_side="right"
)
# Set pad token if not present
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
logger.info("Tokenizer loaded successfully")
return True
except Exception as e:
logger.error(f"Failed to load tokenizer: {e}")
return False
def load_preprocessed_data(self, file_path: str) -> List[str]:
"""Load preprocessed data from text file."""
logger.info(f"Loading preprocessed data from {file_path}")
try:
with open(file_path, 'r', encoding='utf-8') as f:
lines = f.readlines()
# Remove empty lines and strip whitespace
data = [line.strip() for line in lines if line.strip()]
logger.info(f"Loaded {len(data)} samples from {file_path}")
return data
except Exception as e:
logger.error(f"Failed to load data from {file_path}: {e}")
return []
def tokenize_function(self, examples, max_length: int = 512):
"""Tokenize the text data for training."""
if not self.tokenizer:
raise ValueError("Tokenizer not loaded")
# Tokenize the text
tokenized = self.tokenizer(
examples['text'],
truncation=True,
padding='max_length', # Pad to max_length for consistent lengths
max_length=max_length,
return_tensors=None # Return lists, not tensors
)
# Ensure labels are properly formatted
labels = []
for i, input_ids in enumerate(tokenized['input_ids']):
# Create labels that are the same as input_ids
label = input_ids.copy()
# Mask padding tokens in labels (set to -100)
attention_mask = tokenized['attention_mask'][i]
for j, mask_val in enumerate(attention_mask):
if mask_val == 0: # This is a padding token
label[j] = -100
labels.append(label)
tokenized['labels'] = labels
# Add length column for memory optimization
lengths = [len(ids) for ids in tokenized['input_ids']]
tokenized['length'] = lengths
return tokenized
def tokenize_dataset(self, dataset: Dataset, max_length: int = 512) -> Dataset:
"""Tokenize the entire dataset."""
logger.info(f"Tokenizing dataset with max_length={max_length}")
tokenized_dataset = dataset.map(
lambda examples: self.tokenize_function(examples, max_length),
batched=True,
batch_size=100, # Process in smaller batches for memory efficiency
num_proc=1, # Use single process for Windows compatibility
remove_columns=dataset.column_names,
desc="Tokenizing dataset"
)
logger.info(f"Tokenized dataset with {len(tokenized_dataset)} samples")
return tokenized_dataset
def run(self, max_length: int = 512):
"""Main tokenization process."""
logger.info("Starting dataset tokenization...")
# Check if tokenized datasets already exist
tokenized_dir = os.path.join(self.output_dir, "tokenized")
train_path = os.path.join(tokenized_dir, "train")
val_path = os.path.join(tokenized_dir, "validation")
if os.path.exists(train_path) and os.path.exists(val_path):
logger.info("Tokenized datasets already exist. Skipping tokenization.")
logger.info(f"Training samples: {len(Dataset.load_from_disk(train_path))}")
logger.info(f"Validation samples: {len(Dataset.load_from_disk(val_path))}")
return
# Load tokenizer
if not self.load_tokenizer():
logger.error("Failed to load tokenizer. Exiting.")
return
# Load preprocessed data
train_file = os.path.join(self.output_dir, "train_data.txt")
val_file = os.path.join(self.output_dir, "val_data.txt")
if not os.path.exists(train_file):
logger.error(f"Training data file not found: {train_file}")
return
if not os.path.exists(val_file):
logger.error(f"Validation data file not found: {val_file}")
return
train_data = self.load_preprocessed_data(train_file)
val_data = self.load_preprocessed_data(val_file)
if not train_data or not val_data:
logger.error("Failed to load preprocessed data. Exiting.")
return
# Create datasets for tokenization
train_dataset = Dataset.from_dict({"text": train_data})
val_dataset = Dataset.from_dict({"text": val_data})
# Tokenize datasets
logger.info("Tokenizing training dataset...")
tokenized_train = self.tokenize_dataset(train_dataset, max_length)
logger.info("Tokenizing validation dataset...")
tokenized_val = self.tokenize_dataset(val_dataset, max_length)
# Save tokenized datasets
os.makedirs(tokenized_dir, exist_ok=True)
logger.info(f"Saving tokenized datasets to {tokenized_dir}")
tokenized_train.save_to_disk(train_path)
tokenized_val.save_to_disk(val_path)
logger.info(f"Tokenized datasets saved successfully!")
logger.info(f"Training samples: {len(tokenized_train)}")
logger.info(f"Validation samples: {len(tokenized_val)}")
logger.info("Dataset tokenization completed successfully!")
def main():
parser = argparse.ArgumentParser(description="Tokenize preprocessed agriQA dataset")
parser.add_argument("--model_name", type=str, default="Qwen/Qwen1.5-1.8B-Chat",
help="Model name for tokenizer")
parser.add_argument("--output_dir", type=str, default="data",
help="Output directory for tokenized datasets")
parser.add_argument("--max_length", type=int, default=512,
help="Maximum sequence length for tokenization")
args = parser.parse_args()
tokenizer = DatasetTokenizer(
model_name=args.model_name,
output_dir=args.output_dir
)
tokenizer.run(max_length=args.max_length)
if __name__ == "__main__":
main()