import torch from torch.utils.data import Dataset from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments from ..nlp.preprocess import NLPPreprocessor class TextDataset(Dataset): def __init__(self, encodings, labels): self.encodings = encodings self.labels = labels def __getitem__(self, idx): item = {key: val[idx] for key, val in self.encodings.items()} item['labels'] = torch.tensor(self.labels[idx]) return item def __len__(self): return len(self.labels) class TextClassifier: def __init__(self, model_name="distilbert-base-uncased", num_labels=2): self.tokenizer = AutoTokenizer.from_pretrained(model_name) self.model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=num_labels) self.preprocessor = NLPPreprocessor() def train(self, texts, labels): # Clean texts texts = [self.preprocessor.clean(text) for text in texts] # Tokenize encodings = self.tokenizer(texts, truncation=True, padding=True, max_length=512, return_tensors="pt") # Create dataset dataset = TextDataset(encodings, labels) # Training arguments training_args = TrainingArguments( output_dir='./results', num_train_epochs=3, per_device_train_batch_size=16, per_device_eval_batch_size=64, warmup_steps=500, weight_decay=0.01, logging_dir='./logs', logging_steps=10, save_steps=500, evaluation_strategy="no", save_strategy="no", ) trainer = Trainer( model=self.model, args=training_args, train_dataset=dataset, ) trainer.train() return self.model