terrain-generation / train_heightmap.py
flappybird1084's picture
add initial files
6caa466
import torch
import torch.optim as optim
import torch.nn as nn
from util.unet import UNet
import torchvision.transforms as transforms
import util.dataset as ds
from torch.utils.data import random_split
from torch.utils.data import DataLoader
import torchvision.models as models
# change for your own dataset path.
# dataset: https://www.kaggle.com/datasets/tpapp157/earth-terrain-height-and-segmentation-map-images
dataset_path = "../../Other/cosmos/data/terrain_reconstruction/_dataset/"
transform_pipeline = transforms.Compose([
transforms.Resize((128, 128)),
transforms.ToTensor(),
# transforms.Normalize(mean=[0.5], std=[0.5]),
# transforms.Normalize(mean=[0.485, 0.456, 0.406],
# std=[0.229, 0.224, 0.225])
])
dataset = ds.TerrainDataset(dataset_path, transform=transform_pipeline)
# Example: 80% train, 20% test
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
dataset_train, dataset_test = random_split(dataset, [train_size, test_size])
# from unet import UNet
device = torch.device("mps" if torch.backends.mps.is_available(
) else "cuda" if torch.cuda.is_available() else "cpu")
# initialize dataloaders
numworkers = 0
batchsize = 8
train_loader = DataLoader(
dataset_train, batch_size=batchsize, shuffle=True, num_workers=numworkers)
test_loader = DataLoader(dataset_test, batch_size=batchsize,
shuffle=False, num_workers=numworkers)
class PerceptualLoss(nn.Module):
def __init__(self, feature_layer=9):
super(PerceptualLoss, self).__init__()
vgg = models.vgg16(
weights=models.VGG16_Weights.DEFAULT).features[:feature_layer].eval()
for param in vgg.parameters():
param.requires_grad = False
self.vgg = vgg.to(device)
self.transform = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
def forward(self, pred, target):
pred = self.transform(pred)
target = self.transform(target)
return nn.functional.mse_loss(self.vgg(pred), self.vgg(target))
def total_variation_loss(x):
return torch.mean(torch.abs(x[:, :, :, :-1] - x[:, :, :, 1:])) + \
torch.mean(torch.abs(x[:, :, :-1, :] - x[:, :, 1:, :]))
unet_model = UNet(in_channels=3, out_channels=1, use_sigmoid=False, features=[
64, 128, 256, 512, 1024]).to(device)
mse_loss = nn.MSELoss()
perceptual_loss = PerceptualLoss().to(device)
perceptual_loss_scaling_factor = 0.1
optimizer = optim.Adam(unet_model.parameters(), lr=0.001)
# unet_model.load_state_dict(torch.load('./models/terrain/heightmap_unet_model.pth'))
num_epochs = 5
for epoch in range(num_epochs):
unet_model.train()
running_loss = 0.0
for i, (height, terrain, segmentation) in enumerate(train_loader):
images = segmentation
images = images.to(device).float()
target_images = height
target_images = target_images.to(device).float()
# Forward pass
outputs = unet_model(images)
# print(f"Outputs shape: {outputs.shape}, Target shape: {target_images.shape}")
# print(f"outputs {outputs}")
# print(f"target_images {target_images}")
# loss = criterion(outputs, target_images)
# Convert [B, 1, H, W] → [B, 3, H, W]
outputs_rgb = outputs.repeat(1, 3, 1, 1)
targets_rgb = target_images.repeat(1, 3, 1, 1)
# loss = mse_loss(outputs/65535, target_images/65535) + perceptual_loss(outputs/65535, target_images/65535) * perceptual_loss_scaling_factor
tv_weight = 1e-6
loss = (mse_loss(outputs/65535, target_images/65535) + perceptual_loss_scaling_factor *
perceptual_loss(outputs_rgb/65535, targets_rgb/65535) + tv_weight * total_variation_loss(outputs/65535))
# TODO: ADD PERCEPTUAL LOSS
running_loss += loss.item()
# Backward pass and optimization
optimizer.zero_grad()
loss.backward()
optimizer.step()
if (i + 1) % 10 == 0:
print('Epoch ', (epoch + 1/num_epochs), "Step",
((i + 1)/len(train_loader)), "Loss:", (loss.item()))
torch.save(unet_model.state_dict(),
'./models/terrain/turbo_heightmap_unet_model.pth')
print("Model saved to './models/terrain/turbo_heightmap_unet_model.pth'")