Mittalyash's picture
Upload folder using huggingface_hub
9916246 verified
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from tqdm import tqdm
from dataset import VinDrCXRClassificationDataset, get_train_val_transforms
from model import CliniScanClassifier
from sklearn.metrics import roc_auc_score
def train_one_epoch(model, loader, optimizer, criterion, device):
model.train()
running_loss = 0.0
loop = tqdm(loader, desc="Training Epoch")
for images, targets in loop:
images, targets = images.to(device), targets.to(device)
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
running_loss += loss.item() * images.size(0)
loop.set_postfix(loss=loss.item())
return running_loss / len(loader.dataset)
@torch.no_grad()
def evaluate(model, loader, criterion, device):
model.eval()
running_loss = 0.0
all_targets = []
all_outputs = []
for images, targets in tqdm(loader, desc="Evaluating"):
images, targets = images.to(device), targets.to(device)
outputs = model(images)
loss = criterion(outputs, targets)
running_loss += loss.item() * images.size(0)
all_targets.append(targets.cpu())
all_outputs.append(torch.sigmoid(outputs).cpu())
all_targets = torch.cat(all_targets).numpy()
all_outputs = torch.cat(all_outputs).numpy()
# Calculate Macro AUC
try:
auc = roc_auc_score(all_targets, all_outputs, average='macro', multi_class='ovr')
except ValueError:
auc = 0.0 # handles edge case where batch has only one class
return running_loss / len(loader.dataset), auc
def main():
if torch.backends.mps.is_available():
device = torch.device('mps')
elif torch.cuda.is_available():
device = torch.device('cuda')
else:
device = torch.device('cpu')
print(f"Using device: {device}")
# Dummy paths for local
train_csv = './data/train.csv'
val_csv = './data/train.csv' # For demonstration, reusing same dataset for val
img_dir = './data/images'
train_transform, val_transform = get_train_val_transforms()
try:
train_dataset = VinDrCXRClassificationDataset(train_csv, img_dir, transform=train_transform)
val_dataset = VinDrCXRClassificationDataset(val_csv, img_dir, transform=val_transform)
except FileNotFoundError:
print("Data files not found. Skipping dataset initialization for demo purposes.")
return
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4)
model = CliniScanClassifier(num_classes=15).to(device)
# Using BCEWithLogitsLoss for Multi-Label Classification
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
epochs = 10
best_auc = 0.0
for epoch in range(epochs):
train_loss = train_one_epoch(model, train_loader, optimizer, criterion, device)
val_loss, val_auc = evaluate(model, val_loader, criterion, device)
print(f"Epoch {epoch+1}/{epochs} - Train Loss: {train_loss:.4f} - Val Loss: {val_loss:.4f} - Val AUC: {val_auc:.4f}")
if val_auc > best_auc:
best_auc = val_auc
torch.save(model.state_dict(), "best_resnet_classification.pth")
print("Saved Best Model!")
if __name__ == '__main__':
main()