DamageLensAI / src /training /trainer.py
junaid17's picture
Upload 43 files
eef8873 verified
Raw
History Blame Contribute Delete
8.09 kB
import logging
import torch
from tqdm import tqdm
from transformers import get_cosine_schedule_with_warmup
from src.config import CHECKPOINT_DIR
logger = logging.getLogger(__name__)
class EarlyStopping:
def __init__(self, patience=7, min_delta=0.001):
self.patience = patience
self.min_delta = min_delta
self.counter = 0
self.best_score = None
self.early_stop = False
def __call__(self, val_acc):
if self.best_score is None:
self.best_score = val_acc
elif val_acc < self.best_score + self.min_delta:
self.counter += 1
logger.info(
f"EarlyStopping counter: {self.counter}/{self.patience}"
)
if self.counter >= self.patience:
self.early_stop = True
else:
self.best_score = val_acc
self.counter = 0
def train_single_input_model(
model,
train_loader,
eval_loader,
optimizer,
criterion,
device,
epochs,
checkpoint_model_name,
patience=7
):
logger.info("Starting single-input training...")
num_training_steps = epochs * len(train_loader)
num_warmup_steps = int(0.1 * num_training_steps)
scheduler = get_cosine_schedule_with_warmup(
optimizer=optimizer,
num_warmup_steps=num_warmup_steps,
num_training_steps=num_training_steps
)
early_stopping = EarlyStopping(patience=patience)
best_acc = 0.0
all_preds = []
all_labels = []
for epoch in range(epochs):
logger.info(f"Epoch {epoch + 1}/{epochs}")
model.train()
running_loss = 0
correct = 0
total = 0
for images, labels in tqdm(
train_loader,
desc=f"Epoch {epoch+1} Training"
):
images = images.to(device)
labels = labels.to(device)
optimizer.zero_grad(set_to_none=True)
logits = model(images)
loss = criterion(logits, labels)
loss.backward()
optimizer.step()
scheduler.step()
running_loss += loss.item()
preds = torch.argmax(logits, dim=1)
correct += (preds == labels).sum().item()
total += labels.size(0)
train_loss = running_loss / len(train_loader)
train_acc = 100 * correct / total
model.eval()
val_running_loss = 0
val_correct = 0
val_total = 0
all_preds = []
all_labels = []
with torch.no_grad():
for images, labels in tqdm(
eval_loader,
desc=f"Epoch {epoch+1} Validation"
):
images = images.to(device)
labels = labels.to(device)
logits = model(images)
loss = criterion(logits, labels)
val_running_loss += loss.item()
preds = torch.argmax(logits, dim=1)
val_correct += (preds == labels).sum().item()
val_total += labels.size(0)
all_preds.extend(preds.cpu().numpy())
all_labels.extend(labels.cpu().numpy())
val_loss = val_running_loss / len(eval_loader)
val_acc = 100 * val_correct / val_total
logger.info(
f"Train Loss: {train_loss:.4f} | "
f"Train Acc: {train_acc:.2f}% || "
f"Val Loss: {val_loss:.4f} | "
f"Val Acc: {val_acc:.2f}%"
)
if val_acc > best_acc:
best_acc = val_acc
checkpoint_path = CHECKPOINT_DIR / f"{checkpoint_model_name}.pt"
torch.save(
{
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"epoch": epoch,
"val_acc": val_acc
},
checkpoint_path
)
logger.info(f"Best checkpoint saved at: {checkpoint_path}")
early_stopping(val_acc)
if early_stopping.early_stop:
logger.info("Early stopping triggered.")
break
return all_preds, all_labels
def train_dual_input_model(
model,
train_loader,
eval_loader,
optimizer,
criterion,
device,
epochs,
checkpoint_model_name,
patience=7
):
logger.info("Starting dual-input training...")
num_training_steps = epochs * len(train_loader)
num_warmup_steps = int(0.1 * num_training_steps)
scheduler = get_cosine_schedule_with_warmup(
optimizer=optimizer,
num_warmup_steps=num_warmup_steps,
num_training_steps=num_training_steps
)
early_stopping = EarlyStopping(patience=patience)
best_acc = 0.0
all_preds = []
all_labels = []
for epoch in range(epochs):
logger.info(f"Epoch {epoch + 1}/{epochs}")
model.train()
running_loss = 0
correct = 0
total = 0
for batch in tqdm(
train_loader,
desc=f"Epoch {epoch+1} Training"
):
images_eff = batch["pixel_values_eff"].to(device)
images_cnx = batch["pixel_values_cnx"].to(device)
labels = batch["labels"].to(device)
optimizer.zero_grad(set_to_none=True)
logits = model(images_eff, images_cnx)
loss = criterion(logits, labels)
loss.backward()
optimizer.step()
scheduler.step()
running_loss += loss.item()
preds = torch.argmax(logits, dim=1)
correct += (preds == labels).sum().item()
total += labels.size(0)
train_loss = running_loss / len(train_loader)
train_acc = 100 * correct / total
model.eval()
val_running_loss = 0
val_correct = 0
val_total = 0
all_preds = []
all_labels = []
with torch.no_grad():
for batch in tqdm(
eval_loader,
desc=f"Epoch {epoch+1} Validation"
):
images_eff = batch["pixel_values_eff"].to(device)
images_cnx = batch["pixel_values_cnx"].to(device)
labels = batch["labels"].to(device)
logits = model(images_eff, images_cnx)
loss = criterion(logits, labels)
val_running_loss += loss.item()
preds = torch.argmax(logits, dim=1)
val_correct += (preds == labels).sum().item()
val_total += labels.size(0)
all_preds.extend(preds.cpu().numpy())
all_labels.extend(labels.cpu().numpy())
val_loss = val_running_loss / len(eval_loader)
val_acc = 100 * val_correct / val_total
logger.info(
f"Train Loss: {train_loss:.4f} | "
f"Train Acc: {train_acc:.2f}% || "
f"Val Loss: {val_loss:.4f} | "
f"Val Acc: {val_acc:.2f}%"
)
if val_acc > best_acc:
best_acc = val_acc
checkpoint_path = CHECKPOINT_DIR / f"{checkpoint_model_name}.pt"
torch.save(
{
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"epoch": epoch,
"val_acc": val_acc
},
checkpoint_path
)
logger.info(f"Best checkpoint saved at: {checkpoint_path}")
early_stopping(val_acc)
if early_stopping.early_stop:
logger.info("Early stopping triggered.")
break
return all_preds, all_labels
if __name__ == "__main__":
print("Trainer utilities ready.")