File size: 10,651 Bytes
da89f1c
 
f292cd1
da89f1c
 
b41635a
 
 
da89f1c
 
 
 
 
 
b41635a
da89f1c
 
b42b434
f292cd1
da89f1c
b41635a
37586a2
 
 
f292cd1
 
 
 
b41635a
f292cd1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
da89f1c
 
 
 
 
 
37586a2
db7bdc3
da89f1c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77bc910
 
 
 
37586a2
77bc910
f292cd1
 
da89f1c
 
 
 
 
 
 
 
 
 
 
f292cd1
 
 
 
 
 
 
 
 
 
 
 
ffe97b0
f292cd1
 
ffe97b0
f292cd1
 
 
07a484d
 
f292cd1
07a484d
f292cd1
07a484d
 
 
 
 
 
 
 
 
 
 
f292cd1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
da89f1c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b41635a
 
8035112
b41635a
da89f1c
 
 
a4d7cd8
da89f1c
 
 
 
 
 
 
 
b41635a
 
 
da89f1c
a4d7cd8
 
 
da89f1c
ae47555
 
 
 
 
 
 
 
 
 
 
 
 
 
 
da89f1c
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer, AutoTokenizer
import pandas as pd
import numpy as np
import logging

logger = logging.getLogger(__name__)

class DocumentDataset(Dataset):
    """
    Dataset class for document classification
    with improved preprocessing and batching
    """
    def __init__(self, texts, labels, tokenizer_name='bert-base-uncased', max_length=512, num_classes=None):
        self.texts = texts
        self.labels = labels
        self.num_classes = num_classes
        self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
        self.max_length = max_length
        
        if labels is None:
            self.labels = np.zeros(len(texts), dtype=int)
        elif type(labels) is not np.ndarray and type(labels) is not list:
            # Validate labels
            unique_labels = set(labels)
            min_label = min(unique_labels) if unique_labels else 0
            max_label = max(unique_labels) if unique_labels else 0
            
            # Log warning if labels might be out of range
            if num_classes is not None and (min_label < 0 or max_label >= num_classes):
                logger.warning(f"Label Range Error: Labels must be between 0 and {num_classes-1}, "
                            f"but found range [{min_label}, {max_label}]")
                logger.warning(f"Unique label values: {sorted(unique_labels)}")
                
                # Fix labels by remapping them to start from 0 (some datasets might have labels starting from 1)
                if min_label != 0:
                    logger.warning(f"Auto-correcting labels to be zero-indexed...")
                    label_map = {original: idx for idx, original in enumerate(sorted(unique_labels))}
                    self.labels = np.array([label_map[label] for label in labels])
                    logger.warning(f"New unique label values: {sorted(set(self.labels))}")

        else:
            # If labels is a list or numpy array, there are multiple label columns
            # Validate each label column
            labels = np.array(labels)
            for i in range(labels.shape[1]):
                unique_labels = set(labels[:, i])
                min_label = min(unique_labels) if unique_labels else 0
                max_label = max(unique_labels) if unique_labels else 0
                
                # Log warning if labels might be out of range
                if num_classes is not None and (min_label < 0 or max_label >= num_classes):
                    logger.warning(f"Label Range Error: Labels must be between 0 and {num_classes-1}, "
                                f"but found range [{min_label}, {max_label}]")
                    logger.warning(f"Unique label values: {sorted(unique_labels)}")
                    
                    # Fix labels by remapping them to start from 0
                    if min_label != 0:
                        logger.warning(f"Auto-correcting labels to be zero-indexed...")
                        label_map = {original: idx for idx, original in enumerate(sorted(unique_labels))}
                        labels[:, i] = np.array([label_map[label] for label in labels[:, i]])
                        logger.warning(f"New unique label values: {sorted(set(labels[:, i]))}")

            self.labels = labels

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = str(self.texts[idx])
        label = self.labels[idx] if self.labels is not None else torch.tensor(0, dtype=torch.long)

        # Tokenize the text with attention mask and truncation
        encoding = self.tokenizer.encode_plus(
            text,
            add_special_tokens=True,
            max_length=self.max_length,
            return_token_type_ids=True,
            padding='max_length',
            truncation=True,
            return_attention_mask=True,
            return_tensors='pt'
        )

        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'token_type_ids': encoding['token_type_ids'].flatten(),
            'label': torch.tensor(label, dtype=torch.long)
        }

    def get_text_(self, idx):
        """Get original text for a given index"""
        return {
            'text': self.texts[idx],
            'label': self.labels[idx] if self.labels is not None else None
        }
    
def load_data(data_path, text_col='text', label_col: str | list ='label', validation_split=0.1, test_split=0.1, seed=42):
    """
    Load data from CSV/TSV and split into train, validation and test sets
    """
    # Determine file format based on extension
    if data_path.endswith('.csv'):
        df = pd.read_csv(data_path)
    elif data_path.endswith('.tsv'):
        df = pd.read_csv(data_path, sep='\t')
    else:
        raise ValueError("Unsupported file format. Please provide CSV or TSV file.")
    
    # If label_col is a list of columns, do the below but for each column
    if isinstance(label_col, list):
        labels = None
        for idx, label in enumerate(label_col):
            if label not in df.columns:
                raise ValueError(f"Label column '{label}' not found in the dataset.")
            
            # Convert labels to numeric if they aren't already
            if not np.issubdtype(df[label].dtype, np.number):
                label_map = {label: idx for idx, label in enumerate(sorted(df[label].unique()))}
                df[f'label_numeric_{idx}'] = df[label].map(label_map)
                if labels is None:
                    labels = df[f'label_numeric_{idx}'].values.reshape(-1, 1)
                else:
                    # Extend the labels array to dim 1
                    labels = np.concatenate((labels, df[f'label_numeric_{idx}'].values.reshape(-1, 1)), axis=1)
                
                # Log the mapping for reference
                logger.info(f"Label mapping for column '{label}': {label_map}")
            else: # Column is already numeric
                current_col_labels = df[label].values # Get current column's data
                # Check if labels start from 0
                min_label = current_col_labels.min()
                if min_label != 0:
                    logger.warning(f"Labels in numeric column '{label}' don't start from 0 (min={min_label}). Converting to zero-indexed...")
                    label_map = {lbl: idx for idx, lbl in enumerate(sorted(set(current_col_labels)))}
                    current_col_labels = np.array([label_map[lbl] for lbl in current_col_labels]) # Apply mapping to current column data

                # Concatenate this column to the main 'labels' array
                if labels is None:
                    # This is the first column encountered (and it's numeric)
                    labels = current_col_labels.reshape(-1, 1)
                else:
                    # Append this numeric column to existing labels
                    labels = np.concatenate((labels, current_col_labels.reshape(-1, 1)), axis=1)
    else: # In case there is only one label column
        # Convert labels to numeric if they aren't already
        if not np.issubdtype(df[label_col].dtype, np.number):
            label_map = {label: idx for idx, label in enumerate(sorted(df[label_col].unique()))}
            df['label_numeric'] = df[label_col].map(label_map)
            labels = df['label_numeric'].values
            
            logger.info(f"Label mapping: {label_map}")
        else:
            labels = df[label_col].values
            
            # Check if labels start from 0
            min_label = labels.min()
            if min_label != 0:
                logger.warning(f"Labels don't start from 0 (min={min_label}). Converting to zero-indexed...")
                label_map = {label: idx for idx, label in enumerate(sorted(set(labels)))}
                labels = np.array([label_map[label] for label in labels])
    
    # Create a DataFrame with text and numeric labels
    texts = df[text_col].values
    
    # Shuffle and split the data
    np.random.seed(seed)
    indices = np.random.permutation(len(texts))
    
    test_size = int(test_split * len(texts))
    val_size = int(validation_split * len(texts))
    train_size = len(texts) - test_size - val_size
    
    train_indices = indices[:train_size]
    val_indices = indices[train_size:train_size + val_size]
    test_indices = indices[train_size + val_size:]
    
    train_texts, train_labels = texts[train_indices], labels[train_indices]
    val_texts, val_labels = texts[val_indices], labels[val_indices]
    test_texts, test_labels = texts[test_indices], labels[test_indices]
    
    # Log stats about the dataset
    logger.info(f"Dataset splits: train={train_size}, val={val_size}, test={test_size}")
    # Also print the num_categories being passed
    
    return (train_texts, train_labels), (val_texts, val_labels), (test_texts, test_labels)

def create_data_loaders(train_data, val_data, test_data, tokenizer_name='bert-base-uncased', 
                       max_length=512, batch_size=16, num_classes=None, return_datasets=False):
    """
    Create DataLoader objects for training, validation and testing
    """
    train_texts, train_labels = train_data
    val_texts, val_labels = val_data
    test_texts, test_labels = test_data
    
    # Create datasets
    train_dataset = DocumentDataset(train_texts, train_labels, tokenizer_name, max_length, num_classes)
    val_dataset = DocumentDataset(val_texts, val_labels, tokenizer_name, max_length, num_classes)
    test_dataset = DocumentDataset(test_texts, test_labels, tokenizer_name, max_length, num_classes)
    
    if return_datasets:
        return train_dataset, val_dataset, test_dataset

    # Create data loaders
    if len(train_dataset.texts) == 0:
        logger.warning("Training dataset is empty. Check your data loading and splitting.")
        train_loader = None
    else:
        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    if len(val_dataset.texts) == 0:
        logger.warning("Validation dataset is empty. Check your data loading and splitting.")
        val_loader = None
    else:
        val_loader = DataLoader(val_dataset, batch_size=batch_size)
    if len(test_dataset.texts) == 0:
        logger.warning("Test dataset is empty. Check your data loading and splitting.")
        test_loader = None
    else:
        test_loader = DataLoader(test_dataset, batch_size=batch_size)
    
    return train_loader, val_loader, test_loader