Spaces:
Sleeping
Sleeping
| import torch | |
| import torchvision | |
| from torch.utils.data import DataLoader | |
| from torchvision import transforms | |
| import numpy as np | |
| import json | |
| import cv2 | |
| from PIL import Image | |
| from engine import train_one_epoch, evaluate | |
| import utils | |
| from torchvision.models.detection import maskrcnn_resnet50_fpn | |
| # Define a custom dataset class for COCO-like annotations | |
| class CellDataset(torch.utils.data.Dataset): | |
| def __init__(self, json_file, root_dir, transforms=None): | |
| with open(json_file) as f: | |
| self.annotations = json.load(f) | |
| self.root_dir = root_dir | |
| self.transforms = transforms | |
| self.images = self.annotations['images'] | |
| self.annotations_data = self.annotations['annotations'] | |
| def __getitem__(self, idx): | |
| image_info = self.images[idx] | |
| img_path = f"{self.root_dir}/{image_info['file_name']}" | |
| img = Image.open(img_path).convert("RGB") | |
| width, height = img.size | |
| # Get annotations for this image | |
| annotations = [ann for ann in self.annotations_data if ann['image_id'] == image_info['id']] | |
| boxes = [] | |
| labels = [] | |
| for ann in annotations: | |
| x_min, y_min, width, height = ann['bbox'] | |
| boxes.append([x_min, y_min, x_min + width, y_min + height]) | |
| labels.append(1) # Label 1 for 'cell' | |
| boxes = torch.as_tensor(boxes, dtype=torch.float32) | |
| labels = torch.as_tensor(labels, dtype=torch.int64) | |
| target = {'boxes': boxes, 'labels': labels} | |
| if self.transforms: | |
| img = self.transforms(img) | |
| return img, target | |
| def __len__(self): | |
| return len(self.images) | |
| # Define transformations | |
| transform = transforms.Compose([ | |
| transforms.ToTensor(), | |
| ]) | |
| # Load dataset | |
| train_dataset = CellDataset('annotations.json', 'images', transforms=transform) | |
| train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, collate_fn=utils.collate_fn) | |
| # Define model | |
| model = maskrcnn_resnet50_fpn(pretrained=True) | |
| in_features = model.roi_heads.box_predictor.cls_score.in_features | |
| model.roi_heads.box_predictor = torchvision.models.detection.faster_rcnn.FastRCNNPredictor(in_features, 2) # 2 classes (background, cell) | |
| # Move model to GPU if available | |
| device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") | |
| model.to(device) | |
| # Define optimizer | |
| params = [p for p in model.parameters() if p.requires_grad] | |
| optimizer = torch.optim.SGD(params, lr=0.005, momentum=0.9, weight_decay=0.0005) | |
| # Training loop | |
| for epoch in range(10): # 10 epochs, adjust as needed | |
| train_one_epoch(model, optimizer, train_loader, device, epoch) | |
| evaluate(model, train_loader, device) | |
| # Save the trained model | |
| torch.save(model.state_dict(), "cell_detection_model.pth") | |