Spaces:
Sleeping
Sleeping
| import math | |
| import sys | |
| import torch | |
| from torch.utils.data import DataLoader | |
| from tqdm import tqdm | |
| from dataset import VinDrCXRBoxesDataset, collate_fn | |
| from model import get_detection_model | |
| def train_one_epoch(model, optimizer, data_loader, device, epoch): | |
| model.train() | |
| running_loss = 0.0 | |
| loop = tqdm(data_loader, desc=f"Detection Training Epoch {epoch}") | |
| for i, (images, targets) in enumerate(loop): | |
| images = list(image.to(device) for image in images) | |
| targets = [{k: v.to(device) for k, v in t.items()} for t in targets] | |
| loss_dict = model(images, targets) | |
| losses = sum(loss for loss in loss_dict.values()) | |
| loss_value = losses.item() | |
| if not math.isfinite(loss_value): | |
| print(f"Loss is {loss_value}, stopping training") | |
| sys.exit(1) | |
| optimizer.zero_grad() | |
| losses.backward() | |
| optimizer.step() | |
| running_loss += loss_value | |
| loop.set_postfix(loss=loss_value) | |
| if (i + 1) % 10 == 0: | |
| print(f"Epoch {epoch} - Batch {i+1}/{len(data_loader)} - Loss: {loss_value:.4f}", flush=True) | |
| return running_loss / len(data_loader) | |
| def main(): | |
| # Force CPU for detection as MPS is currently unstable for Faster R-CNN on this machine | |
| device = torch.device('cpu') | |
| print(f"Using device: {device}") | |
| # 14 classes + 1 No Finding + 1 background = 16 | |
| num_classes = 16 | |
| train_csv = './data/train.csv' | |
| img_dir = './data/images' | |
| try: | |
| dataset = VinDrCXRBoxesDataset(train_csv, img_dir, transform=None) | |
| except FileNotFoundError: | |
| print("Data files not found. Skipping dataset initialization for demo purposes.") | |
| return | |
| data_loader = DataLoader( | |
| dataset, batch_size=4, shuffle=True, num_workers=0, | |
| collate_fn=collate_fn | |
| ) | |
| model = get_detection_model(num_classes).to(device) | |
| params = [p for p in model.parameters() if p.requires_grad] | |
| optimizer = torch.optim.SGD(params, lr=0.001, momentum=0.9, weight_decay=0.0005) | |
| # Optional learning rate scheduler | |
| lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1) | |
| num_epochs = 10 | |
| best_loss = float('inf') | |
| for epoch in range(num_epochs): | |
| epoch_loss = train_one_epoch(model, optimizer, data_loader, device, epoch) | |
| lr_scheduler.step() | |
| if epoch_loss < best_loss: | |
| best_loss = epoch_loss | |
| torch.save(model.state_dict(), "best_faster_rcnn_detection.pth") | |
| print("Saved Best Detection Model!") | |
| if __name__ == '__main__': | |
| main() | |