mediscan-api / detection /train.py
Mittalyash's picture
Upload folder using huggingface_hub
9916246 verified
import math
import sys
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from dataset import VinDrCXRBoxesDataset, collate_fn
from model import get_detection_model
def train_one_epoch(model, optimizer, data_loader, device, epoch):
model.train()
running_loss = 0.0
loop = tqdm(data_loader, desc=f"Detection Training Epoch {epoch}")
for i, (images, targets) in enumerate(loop):
images = list(image.to(device) for image in images)
targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
loss_dict = model(images, targets)
losses = sum(loss for loss in loss_dict.values())
loss_value = losses.item()
if not math.isfinite(loss_value):
print(f"Loss is {loss_value}, stopping training")
sys.exit(1)
optimizer.zero_grad()
losses.backward()
optimizer.step()
running_loss += loss_value
loop.set_postfix(loss=loss_value)
if (i + 1) % 10 == 0:
print(f"Epoch {epoch} - Batch {i+1}/{len(data_loader)} - Loss: {loss_value:.4f}", flush=True)
return running_loss / len(data_loader)
def main():
# Force CPU for detection as MPS is currently unstable for Faster R-CNN on this machine
device = torch.device('cpu')
print(f"Using device: {device}")
# 14 classes + 1 No Finding + 1 background = 16
num_classes = 16
train_csv = './data/train.csv'
img_dir = './data/images'
try:
dataset = VinDrCXRBoxesDataset(train_csv, img_dir, transform=None)
except FileNotFoundError:
print("Data files not found. Skipping dataset initialization for demo purposes.")
return
data_loader = DataLoader(
dataset, batch_size=4, shuffle=True, num_workers=0,
collate_fn=collate_fn
)
model = get_detection_model(num_classes).to(device)
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=0.001, momentum=0.9, weight_decay=0.0005)
# Optional learning rate scheduler
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1)
num_epochs = 10
best_loss = float('inf')
for epoch in range(num_epochs):
epoch_loss = train_one_epoch(model, optimizer, data_loader, device, epoch)
lr_scheduler.step()
if epoch_loss < best_loss:
best_loss = epoch_loss
torch.save(model.state_dict(), "best_faster_rcnn_detection.pth")
print("Saved Best Detection Model!")
if __name__ == '__main__':
main()