File size: 1,182 Bytes
bf6d7ab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import pandas as pd
import torch
from torch.utils.data import Dataset
from sklearn.preprocessing import LabelEncoder
import pickle
from config import TEXT_COLUMN, LABEL_COLUMNS, MAX_LEN, TOKENIZER_PATH, LABEL_ENCODERS_PATH, METADATA_COLUMNS

class ComplianceDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_len):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = str(self.texts[idx])
        inputs = self.tokenizer(
            text,
            padding='max_length',
            truncation=True,
            max_length=self.max_len,
            return_tensors="pt"
        )
        inputs = {key: val.squeeze(0) for key, val in inputs.items()}
        labels = torch.tensor(self.labels[idx], dtype=torch.long)
        return inputs, labels

def save_label_encoders(label_encoders):
    with open(LABEL_ENCODERS_PATH, "wb") as f:
        pickle.dump(label_encoders, f)

def load_label_encoders(path=LABEL_ENCODERS_PATH):
    with open(path, "rb") as f:
        return pickle.load(f)