File size: 3,361 Bytes
dec266f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 |
import torch
from torch.utils.data import DataLoader
from typing import Dict, List
from tqdm import tqdm
from torch.amp import autocast, GradScaler
class ModelTrainer:
def __init__(self, model, optimizer, criterion, device, scaler: GradScaler = None, scheduler=None):
self.model = model
self.optimizer = optimizer
self.criterion = criterion
self.device = device
self.scaler = scaler or GradScaler('cuda')
self.use_amp = device.type == 'cuda'
self.scheduler = scheduler
def train_epoch(self, dataloader: DataLoader) -> Dict[str, float]:
self.model.train()
total_loss = 0
for batch in tqdm(dataloader, desc="Training"):
input_ids = batch['input_ids'].to(self.device)
attention_mask = batch['attention_mask'].to(self.device)
labels = batch['labels'].to(self.device)
self.optimizer.zero_grad()
if self.use_amp:
with autocast('cuda'):
outputs = self.model(input_ids, attention_mask)
loss = self.criterion(outputs, labels)
self.scaler.scale(loss).backward()
# Clip gradients
self.scaler.unscale_(self.optimizer)
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
self.scaler.step(self.optimizer)
self.scaler.update()
else:
outputs = self.model(input_ids, attention_mask)
loss = self.criterion(outputs, labels)
loss.backward()
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
self.optimizer.step()
if self.scheduler is not None:
self.scheduler.step()
total_loss += loss.item()
return {'loss': total_loss / len(dataloader)}
def evaluate(self, dataloader: DataLoader) -> Dict[str, float]:
self.model.eval()
total_loss = 0
predictions = []
true_labels = []
with torch.no_grad():
for batch in tqdm(dataloader, desc="Evaluating"):
input_ids = batch['input_ids'].to(self.device)
attention_mask = batch['attention_mask'].to(self.device)
labels = batch['labels'].to(self.device)
if self.use_amp:
with autocast('cuda'):
outputs = self.model(input_ids, attention_mask)
loss = self.criterion(outputs, labels)
else:
outputs = self.model(input_ids, attention_mask)
loss = self.criterion(outputs, labels)
# Apply sigmoid to get probabilities for predictions
probs = torch.sigmoid(outputs)
total_loss += loss.item()
predictions.extend(probs.cpu().numpy())
true_labels.extend(labels.cpu().numpy())
return {
'loss': total_loss / len(dataloader),
'predictions': predictions,
'true_labels': true_labels
} |