cellanalyzer / utils.py
Bharatmali999's picture
Create utils.py
8fb5ee3 verified
raw
history blame
1.26 kB
import torch
from torchvision.models.detection import FasterRCNN
from collections import defaultdict
def collate_fn(batch):
return tuple(zip(*batch))
def train_one_epoch(model, optimizer, data_loader, device, epoch):
model.train()
metric_logger = defaultdict(float)
for images, targets in data_loader:
images = [image.to(device) for image in images]
targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
# Zero the gradients
optimizer.zero_grad()
# Forward pass
loss_dict = model(images, targets)
losses = sum(loss for loss in loss_dict.values())
# Backpropagation
losses.backward()
optimizer.step()
# Log training progress
metric_logger["loss"] += losses.item()
def evaluate(model, data_loader, device):
model.eval()
for images, targets in data_loader:
images = [image.to(device) for image in images]
targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
with torch.no_grad():
predictions = model(images)
# Evaluate predictions (e.g., IOU, detection accuracy)
# Use suitable evaluation metrics for your dataset