Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.optim as optim | |
| import torch.nn as nn | |
| from util.unet import UNet | |
| import torchvision.transforms as transforms | |
| import util.dataset as ds | |
| from torch.utils.data import random_split | |
| from torch.utils.data import DataLoader | |
| import torchvision.models as models | |
| # change for your own dataset path. | |
| # dataset: https://www.kaggle.com/datasets/tpapp157/earth-terrain-height-and-segmentation-map-images | |
| dataset_path = "../../Other/cosmos/data/terrain_reconstruction/_dataset/" | |
| transform_pipeline = transforms.Compose([ | |
| transforms.Resize((128, 128)), | |
| transforms.ToTensor(), | |
| # transforms.Normalize(mean=[0.5], std=[0.5]), | |
| # transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
| # std=[0.229, 0.224, 0.225]) | |
| ]) | |
| dataset = ds.TerrainDataset(dataset_path, transform=transform_pipeline) | |
| # Example: 80% train, 20% test | |
| train_size = int(0.8 * len(dataset)) | |
| test_size = len(dataset) - train_size | |
| dataset_train, dataset_test = random_split(dataset, [train_size, test_size]) | |
| # from unet import UNet | |
| device = torch.device("mps" if torch.backends.mps.is_available( | |
| ) else "cuda" if torch.cuda.is_available() else "cpu") | |
| # initialize dataloaders | |
| numworkers = 0 | |
| batchsize = 8 | |
| train_loader = DataLoader( | |
| dataset_train, batch_size=batchsize, shuffle=True, num_workers=numworkers) | |
| test_loader = DataLoader(dataset_test, batch_size=batchsize, | |
| shuffle=False, num_workers=numworkers) | |
| class PerceptualLoss(nn.Module): | |
| def __init__(self, feature_layer=9): | |
| super(PerceptualLoss, self).__init__() | |
| vgg = models.vgg16( | |
| weights=models.VGG16_Weights.DEFAULT).features[:feature_layer].eval() | |
| for param in vgg.parameters(): | |
| param.requires_grad = False | |
| self.vgg = vgg.to(device) | |
| self.transform = transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225]) | |
| def forward(self, pred, target): | |
| pred = self.transform(pred) | |
| target = self.transform(target) | |
| return nn.functional.mse_loss(self.vgg(pred), self.vgg(target)) | |
| def total_variation_loss(x): | |
| return torch.mean(torch.abs(x[:, :, :, :-1] - x[:, :, :, 1:])) + \ | |
| torch.mean(torch.abs(x[:, :, :-1, :] - x[:, :, 1:, :])) | |
| unet_model = UNet(in_channels=3, out_channels=1, use_sigmoid=False, features=[ | |
| 64, 128, 256, 512, 1024]).to(device) | |
| mse_loss = nn.MSELoss() | |
| perceptual_loss = PerceptualLoss().to(device) | |
| perceptual_loss_scaling_factor = 0.1 | |
| optimizer = optim.Adam(unet_model.parameters(), lr=0.001) | |
| # unet_model.load_state_dict(torch.load('./models/terrain/heightmap_unet_model.pth')) | |
| num_epochs = 5 | |
| for epoch in range(num_epochs): | |
| unet_model.train() | |
| running_loss = 0.0 | |
| for i, (height, terrain, segmentation) in enumerate(train_loader): | |
| images = segmentation | |
| images = images.to(device).float() | |
| target_images = height | |
| target_images = target_images.to(device).float() | |
| # Forward pass | |
| outputs = unet_model(images) | |
| # print(f"Outputs shape: {outputs.shape}, Target shape: {target_images.shape}") | |
| # print(f"outputs {outputs}") | |
| # print(f"target_images {target_images}") | |
| # loss = criterion(outputs, target_images) | |
| # Convert [B, 1, H, W] → [B, 3, H, W] | |
| outputs_rgb = outputs.repeat(1, 3, 1, 1) | |
| targets_rgb = target_images.repeat(1, 3, 1, 1) | |
| # loss = mse_loss(outputs/65535, target_images/65535) + perceptual_loss(outputs/65535, target_images/65535) * perceptual_loss_scaling_factor | |
| tv_weight = 1e-6 | |
| loss = (mse_loss(outputs/65535, target_images/65535) + perceptual_loss_scaling_factor * | |
| perceptual_loss(outputs_rgb/65535, targets_rgb/65535) + tv_weight * total_variation_loss(outputs/65535)) | |
| # TODO: ADD PERCEPTUAL LOSS | |
| running_loss += loss.item() | |
| # Backward pass and optimization | |
| optimizer.zero_grad() | |
| loss.backward() | |
| optimizer.step() | |
| if (i + 1) % 10 == 0: | |
| print('Epoch ', (epoch + 1/num_epochs), "Step", | |
| ((i + 1)/len(train_loader)), "Loss:", (loss.item())) | |
| torch.save(unet_model.state_dict(), | |
| './models/terrain/turbo_heightmap_unet_model.pth') | |
| print("Model saved to './models/terrain/turbo_heightmap_unet_model.pth'") | |