Spaces:
Runtime error
Runtime error
| import pandas as pd | |
| import torch | |
| from torch.utils.data import Dataset | |
| from sklearn.preprocessing import LabelEncoder | |
| import pickle | |
| from config import TEXT_COLUMN, LABEL_COLUMNS, MAX_LEN, TOKENIZER_PATH, LABEL_ENCODERS_PATH, METADATA_COLUMNS | |
| class ComplianceDataset(Dataset): | |
| def __init__(self, texts, labels, tokenizer, max_len): | |
| self.texts = texts | |
| self.labels = labels | |
| self.tokenizer = tokenizer | |
| self.max_len = max_len | |
| def __len__(self): | |
| return len(self.texts) | |
| def __getitem__(self, idx): | |
| text = str(self.texts[idx]) | |
| inputs = self.tokenizer( | |
| text, | |
| padding='max_length', | |
| truncation=True, | |
| max_length=self.max_len, | |
| return_tensors="pt" | |
| ) | |
| inputs = {key: val.squeeze(0) for key, val in inputs.items()} | |
| labels = torch.tensor(self.labels[idx], dtype=torch.long) | |
| return inputs, labels | |
| def save_label_encoders(label_encoders): | |
| with open(LABEL_ENCODERS_PATH, "wb") as f: | |
| pickle.dump(label_encoders, f) | |
| def load_label_encoders(path=LABEL_ENCODERS_PATH): | |
| with open(path, "rb") as f: | |
| return pickle.load(f) | |