ImageTrust-AI / src /models /train_generator.py
SiemonCha's picture
add generator type detection
dda3973
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torchvision import models
from sklearn.metrics import classification_report, confusion_matrix
from src.data.generator_loader import get_generator_dataloaders, CLASS_NAMES
def build_multiclass_model(num_classes=4, pretrained=True):
model = models.resnet18(weights="IMAGENET1K_V1" if pretrained else None)
for param in model.parameters():
param.requires_grad = False
# Unfreeze layer4
for name, param in model.named_parameters():
if "layer4" in name or "fc" in name:
param.requires_grad = True
in_features = model.fc.in_features
model.fc = nn.Sequential(
nn.Dropout(0.5),
nn.Linear(in_features, 256),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(256, num_classes)
)
return model
def evaluate(model, loader, device):
model.eval()
all_preds, all_labels = [], []
with torch.no_grad():
for images, labels in loader:
images = images.to(device)
labels = labels.to(device)
outputs = model(images)
preds = outputs.argmax(dim=1)
all_preds.extend(preds.cpu().numpy())
all_labels.extend(labels.cpu().numpy())
return all_preds, all_labels
def train(epochs=15, batch_size=32, lr=1e-4):
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")
train_loader, val_loader, test_loader = get_generator_dataloaders(batch_size=batch_size)
model = build_multiclass_model().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=lr)
scheduler = ReduceLROnPlateau(optimizer, patience=2)
best_val_acc = 0
early_stop_patience = 3
no_improve_count = 0
for epoch in range(epochs):
model.train()
train_loss, correct, total = 0, 0, 0
for images, labels in train_loader:
images = images.to(device)
labels = labels.to(device)
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
train_loss += loss.item()
preds = outputs.argmax(dim=1)
correct += (preds == labels).sum().item()
total += labels.size(0)
train_acc = correct / total
avg_train_loss = train_loss / len(train_loader)
val_preds, val_labels = evaluate(model, val_loader, device)
val_acc = sum(p == l for p, l in zip(val_preds, val_labels)) / len(val_labels)
scheduler.step(1 - val_acc)
print(f"Epoch {epoch+1}/{epochs} | "
f"Train Loss: {avg_train_loss:.4f} | Train Acc: {train_acc:.4f} | "
f"Val Acc: {val_acc:.4f}")
if val_acc > best_val_acc:
best_val_acc = val_acc
no_improve_count = 0
torch.save(model.state_dict(), "saved_models/generator_model.pth")
print(f" -> Best model saved")
else:
no_improve_count += 1
if no_improve_count >= early_stop_patience:
print(f"Early stopping at epoch {epoch+1}")
break
# Final evaluation
print("\n--- Final Evaluation ---")
test_preds, test_labels = evaluate(model, test_loader, device)
test_acc = sum(p == l for p, l in zip(test_preds, test_labels)) / len(test_labels)
print(f"Test Accuracy: {test_acc:.4f}")
print("\nClassification Report:")
print(classification_report(test_labels, test_preds,
target_names=list(CLASS_NAMES.values())))
print("Confusion Matrix:")
print(confusion_matrix(test_labels, test_preds))
if __name__ == "__main__":
train()