File size: 7,676 Bytes
571aca4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
#!/usr/bin/env python3
"""

Tokenization script for preprocessed agriQA dataset.

This script only handles tokenization of already preprocessed data files.

"""

import os
import logging
from typing import List
from datasets import Dataset
from transformers import AutoTokenizer
import argparse

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

class DatasetTokenizer:
    def __init__(self, model_name: str = "Qwen/Qwen1.5-1.8B-Chat", output_dir: str = "data"):
        self.model_name = model_name
        self.output_dir = output_dir
        self.tokenizer = None
        
    def load_tokenizer(self):
        """Load the tokenizer for the model."""
        logger.info(f"Loading tokenizer for {self.model_name}")
        try:
            self.tokenizer = AutoTokenizer.from_pretrained(
                self.model_name,
                trust_remote_code=True,
                padding_side="right"
            )
            
            # Set pad token if not present
            if self.tokenizer.pad_token is None:
                self.tokenizer.pad_token = self.tokenizer.eos_token
                
            logger.info("Tokenizer loaded successfully")
            return True
        except Exception as e:
            logger.error(f"Failed to load tokenizer: {e}")
            return False
    
    def load_preprocessed_data(self, file_path: str) -> List[str]:
        """Load preprocessed data from text file."""
        logger.info(f"Loading preprocessed data from {file_path}")
        try:
            with open(file_path, 'r', encoding='utf-8') as f:
                lines = f.readlines()
            
            # Remove empty lines and strip whitespace
            data = [line.strip() for line in lines if line.strip()]
            logger.info(f"Loaded {len(data)} samples from {file_path}")
            return data
        except Exception as e:
            logger.error(f"Failed to load data from {file_path}: {e}")
            return []
    
    def tokenize_function(self, examples, max_length: int = 512):
        """Tokenize the text data for training."""
        if not self.tokenizer:
            raise ValueError("Tokenizer not loaded")
        
        # Tokenize the text
        tokenized = self.tokenizer(
            examples['text'],
            truncation=True,
            padding='max_length',  # Pad to max_length for consistent lengths
            max_length=max_length,
            return_tensors=None  # Return lists, not tensors
        )
        
        # Ensure labels are properly formatted
        labels = []
        for i, input_ids in enumerate(tokenized['input_ids']):
            # Create labels that are the same as input_ids
            label = input_ids.copy()
            # Mask padding tokens in labels (set to -100)
            attention_mask = tokenized['attention_mask'][i]
            for j, mask_val in enumerate(attention_mask):
                if mask_val == 0:  # This is a padding token
                    label[j] = -100
            labels.append(label)
        
        tokenized['labels'] = labels
        
        # Add length column for memory optimization
        lengths = [len(ids) for ids in tokenized['input_ids']]
        tokenized['length'] = lengths
        
        return tokenized
    
    def tokenize_dataset(self, dataset: Dataset, max_length: int = 512) -> Dataset:
        """Tokenize the entire dataset."""
        logger.info(f"Tokenizing dataset with max_length={max_length}")
        
        tokenized_dataset = dataset.map(
            lambda examples: self.tokenize_function(examples, max_length),
            batched=True,
            batch_size=100,  # Process in smaller batches for memory efficiency
            num_proc=1,  # Use single process for Windows compatibility
            remove_columns=dataset.column_names,
            desc="Tokenizing dataset"
        )
        
        logger.info(f"Tokenized dataset with {len(tokenized_dataset)} samples")
        return tokenized_dataset
    
    def run(self, max_length: int = 512):
        """Main tokenization process."""
        logger.info("Starting dataset tokenization...")
        
        # Check if tokenized datasets already exist
        tokenized_dir = os.path.join(self.output_dir, "tokenized")
        train_path = os.path.join(tokenized_dir, "train")
        val_path = os.path.join(tokenized_dir, "validation")
        
        if os.path.exists(train_path) and os.path.exists(val_path):
            logger.info("Tokenized datasets already exist. Skipping tokenization.")
            logger.info(f"Training samples: {len(Dataset.load_from_disk(train_path))}")
            logger.info(f"Validation samples: {len(Dataset.load_from_disk(val_path))}")
            return
        
        # Load tokenizer
        if not self.load_tokenizer():
            logger.error("Failed to load tokenizer. Exiting.")
            return
        
        # Load preprocessed data
        train_file = os.path.join(self.output_dir, "train_data.txt")
        val_file = os.path.join(self.output_dir, "val_data.txt")
        
        if not os.path.exists(train_file):
            logger.error(f"Training data file not found: {train_file}")
            return
        
        if not os.path.exists(val_file):
            logger.error(f"Validation data file not found: {val_file}")
            return
        
        train_data = self.load_preprocessed_data(train_file)
        val_data = self.load_preprocessed_data(val_file)
        
        if not train_data or not val_data:
            logger.error("Failed to load preprocessed data. Exiting.")
            return
        
        # Create datasets for tokenization
        train_dataset = Dataset.from_dict({"text": train_data})
        val_dataset = Dataset.from_dict({"text": val_data})
        
        # Tokenize datasets
        logger.info("Tokenizing training dataset...")
        tokenized_train = self.tokenize_dataset(train_dataset, max_length)
        
        logger.info("Tokenizing validation dataset...")
        tokenized_val = self.tokenize_dataset(val_dataset, max_length)
        
        # Save tokenized datasets
        os.makedirs(tokenized_dir, exist_ok=True)
        
        logger.info(f"Saving tokenized datasets to {tokenized_dir}")
        tokenized_train.save_to_disk(train_path)
        tokenized_val.save_to_disk(val_path)
        
        logger.info(f"Tokenized datasets saved successfully!")
        logger.info(f"Training samples: {len(tokenized_train)}")
        logger.info(f"Validation samples: {len(tokenized_val)}")
        logger.info("Dataset tokenization completed successfully!")

def main():
    parser = argparse.ArgumentParser(description="Tokenize preprocessed agriQA dataset")
    parser.add_argument("--model_name", type=str, default="Qwen/Qwen1.5-1.8B-Chat",
                       help="Model name for tokenizer")
    parser.add_argument("--output_dir", type=str, default="data",
                       help="Output directory for tokenized datasets")
    parser.add_argument("--max_length", type=int, default=512,
                       help="Maximum sequence length for tokenization")
    
    args = parser.parse_args()
    
    tokenizer = DatasetTokenizer(
        model_name=args.model_name,
        output_dir=args.output_dir
    )
    
    tokenizer.run(max_length=args.max_length)

if __name__ == "__main__":
    main()