zy7_oldserver
1
fd601de
import torch
import argparse
from synthrad_conversion.utils.my_configs_yacs import init_cfg
import os
import shutil
from dataprocesser.step1_init_data_list import init_dataset
from synthrad_conversion.networks.launch_model import launch_model
# python train_3d.py --config ./configs/newserver/0510_test3d.yaml
import subprocess
import sys
from torch.multiprocessing import Process
import torch.distributed as dist
def install_and_check(package):
try:
__import__(package)
print(f"'{package}' is already installed.")
except ImportError:
print(f"'{package}' not found. Installing...")
subprocess.check_call([sys.executable, "-m", "pip", "install", package])
def check_neccessary_packages():
packages = ['numpy', 'pandas', 'matplotlib'] # Add your packages here
for package in packages:
install_and_check(package)
def cleanup():
dist.destroy_process_group()
def setup(rank, world_size, using_torchrun=True):
if not using_torchrun:
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = '12355'
dist.init_process_group("nccl", rank=rank, world_size=world_size)
import platform
def is_linux():
return platform.system().lower() == "linux"
def run(input_args=None, config = './configs/sample.yaml', dataset_name = 'combined_simplified_csv_seg_assigned', data_dir = 'E:\Projects\yang_proj\data\seg2med', **kargs):
VERBOSE = False
import os
parser = argparse.ArgumentParser(description="StyleGAN pytorch implementation.")
parser.add_argument('--config', default=config)
parser.add_argument('--data_dir', default=data_dir, help='data directory')
parser.add_argument('--loss_type', type=str, default=None, help='Contrastive loss type: cossim, nt_xent, or cossim_ntxent')
parser.add_argument('--batch_size', type=int, default=None, help='batch size used for training')
parser.add_argument('--GPU_ID', default=[0])
#_, cyclegan_input_args = parser.parse_known_args(input_args)
args, remaining_args = parser.parse_known_args(input_args)
#args = parser.parse_args(input_args)
opt=init_cfg(args.config)
if args.data_dir is not None and os.path.exists(args.data_dir):
opt.dataset.data_dir=args.data_dir
else:
opt.dataset.data_dir=None
if VERBOSE:
print(opt)
# decode kargs
# Handle loss type from either argparse or kwargs
if args.loss_type is not None:
opt.train.loss = args.loss_type
elif "loss_type" in kargs:
opt.train.loss = kargs["loss_type"]
else:
opt.train.loss = opt.train.loss # default fallback
if args.batch_size is not None:
opt.dataset.batch_size = args.batch_size
elif "batch_size" in kargs:
opt.dataset.batch_size = kargs["batch_size"]
else:
opt.dataset.batch_size = opt.dataset.batch_size
if args.GPU_ID is not None:
opt.dataset.GPU_ID = args.GPU_ID
elif "GPU_ID" in kargs:
opt.dataset.GPU_ID = kargs["GPU_ID"]
else:
opt.dataset.GPU_ID = opt.dataset.GPU_ID
print("##### training using batch size:", opt.dataset.batch_size)
print("##### training using loss:", opt.train.loss)
print("##### training using GPU:", opt.dataset.GPU_ID)
mode = opt.mode
if mode=='train':
model_name_path=opt.model_name + opt.name_prefix
elif mode == 'test':
model_name_path='Infer_'+opt.model_name + opt.name_prefix
else:
print('mode not implemented')
model_name_path='Task_'+opt.model_name + opt.name_prefix
config_file = args.config
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
print('given GPU IDs: ', opt.GPU_ID)
islinux = is_linux()
if islinux and torch.cuda.device_count() > 1:
print("🟢 Detected Linux with multiple GPUs — using DDP...")
world_size = torch.cuda.device_count()
opt.is_ddp = True
opt.rank = int(os.environ["LOCAL_RANK"])
opt.world_size = world_size
setup(opt.rank, world_size)
# 打印当前进程使用的 GPU 名称
current_gpu = torch.cuda.current_device()
print(f"🧠 Current GPU [rank {opt.rank})]: {torch.cuda.get_device_name(current_gpu)}")
else:
print("🟡 Using single-GPU training (Windows or single GPU)...")
opt.is_ddp = False
opt.rank = 0
opt.world_size = 1
# 打印单 GPU 模式下使用的 GPU 名称
gpu_id = int(opt.GPU_ID[0])
torch.cuda.set_device(gpu_id)
current_gpu = torch.cuda.current_device()
print(f"🧠 Using GPU ID {gpu_id}: {torch.cuda.get_device_name(current_gpu)}")
loader, opt, my_paths = init_dataset(opt, model_name_path, dataset_name)
train_loader = loader.train_loader
val_loader = loader.val_loader
create_folder = True
if create_folder:
os.makedirs(my_paths["saved_logs_folder"], exist_ok=True)
os.makedirs(my_paths["saved_model_folder"], exist_ok=True)
os.makedirs(my_paths["tensorboard_log_dir"], exist_ok=True)
os.makedirs(my_paths["saved_img_folder"], exist_ok=True)
os.makedirs(my_paths["saved_inference_folder"], exist_ok=True)
shutil.copy2(config_file, my_paths["saved_logs_folder"])
launch_model(
model_name=opt.model_name,
opt=opt,
paths=my_paths,
train_loader=train_loader,
val_loader=val_loader,
mode=opt.mode,
#remaining_args=remaining_args
)
if opt.is_ddp:
cleanup()
def initialize_collection(first_data):
collected_patches = []
collected_coords = []
#first_data = next(iter(train_loader))
original_spatial_shape = first_data['original_spatial_shape']
data_patch_0 = first_data['img']
#print(data_patch_0.meta['filename_or_obj'])
volume_shape = tuple(torch.max(dim_shape).item() for dim_shape in original_spatial_shape)
reconstructed_volume = torch.zeros(volume_shape, dtype=data_patch_0.dtype)
print('empty volume_shape:',volume_shape)
# Initialize a volume to keep count of the number of patches added at each location
count_volume = torch.zeros(volume_shape, dtype=torch.int)
return collected_patches, collected_coords, reconstructed_volume, count_volume
def reconstruct_volume(collected_patches, collected_coords, reconstructed_volume, count_volume):
A_data = collected_patches[0]
batch_size = A_data.shape[0]
batch_num = len(collected_patches)
print('batch_num:',batch_num)
for data_idx in range(batch_num):
data = collected_patches[data_idx]
patch_coords = collected_coords[data_idx]
#print(patch_coords)
for batch_idx in range(batch_size):
data_patch_idx = data[batch_idx]
patch_coords_idx = patch_coords[batch_idx]
channel_start, channel_end = patch_coords_idx[0]
x_start, x_end = patch_coords_idx[1]
y_start, y_end = patch_coords_idx[2]
z_start, z_end = patch_coords_idx[3]
# Place the patch in the reconstructed volume
try:
reconstructed_volume[x_start:x_end, y_start:y_end, z_start:z_end] = data_patch_idx[0]
count_volume[x_start:x_end, y_start:y_end, z_start:z_end] = 1
except IndexError as e:
print(f"IndexError: {e} - check patch coordinates and dimensions")
print('patch_coords_idx:',patch_coords_idx)
print('data shape:',data_patch_idx.shape)
print('to fill shape:',reconstructed_volume[x_start:x_end, y_start:y_end, z_start:z_end].shape)
print('check the div_size and patch_size, they should be at least the same')
'''
si_input(B_data[batch_idx])
si_seg(A_data[batch_idx])
grad=gradient_calc(B_data[batch_idx])
si_grad(grad)
'''
# Avoid division by zero
#count_volume = torch.where(count_volume == 0, torch.ones_like(count_volume), count_volume)
# Average out the overlapping areas
#reconstructed_volume = reconstructed_volume / count_volume
return reconstructed_volume, count_volume
def print_data_info(A_data):
print('shape of A',A_data.shape)
print('min,max,mean,std of A',
torch.min(A_data),
torch.max(A_data),
torch.mean(A_data),
torch.std(A_data))
print(f"source image affine:\n{A_data.meta['affine']}")
print(f"source image pixdim:\n{A_data.pixdim}")
# Example of how to reconstruct the image
if __name__ == '__main__':
run()