subbunanepalli commited on
Commit
93b526d
·
verified ·
1 Parent(s): b8c0e6b

Update dataset_utils.py

Browse files
Files changed (1) hide show
  1. dataset_utils.py +17 -8
dataset_utils.py CHANGED
@@ -1,12 +1,20 @@
1
  import pandas as pd
2
  import torch
3
- from torch.utils.data import Dataset, DataLoader
4
  from sklearn.preprocessing import LabelEncoder
5
  from transformers import BertTokenizer, RobertaTokenizer, DebertaTokenizer
6
  import pickle
7
  import os
8
 
9
- from config import TEXT_COLUMN, LABEL_COLUMNS, MAX_LEN, TOKENIZER_PATH, LABEL_ENCODERS_PATH, METADATA_COLUMNS
 
 
 
 
 
 
 
 
10
 
11
  class ComplianceDataset(Dataset):
12
  def __init__(self, texts, labels, tokenizer, max_len):
@@ -69,13 +77,14 @@ def load_and_preprocess_data(data_path):
69
  data[col] = label_encoders[col].fit_transform(data[col])
70
  return data, label_encoders
71
 
72
- def get_tokenizer(model_name):
73
- if "bert" in model_name.lower():
74
- return BertTokenizer.from_pretrained(model_name)
 
75
  elif "roberta" in model_name.lower():
76
- return RobertaTokenizer.from_pretrained(model_name)
77
- elif "deberta" in model_name.lower():
78
- return DebertaTokenizer.from_pretrained(model_name)
79
  else:
80
  raise ValueError(f"Unsupported tokenizer for model: {model_name}")
81
 
 
1
  import pandas as pd
2
  import torch
3
+ from torch.utils.data import Dataset
4
  from sklearn.preprocessing import LabelEncoder
5
  from transformers import BertTokenizer, RobertaTokenizer, DebertaTokenizer
6
  import pickle
7
  import os
8
 
9
+ from config import (
10
+ TEXT_COLUMN,
11
+ LABEL_COLUMNS,
12
+ MAX_LEN,
13
+ TOKENIZER_PATH,
14
+ LABEL_ENCODERS_PATH,
15
+ METADATA_COLUMNS,
16
+ MODEL_NAME # ✅ Add this in your config.py: MODEL_NAME = "roberta-base"
17
+ )
18
 
19
  class ComplianceDataset(Dataset):
20
  def __init__(self, texts, labels, tokenizer, max_len):
 
77
  data[col] = label_encoders[col].fit_transform(data[col])
78
  return data, label_encoders
79
 
80
+ def get_tokenizer(model_name=MODEL_NAME):
81
+ model_name = model_name or "roberta-base" # fallback
82
+ if "deberta" in model_name.lower():
83
+ return DebertaTokenizer.from_pretrained(model_name, cache_dir=TOKENIZER_PATH)
84
  elif "roberta" in model_name.lower():
85
+ return RobertaTokenizer.from_pretrained(model_name, cache_dir=TOKENIZER_PATH)
86
+ elif "bert" in model_name.lower():
87
+ return BertTokenizer.from_pretrained(model_name, cache_dir=TOKENIZER_PATH)
88
  else:
89
  raise ValueError(f"Unsupported tokenizer for model: {model_name}")
90