ModelSmith-AI / backend /nlp /trainers.py
ACA050's picture
Upload 79 files
a309487 verified
import torch
from torch.utils.data import Dataset
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments
from ..nlp.preprocess import NLPPreprocessor
class TextDataset(Dataset):
def __init__(self, encodings, labels):
self.encodings = encodings
self.labels = labels
def __getitem__(self, idx):
item = {key: val[idx] for key, val in self.encodings.items()}
item['labels'] = torch.tensor(self.labels[idx])
return item
def __len__(self):
return len(self.labels)
class TextClassifier:
def __init__(self, model_name="distilbert-base-uncased", num_labels=2):
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=num_labels)
self.preprocessor = NLPPreprocessor()
def train(self, texts, labels):
# Clean texts
texts = [self.preprocessor.clean(text) for text in texts]
# Tokenize
encodings = self.tokenizer(texts, truncation=True, padding=True, max_length=512, return_tensors="pt")
# Create dataset
dataset = TextDataset(encodings, labels)
# Training arguments
training_args = TrainingArguments(
output_dir='./results',
num_train_epochs=3,
per_device_train_batch_size=16,
per_device_eval_batch_size=64,
warmup_steps=500,
weight_decay=0.01,
logging_dir='./logs',
logging_steps=10,
save_steps=500,
evaluation_strategy="no",
save_strategy="no",
)
trainer = Trainer(
model=self.model,
args=training_args,
train_dataset=dataset,
)
trainer.train()
return self.model