subbunanepalli commited on
Commit
356cf69
·
verified ·
1 Parent(s): 891d206

Create dataset_utils.py

Browse files
Files changed (1) hide show
  1. dataset_utils.py +93 -0
dataset_utils.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import torch
3
+ from torch.utils.data import Dataset, DataLoader
4
+ from sklearn.preprocessing import LabelEncoder
5
+ from transformers import BertTokenizer, RobertaTokenizer, DebertaTokenizer
6
+ import pickle
7
+ import os
8
+
9
+ from config import TEXT_COLUMN, LABEL_COLUMNS, MAX_LEN, TOKENIZER_PATH, LABEL_ENCODERS_PATH, METADATA_COLUMNS
10
+
11
+ class ComplianceDataset(Dataset):
12
+ def __init__(self, texts, labels, tokenizer, max_len):
13
+ self.texts = texts
14
+ self.labels = labels
15
+ self.tokenizer = tokenizer
16
+ self.max_len = max_len
17
+
18
+ def __len__(self):
19
+ return len(self.texts)
20
+
21
+ def __getitem__(self, idx):
22
+ text = str(self.texts[idx])
23
+ inputs = self.tokenizer(
24
+ text,
25
+ padding='max_length',
26
+ truncation=True,
27
+ max_length=self.max_len,
28
+ return_tensors="pt"
29
+ )
30
+ inputs = {key: val.squeeze(0) for key, val in inputs.items()}
31
+ labels = torch.tensor(self.labels[idx], dtype=torch.long)
32
+ return inputs, labels
33
+
34
+ class ComplianceDatasetWithMetadata(Dataset):
35
+ def __init__(self, texts, metadata, labels, tokenizer, max_len):
36
+ self.texts = texts
37
+ self.metadata = metadata
38
+ self.labels = labels
39
+ self.tokenizer = tokenizer
40
+ self.max_len = max_len
41
+
42
+ def __len__(self):
43
+ return len(self.texts)
44
+
45
+ def __getitem__(self, idx):
46
+ text = str(self.texts[idx])
47
+ inputs = self.tokenizer(
48
+ text,
49
+ padding='max_length',
50
+ truncation=True,
51
+ max_length=self.max_len,
52
+ return_tensors="pt"
53
+ )
54
+ inputs = {key: val.squeeze(0) for key, val in inputs.items()}
55
+ metadata = torch.tensor(self.metadata[idx], dtype=torch.float)
56
+ labels = torch.tensor(self.labels[idx], dtype=torch.long)
57
+ return inputs, metadata, labels
58
+
59
+ def load_and_preprocess_data(data_path):
60
+ data = pd.read_csv(data_path)
61
+ data.fillna("Unknown", inplace=True)
62
+
63
+ for col in METADATA_COLUMNS:
64
+ if col in data.columns:
65
+ data[col] = pd.to_numeric(data[col], errors='coerce').fillna(0)
66
+
67
+ label_encoders = {col: LabelEncoder() for col in LABEL_COLUMNS}
68
+ for col in LABEL_COLUMNS:
69
+ data[col] = label_encoders[col].fit_transform(data[col])
70
+ return data, label_encoders
71
+
72
+ def get_tokenizer(model_name):
73
+ # Important: Order matters
74
+ if "deberta" in model_name.lower():
75
+ return DebertaTokenizer.from_pretrained(model_name)
76
+ elif "roberta" in model_name.lower():
77
+ return RobertaTokenizer.from_pretrained(model_name)
78
+ elif "bert" in model_name.lower():
79
+ return BertTokenizer.from_pretrained(model_name)
80
+ else:
81
+ raise ValueError(f"Unsupported tokenizer for model: {model_name}")
82
+
83
+ def save_label_encoders(label_encoders):
84
+ with open(LABEL_ENCODERS_PATH, "wb") as f:
85
+ pickle.dump(label_encoders, f)
86
+ print(f"Label encoders saved to {LABEL_ENCODERS_PATH}")
87
+
88
+ def load_label_encoders():
89
+ with open(LABEL_ENCODERS_PATH, "rb") as f:
90
+ return pickle.load(f)
91
+
92
+ def get_num_labels(label_encoders):
93
+ return [len(label_encoders[col].classes_) for col in LABEL_COLUMNS]