crackrammer's picture
Upload folder using huggingface_hub
db60e24 verified
"""
PyTorch Dataset 类:用于加载敏感词过滤训练数据
"""
import csv
import torch
from torch.utils.data import Dataset
from transformers import BertTokenizer
class SensitiveWordDataset(Dataset):
"""敏感词文本分类数据集"""
def __init__(
self,
csv_path: str,
tokenizer: BertTokenizer,
max_length: int = 128,
):
self.tokenizer = tokenizer
self.max_length = max_length
self.texts = []
self.labels = []
with open(csv_path, "r", encoding="utf-8") as f:
reader = csv.DictReader(f)
for row in reader:
self.texts.append(row["text"])
self.labels.append(int(row["label"]))
def __len__(self):
return len(self.texts)
def __getitem__(self, idx):
text = self.texts[idx]
label = self.labels[idx]
encoding = self.tokenizer(
text,
padding="max_length",
truncation=True,
max_length=self.max_length,
return_tensors="pt",
)
return {
"input_ids": encoding["input_ids"].squeeze(0),
"attention_mask": encoding["attention_mask"].squeeze(0),
"label": torch.tensor(label, dtype=torch.long),
}