Spaces:
Runtime error
Runtime error
| # dataset_utils.py | |
| import pandas as pd | |
| import torch | |
| from torch.utils.data import Dataset, DataLoader | |
| from sklearn.preprocessing import LabelEncoder | |
| from transformers import BertTokenizer, RobertaTokenizer, DebertaTokenizer | |
| import pickle | |
| import os | |
| from config import TEXT_COLUMN, LABEL_COLUMNS, MAX_LEN, TOKENIZER_PATH, LABEL_ENCODERS_PATH, METADATA_COLUMNS | |
| class ComplianceDataset(Dataset): | |
| """ | |
| Custom Dataset class for handling text and multi-output labels for PyTorch models. | |
| """ | |
| def __init__(self, texts, labels, tokenizer, max_len): | |
| self.texts = texts | |
| self.labels = labels | |
| self.tokenizer = tokenizer | |
| self.max_len = max_len | |
| def __len__(self): | |
| """Returns the total number of samples in the dataset.""" | |
| return len(self.texts) | |
| def __getitem__(self, idx): | |
| """ | |
| Retrieves a sample from the dataset at the given index. | |
| Tokenizes the text and converts labels to a PyTorch tensor. | |
| """ | |
| text = str(self.texts[idx]) | |
| # Tokenize the text, padding to max_length and truncating if longer. | |
| # return_tensors="pt" ensures PyTorch tensors are returned. | |
| inputs = self.tokenizer( | |
| text, | |
| padding='max_length', | |
| truncation=True, | |
| max_length=self.max_len, | |
| return_tensors="pt" | |
| ) | |
| # Squeeze removes the batch dimension (which is 1 here because we process one sample at a time) | |
| inputs = {key: val.squeeze(0) for key, val in inputs.items()} | |
| # Convert labels to a PyTorch long tensor | |
| labels = torch.tensor(self.labels[idx], dtype=torch.long) | |
| return inputs, labels | |
| class ComplianceDatasetWithMetadata(Dataset): | |
| """ | |
| Custom Dataset class for handling text, additional numerical metadata, and multi-output labels. | |
| Used for hybrid models combining text and tabular features. | |
| """ | |
| def __init__(self, texts, metadata, labels, tokenizer, max_len): | |
| self.texts = texts | |
| self.metadata = metadata # Expects metadata as a NumPy array or list of lists | |
| self.labels = labels | |
| self.tokenizer = tokenizer | |
| self.max_len = max_len | |
| def __len__(self): | |
| """Returns the total number of samples in the dataset.""" | |
| return len(self.texts) | |
| def __getitem__(self, idx): | |
| """ | |
| Retrieves a sample, its metadata, and labels from the dataset at the given index. | |
| Tokenizes text, converts metadata and labels to PyTorch tensors. | |
| """ | |
| text = str(self.texts[idx]) | |
| inputs = self.tokenizer( | |
| text, | |
| padding='max_length', | |
| truncation=True, | |
| max_length=self.max_len, | |
| return_tensors="pt" | |
| ) | |
| inputs = {key: val.squeeze(0) for key, val in inputs.items()} | |
| # Convert metadata for the current sample to a float tensor | |
| metadata = torch.tensor(self.metadata[idx], dtype=torch.float) | |
| labels = torch.tensor(self.labels[idx], dtype=torch.long) | |
| return inputs, metadata, labels | |
| def load_and_preprocess_data(data_path): | |
| """ | |
| Loads data from a CSV, fills missing values, and encodes categorical labels. | |
| Also handles converting specified METADATA_COLUMNS to numeric. | |
| Args: | |
| data_path (str): Path to the CSV data file. | |
| Returns: | |
| tuple: A tuple containing: | |
| - data (pd.DataFrame): The preprocessed DataFrame. | |
| - label_encoders (dict): A dictionary of LabelEncoder objects for each label column. | |
| """ | |
| data = pd.read_csv(data_path) | |
| data.fillna("Unknown", inplace=True) # Fill any missing text values with "Unknown" | |
| # Convert metadata columns to numeric, coercing errors and filling NaNs with 0 | |
| # This ensures metadata is suitable for neural networks. | |
| for col in METADATA_COLUMNS: | |
| if col in data.columns: | |
| data[col] = pd.to_numeric(data[col], errors='coerce').fillna(0) # Fill NaN with 0 or a suitable value | |
| label_encoders = {col: LabelEncoder() for col in LABEL_COLUMNS} | |
| for col in LABEL_COLUMNS: | |
| # Fit and transform each label column using its respective LabelEncoder | |
| data[col] = label_encoders[col].fit_transform(data[col]) | |
| return data, label_encoders | |
| def get_tokenizer(model_name): | |
| """ | |
| Returns the appropriate Hugging Face tokenizer based on the model name. | |
| Args: | |
| model_name (str): The name of the pre-trained model (e.g., 'bert-base-uncased'). | |
| Returns: | |
| transformers.PreTrainedTokenizer: The initialized tokenizer. | |
| """ | |
| if "bert" in model_name.lower(): | |
| return BertTokenizer.from_pretrained(model_name) | |
| elif "roberta" in model_name.lower(): | |
| return RobertaTokenizer.from_pretrained(model_name) | |
| elif "deberta" in model_name.lower(): | |
| return DebertaTokenizer.from_pretrained(model_name) | |
| else: | |
| raise ValueError(f"Unsupported tokenizer for model: {model_name}") | |
| def save_label_encoders(label_encoders): | |
| """ | |
| Saves a dictionary of label encoders to a pickle file. | |
| This is crucial for decoding predictions back to original labels. | |
| Args: | |
| label_encoders (dict): Dictionary of LabelEncoder objects. | |
| """ | |
| with open(LABEL_ENCODERS_PATH, "wb") as f: | |
| pickle.dump(label_encoders, f) | |
| print(f"Label encoders saved to {LABEL_ENCODERS_PATH}") | |
| def load_label_encoders(): | |
| """ | |
| Loads a dictionary of label encoders from a pickle file. | |
| Returns: | |
| dict: Loaded dictionary of LabelEncoder objects. | |
| """ | |
| with open(LABEL_ENCODERS_PATH, "rb") as f: | |
| return pickle.load(f) | |
| print(f"Label encoders loaded from {LABEL_ENCODERS_PATH}") | |
| def get_num_labels(label_encoders): | |
| """ | |
| Returns a list containing the number of unique classes for each label column. | |
| This list is used to define the output dimensions of the model's classification heads. | |
| Args: | |
| label_encoders (dict): Dictionary of LabelEncoder objects. | |
| Returns: | |
| list: A list of integers, where each integer is the number of classes for a label. | |
| """ | |
| return [len(label_encoders[col].classes_) for col in LABEL_COLUMNS] |