import pandas as pd import torch from torch.utils.data import Dataset from sklearn.preprocessing import LabelEncoder import pickle from config import TEXT_COLUMN, LABEL_COLUMNS, MAX_LEN, TOKENIZER_PATH, LABEL_ENCODERS_PATH, METADATA_COLUMNS class ComplianceDataset(Dataset): def __init__(self, texts, labels, tokenizer, max_len): self.texts = texts self.labels = labels self.tokenizer = tokenizer self.max_len = max_len def __len__(self): return len(self.texts) def __getitem__(self, idx): text = str(self.texts[idx]) inputs = self.tokenizer( text, padding='max_length', truncation=True, max_length=self.max_len, return_tensors="pt" ) inputs = {key: val.squeeze(0) for key, val in inputs.items()} labels = torch.tensor(self.labels[idx], dtype=torch.long) return inputs, labels def save_label_encoders(label_encoders): with open(LABEL_ENCODERS_PATH, "wb") as f: pickle.dump(label_encoders, f) def load_label_encoders(path=LABEL_ENCODERS_PATH): with open(path, "rb") as f: return pickle.load(f)