""" PyTorch Dataset 类:用于加载敏感词过滤训练数据 """ import csv import torch from torch.utils.data import Dataset from transformers import BertTokenizer class SensitiveWordDataset(Dataset): """敏感词文本分类数据集""" def __init__( self, csv_path: str, tokenizer: BertTokenizer, max_length: int = 128, ): self.tokenizer = tokenizer self.max_length = max_length self.texts = [] self.labels = [] with open(csv_path, "r", encoding="utf-8") as f: reader = csv.DictReader(f) for row in reader: self.texts.append(row["text"]) self.labels.append(int(row["label"])) def __len__(self): return len(self.texts) def __getitem__(self, idx): text = self.texts[idx] label = self.labels[idx] encoding = self.tokenizer( text, padding="max_length", truncation=True, max_length=self.max_length, return_tensors="pt", ) return { "input_ids": encoding["input_ids"].squeeze(0), "attention_mask": encoding["attention_mask"].squeeze(0), "label": torch.tensor(label, dtype=torch.long), }