Prudhvinath07's picture
added all files
dec266f
raw
history blame
3.36 kB
import torch
from torch.utils.data import DataLoader
from typing import Dict, List
from tqdm import tqdm
from torch.amp import autocast, GradScaler
class ModelTrainer:
def __init__(self, model, optimizer, criterion, device, scaler: GradScaler = None, scheduler=None):
self.model = model
self.optimizer = optimizer
self.criterion = criterion
self.device = device
self.scaler = scaler or GradScaler('cuda')
self.use_amp = device.type == 'cuda'
self.scheduler = scheduler
def train_epoch(self, dataloader: DataLoader) -> Dict[str, float]:
self.model.train()
total_loss = 0
for batch in tqdm(dataloader, desc="Training"):
input_ids = batch['input_ids'].to(self.device)
attention_mask = batch['attention_mask'].to(self.device)
labels = batch['labels'].to(self.device)
self.optimizer.zero_grad()
if self.use_amp:
with autocast('cuda'):
outputs = self.model(input_ids, attention_mask)
loss = self.criterion(outputs, labels)
self.scaler.scale(loss).backward()
# Clip gradients
self.scaler.unscale_(self.optimizer)
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
self.scaler.step(self.optimizer)
self.scaler.update()
else:
outputs = self.model(input_ids, attention_mask)
loss = self.criterion(outputs, labels)
loss.backward()
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
self.optimizer.step()
if self.scheduler is not None:
self.scheduler.step()
total_loss += loss.item()
return {'loss': total_loss / len(dataloader)}
def evaluate(self, dataloader: DataLoader) -> Dict[str, float]:
self.model.eval()
total_loss = 0
predictions = []
true_labels = []
with torch.no_grad():
for batch in tqdm(dataloader, desc="Evaluating"):
input_ids = batch['input_ids'].to(self.device)
attention_mask = batch['attention_mask'].to(self.device)
labels = batch['labels'].to(self.device)
if self.use_amp:
with autocast('cuda'):
outputs = self.model(input_ids, attention_mask)
loss = self.criterion(outputs, labels)
else:
outputs = self.model(input_ids, attention_mask)
loss = self.criterion(outputs, labels)
# Apply sigmoid to get probabilities for predictions
probs = torch.sigmoid(outputs)
total_loss += loss.item()
predictions.extend(probs.cpu().numpy())
true_labels.extend(labels.cpu().numpy())
return {
'loss': total_loss / len(dataloader),
'predictions': predictions,
'true_labels': true_labels
}