#!/usr/bin/env python3 """ Tokenization script for preprocessed agriQA dataset. This script only handles tokenization of already preprocessed data files. """ import os import logging from typing import List from datasets import Dataset from transformers import AutoTokenizer import argparse # Set up logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) class DatasetTokenizer: def __init__(self, model_name: str = "Qwen/Qwen1.5-1.8B-Chat", output_dir: str = "data"): self.model_name = model_name self.output_dir = output_dir self.tokenizer = None def load_tokenizer(self): """Load the tokenizer for the model.""" logger.info(f"Loading tokenizer for {self.model_name}") try: self.tokenizer = AutoTokenizer.from_pretrained( self.model_name, trust_remote_code=True, padding_side="right" ) # Set pad token if not present if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token logger.info("Tokenizer loaded successfully") return True except Exception as e: logger.error(f"Failed to load tokenizer: {e}") return False def load_preprocessed_data(self, file_path: str) -> List[str]: """Load preprocessed data from text file.""" logger.info(f"Loading preprocessed data from {file_path}") try: with open(file_path, 'r', encoding='utf-8') as f: lines = f.readlines() # Remove empty lines and strip whitespace data = [line.strip() for line in lines if line.strip()] logger.info(f"Loaded {len(data)} samples from {file_path}") return data except Exception as e: logger.error(f"Failed to load data from {file_path}: {e}") return [] def tokenize_function(self, examples, max_length: int = 512): """Tokenize the text data for training.""" if not self.tokenizer: raise ValueError("Tokenizer not loaded") # Tokenize the text tokenized = self.tokenizer( examples['text'], truncation=True, padding='max_length', # Pad to max_length for consistent lengths max_length=max_length, return_tensors=None # Return lists, not tensors ) # Ensure labels are properly formatted labels = [] for i, input_ids in enumerate(tokenized['input_ids']): # Create labels that are the same as input_ids label = input_ids.copy() # Mask padding tokens in labels (set to -100) attention_mask = tokenized['attention_mask'][i] for j, mask_val in enumerate(attention_mask): if mask_val == 0: # This is a padding token label[j] = -100 labels.append(label) tokenized['labels'] = labels # Add length column for memory optimization lengths = [len(ids) for ids in tokenized['input_ids']] tokenized['length'] = lengths return tokenized def tokenize_dataset(self, dataset: Dataset, max_length: int = 512) -> Dataset: """Tokenize the entire dataset.""" logger.info(f"Tokenizing dataset with max_length={max_length}") tokenized_dataset = dataset.map( lambda examples: self.tokenize_function(examples, max_length), batched=True, batch_size=100, # Process in smaller batches for memory efficiency num_proc=1, # Use single process for Windows compatibility remove_columns=dataset.column_names, desc="Tokenizing dataset" ) logger.info(f"Tokenized dataset with {len(tokenized_dataset)} samples") return tokenized_dataset def run(self, max_length: int = 512): """Main tokenization process.""" logger.info("Starting dataset tokenization...") # Check if tokenized datasets already exist tokenized_dir = os.path.join(self.output_dir, "tokenized") train_path = os.path.join(tokenized_dir, "train") val_path = os.path.join(tokenized_dir, "validation") if os.path.exists(train_path) and os.path.exists(val_path): logger.info("Tokenized datasets already exist. Skipping tokenization.") logger.info(f"Training samples: {len(Dataset.load_from_disk(train_path))}") logger.info(f"Validation samples: {len(Dataset.load_from_disk(val_path))}") return # Load tokenizer if not self.load_tokenizer(): logger.error("Failed to load tokenizer. Exiting.") return # Load preprocessed data train_file = os.path.join(self.output_dir, "train_data.txt") val_file = os.path.join(self.output_dir, "val_data.txt") if not os.path.exists(train_file): logger.error(f"Training data file not found: {train_file}") return if not os.path.exists(val_file): logger.error(f"Validation data file not found: {val_file}") return train_data = self.load_preprocessed_data(train_file) val_data = self.load_preprocessed_data(val_file) if not train_data or not val_data: logger.error("Failed to load preprocessed data. Exiting.") return # Create datasets for tokenization train_dataset = Dataset.from_dict({"text": train_data}) val_dataset = Dataset.from_dict({"text": val_data}) # Tokenize datasets logger.info("Tokenizing training dataset...") tokenized_train = self.tokenize_dataset(train_dataset, max_length) logger.info("Tokenizing validation dataset...") tokenized_val = self.tokenize_dataset(val_dataset, max_length) # Save tokenized datasets os.makedirs(tokenized_dir, exist_ok=True) logger.info(f"Saving tokenized datasets to {tokenized_dir}") tokenized_train.save_to_disk(train_path) tokenized_val.save_to_disk(val_path) logger.info(f"Tokenized datasets saved successfully!") logger.info(f"Training samples: {len(tokenized_train)}") logger.info(f"Validation samples: {len(tokenized_val)}") logger.info("Dataset tokenization completed successfully!") def main(): parser = argparse.ArgumentParser(description="Tokenize preprocessed agriQA dataset") parser.add_argument("--model_name", type=str, default="Qwen/Qwen1.5-1.8B-Chat", help="Model name for tokenizer") parser.add_argument("--output_dir", type=str, default="data", help="Output directory for tokenized datasets") parser.add_argument("--max_length", type=int, default=512, help="Maximum sequence length for tokenization") args = parser.parse_args() tokenizer = DatasetTokenizer( model_name=args.model_name, output_dir=args.output_dir ) tokenizer.run(max_length=args.max_length) if __name__ == "__main__": main()