cellanalyzer / train.py
Bharatmali999's picture
Create train.py
1c6bf70 verified
raw
history blame
2.78 kB
import torch
import torchvision
from torch.utils.data import DataLoader
from torchvision import transforms
import numpy as np
import json
import cv2
from PIL import Image
from engine import train_one_epoch, evaluate
import utils
from torchvision.models.detection import maskrcnn_resnet50_fpn
# Define a custom dataset class for COCO-like annotations
class CellDataset(torch.utils.data.Dataset):
def __init__(self, json_file, root_dir, transforms=None):
with open(json_file) as f:
self.annotations = json.load(f)
self.root_dir = root_dir
self.transforms = transforms
self.images = self.annotations['images']
self.annotations_data = self.annotations['annotations']
def __getitem__(self, idx):
image_info = self.images[idx]
img_path = f"{self.root_dir}/{image_info['file_name']}"
img = Image.open(img_path).convert("RGB")
width, height = img.size
# Get annotations for this image
annotations = [ann for ann in self.annotations_data if ann['image_id'] == image_info['id']]
boxes = []
labels = []
for ann in annotations:
x_min, y_min, width, height = ann['bbox']
boxes.append([x_min, y_min, x_min + width, y_min + height])
labels.append(1) # Label 1 for 'cell'
boxes = torch.as_tensor(boxes, dtype=torch.float32)
labels = torch.as_tensor(labels, dtype=torch.int64)
target = {'boxes': boxes, 'labels': labels}
if self.transforms:
img = self.transforms(img)
return img, target
def __len__(self):
return len(self.images)
# Define transformations
transform = transforms.Compose([
transforms.ToTensor(),
])
# Load dataset
train_dataset = CellDataset('annotations.json', 'images', transforms=transform)
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, collate_fn=utils.collate_fn)
# Define model
model = maskrcnn_resnet50_fpn(pretrained=True)
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = torchvision.models.detection.faster_rcnn.FastRCNNPredictor(in_features, 2) # 2 classes (background, cell)
# Move model to GPU if available
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model.to(device)
# Define optimizer
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=0.005, momentum=0.9, weight_decay=0.0005)
# Training loop
for epoch in range(10): # 10 epochs, adjust as needed
train_one_epoch(model, optimizer, train_loader, device, epoch)
evaluate(model, train_loader, device)
# Save the trained model
torch.save(model.state_dict(), "cell_detection_model.pth")