File size: 2,281 Bytes
c843d82
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
# imports
import numpy as np
from cellpose import models, core, io, plot, train
from pathlib import Path
from tqdm import trange
import matplotlib.pyplot as plt

io.logger_setup() # run this to get printing of progress

train_dir = "/mnt/WorkingDos/cellpose_sam/8_hdrg_jayden_dataset_data"
model_name = "cp_sam_hdrg_topoint_model"

def train_cp_sam_model(train_dir, model_name, n_epochs=100, learning_rate=1e-5, weight_decay=0.1, batch_size=1):
  """
  Train a Cellpose model using the SAM (Segment Anything) algorithm.

  Args:
    train_dir (str): Path to the directory containing the training data.
    model_name (str): Name of the model to be trained.
    n_epochs (int): Number of epochs to train the model.
    learning_rate (float): Learning rate for the optimizer.
    weight_decay (float): Weight decay for the optimizer.
    batch_size (int): Batch size for training.

  Returns:
    None
    """
  # Check if colab notebook instance has GPU access
  if core.use_gpu()==False:raise ImportError("No GPU access, change your runtime")

  model = models.CellposeModel(gpu=True)

  if not Path(train_dir).exists():raise FileNotFoundError("directory does not exist")

  test_dir = None # optionally you can specify a directory with test files

  # *** change to your mask extension ***
  # masks_ext = "_seg.npy"
  # ^ assumes images from Cellpose GUI, if labels are tiffs, then "_masks.tif"

  # list all files
  masks_ext = "_masks"
  files = [f for f in Path(train_dir).glob("*") if "_masks" not in f.name and "_flows" not in f.name and "_seg" not in f.name]

  if(len(files)==0):raise FileNotFoundError("no files found, did you specify the correct folder and extension?")
  else:print(f"{len(files)} files in folder:")

  output = io.load_train_test_data(train_dir, test_dir, mask_filter=masks_ext)
  train_data, train_labels, _, test_data, test_labels, _ = output
  new_model_path, train_losses, test_losses = train.train_seg(model.net, train_data=train_data, train_labels=train_labels, batch_size=batch_size, n_epochs=n_epochs, learning_rate=learning_rate, weight_decay=weight_decay, nimg_per_epoch=max(2, len(train_data)), model_name=model_name)


if __name__ == "__main__":
  train_cp_sam_model(train_dir, model_name, n_epochs, learning_rate, weight_decay, batch_size)