{ "cells": [ { "cell_type": "code", "execution_count": 5, "id": "22a0040d", "metadata": {}, "outputs": [], "source": [ "# !uv pip install einops\n", "# !uv pip install torchmetrics\n", "# !uv pip install timm" ] }, { "cell_type": "code", "execution_count": 6, "id": "c7b8e575", "metadata": {}, "outputs": [], "source": [ "import torch\n", "import torch.optim as optim\n", "import torch.nn.functional as F\n", "import numpy as np\n", "from torch.utils.data import DataLoader\n", "\n", "from trainer.models import swinv2Cnn, swinv2Swinv2, unet, unetResnet34, deeplabv3, unetResnet34ASPP\n", "from trainer.trainer import train\n", "from trainer.dataloader import SatellitePatchDataset, GeoAugment\n", "from utils import compute_band_statistics\n", "from tqdm import tqdm\n", "import json\n", "import os" ] }, { "cell_type": "code", "execution_count": null, "id": "bed4a4f8", "metadata": {}, "outputs": [], "source": [ "# Model settings\n", "EPOCH_MAE = 30\n", "EPOCH_SEG = 30\n", "LEARNING_RATE_MAE = 1e-3\n", "LEARNING_RATE_SEG = 1e-3\n", "BATCH_SIZE_MAE = 8\n", "BATCH_SIZE_SEG = 8\n", "\n", "root = \"../dataset\"\n", "train_locs = [\"Aletsch\", \"Rhone\", \"Gorner\", \"Anzere\", \"Diablerets\", \"Gorbassiere\", \"Moiry\", \"Saas-Tal\"]\n", "val_locs = [\"PleineMorte\", \"Zmutt\"]\n", "\n", "band_stats_file = \"band_stats.json\"\n", "\n", "augment = GeoAugment()\n", "\n", "if os.path.exists(band_stats_file):\n", " with open(band_stats_file, \"r\") as f:\n", " band_stats = json.load(f)\n", "else:\n", " band_stats = compute_band_statistics(\n", " root,\n", " locations=train_locs,\n", " modalities=(\"S1\", \"S2\", \"DEM\", \"Hillshade\", \"Cloudmask\"),\n", " patch_size=256,\n", " stride=256,\n", " )\n", " with open(band_stats_file, \"w\") as f:\n", " json.dump(band_stats, f)\n", "\n", "experiance = [\n", " {\n", " \"S1\": [],\n", " \"S2\": [0, 1, 2, 10],\n", " \"DEM\": [],\n", " \"Hillshade\": [],\n", " \"Cloudmask\": []\n", " },\n", " {\n", " \"S1\": [],\n", " \"S2\": [0, 1, 2, 10],\n", " \"DEM\": [0],\n", " \"Hillshade\": [0],\n", " \"Cloudmask\": [0]\n", " },\n", " {\n", " \"S1\": [0, 1],\n", " \"S2\": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],\n", " \"DEM\": [],\n", " \"Hillshade\": [],\n", " \"Cloudmask\": []\n", " },\n", " {\n", " \"S1\": [0, 1],\n", " \"S2\": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],\n", " \"DEM\": [0],\n", " \"Hillshade\": [0],\n", " \"Cloudmask\": [0]\n", " }\n", "]\n", "\n", "models = [\n", " # {\n", " # \"model\": unet,\n", " # \"name\": \"unet\",\n", " # \"image_size\": 256,\n", " # \"seg_remove_names\": [\"outc.conv\"],\n", " # \"loss\": \"focal\",\n", " # \"focal_alpha\": 0.25,\n", " # \"focal_gamma\": 2.0\n", " # },\n", " # {\n", " # \"model\": unet,\n", " # \"name\": \"unet_focal1\",\n", " # \"image_size\": 256,\n", " # \"seg_remove_names\": [\"outc.conv\"],\n", " # \"loss\": \"focal\",\n", " # \"focal_alpha\": 0.5,\n", " # \"focal_gamma\": 1.5\n", " # },\n", " # {\n", " # \"model\": unet,\n", " # \"name\": \"unet_ce\",\n", " # \"image_size\": 256,\n", " # \"seg_remove_names\": [\"outc.conv\"],\n", " # \"loss\": \"cross_entropy\" \n", " # },\n", " {\n", " \"model\": unetResnet34,\n", " \"name\": \"unet_resnet34\",\n", " \"image_size\": 224,\n", " \"seg_remove_names\": [\"outc.conv\"]\n", " \"loss\": \"focal\",\n", " \"focal_alpha\": 0.5,\n", " \"focal_gamma\": 1.5\n", " },\n", " {\n", " \"model\": unetResnet34ASPP,\n", " \"name\": \"unet_resnet34ASPP\",\n", " \"image_size\": 224,\n", " \"seg_remove_names\": [\"outc.conv\"]\n", " \"loss\": \"focal\",\n", " \"focal_alpha\": 0.5,\n", " \"focal_gamma\": 1.5\n", " },\n", " {\n", " \"model\": deeplabv3,\n", " \"name\": \"deeplabv3\",\n", " \"image_size\": 256, \n", " \"seg_remove_names\": [\"decoder.output.3\"]\n", " \"loss\": \"focal\",\n", " \"focal_alpha\": 0.5,\n", " \"focal_gamma\": 1.5\n", " },\n", " {\n", " \"model\": swinv2Cnn,\n", " \"name\": \"swinv2_cnn\",\n", " \"image_size\": 256, \n", " \"seg_remove_names\": [\"head\"]\n", " \"loss\": \"focal\",\n", " \"focal_alpha\": 0.5,\n", " \"focal_gamma\": 1.5\n", " },\n", " {\n", " \"model\": swinv2Swinv2,\n", " \"name\": \"swinv2_swinv2\",\n", " \"image_size\": 224,\n", " \"seg_remove_names\": [\"swin_unet.output.weight\"]\n", " \"loss\": \"focal\",\n", " \"focal_alpha\": 0.5,\n", " \"focal_gamma\": 1.5\n", " }\n", "]" ] }, { "cell_type": "code", "execution_count": 8, "id": "fb5d4000", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Current model beeing trained - unet\n", "Epoch 1/30 - System start time 23:33:35\n", "Epoch 2/30 - System start time 23:33:53\n", "Epoch 3/30 - System start time 23:34:11\n", "Epoch 4/30 - System start time 23:34:29\n", "Epoch 5/30 - System start time 23:34:47\n", "Epoch 6/30 - System start time 23:35:04\n", "Epoch 7/30 - System start time 23:35:22\n", "Epoch 8/30 - System start time 23:35:39\n", "Epoch 9/30 - System start time 23:35:57\n", "Epoch 10/30 - System start time 23:36:15\n", "Epoch 11/30 - System start time 23:36:32\n", "Epoch 12/30 - System start time 23:36:50\n", "Epoch 13/30 - System start time 23:37:07\n", "Epoch 14/30 - System start time 23:37:25\n", "Epoch 15/30 - System start time 23:37:42\n", "Epoch 16/30 - System start time 23:37:59\n", "Epoch 17/30 - System start time 23:38:17\n", "Epoch 18/30 - System start time 23:38:35\n", "Epoch 19/30 - System start time 23:38:53\n", "Epoch 20/30 - System start time 23:39:12\n", "Epoch 21/30 - System start time 23:39:30\n", "Epoch 22/30 - System start time 23:39:47\n", "Epoch 23/30 - System start time 23:40:04\n", "Epoch 24/30 - System start time 23:40:22\n", "Epoch 25/30 - System start time 23:40:39\n", "Epoch 26/30 - System start time 23:40:57\n", "Epoch 27/30 - System start time 23:41:15\n", "Epoch 28/30 - System start time 23:41:32\n", "Epoch 29/30 - System start time 23:41:49\n", "Epoch 30/30 - System start time 23:42:05\n", "Epoch 1/30 - System start time 23:42:41\n", "Epoch 2/30 - System start time 23:43:07\n", "Epoch 3/30 - System start time 23:43:33\n", "Epoch 4/30 - System start time 23:43:58\n", "Epoch 5/30 - System start time 23:44:23\n", "Epoch 6/30 - System start time 23:44:48\n", "Epoch 7/30 - System start time 23:45:15\n", "Epoch 8/30 - System start time 23:45:40\n", "Epoch 9/30 - System start time 23:46:05\n", "Epoch 10/30 - System start time 23:46:31\n", "Epoch 11/30 - System start time 23:46:57\n", "Epoch 12/30 - System start time 23:47:25\n", "Epoch 13/30 - System start time 23:47:52\n", "Epoch 14/30 - System start time 23:48:19\n", "Epoch 15/30 - System start time 23:48:47\n", "Epoch 16/30 - System start time 23:49:16\n", "Epoch 17/30 - System start time 23:49:43\n", "Epoch 18/30 - System start time 23:50:12\n", "Epoch 19/30 - System start time 23:50:38\n", "Epoch 20/30 - System start time 23:51:07\n", "Epoch 21/30 - System start time 23:51:37\n", "Epoch 22/30 - System start time 23:52:06\n", "Epoch 23/30 - System start time 23:52:36\n", "Epoch 24/30 - System start time 23:53:05\n", "Epoch 25/30 - System start time 23:53:34\n", "Epoch 26/30 - System start time 23:54:03\n", "Epoch 27/30 - System start time 23:54:32\n", "Epoch 28/30 - System start time 23:55:02\n", "Epoch 29/30 - System start time 23:55:31\n", "Epoch 30/30 - System start time 23:56:00\n", "Epoch 1/30 - System start time 23:56:58\n", "Epoch 2/30 - System start time 23:57:26\n", "Epoch 3/30 - System start time 23:57:54\n", "Epoch 4/30 - System start time 23:58:21\n", "Epoch 5/30 - System start time 23:58:48\n", "Epoch 6/30 - System start time 23:59:14\n", "Epoch 7/30 - System start time 23:59:41\n", "Epoch 8/30 - System start time 00:00:08\n", "Epoch 9/30 - System start time 00:00:35\n", "Epoch 10/30 - System start time 00:01:03\n", "Epoch 11/30 - System start time 00:01:32\n", "Epoch 12/30 - System start time 00:01:59\n", "Epoch 13/30 - System start time 00:02:25\n", "Epoch 14/30 - System start time 00:02:52\n", "Epoch 15/30 - System start time 00:03:19\n", "Epoch 16/30 - System start time 00:03:46\n", "Epoch 17/30 - System start time 00:04:13\n", "Epoch 18/30 - System start time 00:04:41\n", "Epoch 19/30 - System start time 00:05:09\n", "Epoch 20/30 - System start time 00:05:36\n", "Epoch 21/30 - System start time 00:06:05\n", "Epoch 22/30 - System start time 00:06:33\n", "Epoch 23/30 - System start time 00:07:00\n", "Epoch 24/30 - System start time 00:07:28\n", "Epoch 25/30 - System start time 00:07:55\n", "Epoch 26/30 - System start time 00:08:22\n", "Epoch 27/30 - System start time 00:08:49\n", "Epoch 28/30 - System start time 00:09:17\n", "Epoch 29/30 - System start time 00:09:44\n", "Epoch 30/30 - System start time 00:10:11\n", "Epoch 1/30 - System start time 00:11:07\n", "Epoch 2/30 - System start time 00:11:45\n", "Epoch 3/30 - System start time 00:12:22\n", "Epoch 4/30 - System start time 00:12:59\n", "Epoch 5/30 - System start time 00:13:36\n", "Epoch 6/30 - System start time 00:14:12\n", "Epoch 7/30 - System start time 00:14:50\n", "Epoch 8/30 - System start time 00:15:26\n", "Epoch 9/30 - System start time 00:16:04\n", "Epoch 10/30 - System start time 00:16:40\n", "Epoch 11/30 - System start time 00:17:10\n", "Epoch 12/30 - System start time 00:17:44\n", "Epoch 13/30 - System start time 00:18:23\n", "Epoch 14/30 - System start time 00:19:00\n", "Epoch 15/30 - System start time 00:19:36\n", "Epoch 16/30 - System start time 00:20:13\n", "Epoch 17/30 - System start time 00:20:50\n", "Epoch 18/30 - System start time 00:21:26\n", "Epoch 19/30 - System start time 00:22:02\n", "Epoch 20/30 - System start time 00:22:38\n", "Epoch 21/30 - System start time 00:23:15\n", "Epoch 22/30 - System start time 00:23:52\n", "Epoch 23/30 - System start time 00:24:28\n", "Epoch 24/30 - System start time 00:25:06\n", "Epoch 25/30 - System start time 00:25:44\n", "Epoch 26/30 - System start time 00:26:22\n", "Epoch 27/30 - System start time 00:26:59\n", "Epoch 28/30 - System start time 00:27:36\n", "Epoch 29/30 - System start time 00:28:12\n", "Epoch 30/30 - System start time 00:28:49\n", "Current model beeing trained - unet_focal1\n", "Epoch 1/30 - System start time 00:29:52\n", "Epoch 2/30 - System start time 00:30:14\n", "Epoch 3/30 - System start time 00:30:34\n", "Epoch 4/30 - System start time 00:30:54\n", "Epoch 5/30 - System start time 00:31:11\n", "Epoch 6/30 - System start time 00:31:28\n", "Epoch 7/30 - System start time 00:31:44\n", "Epoch 8/30 - System start time 00:32:01\n", "Epoch 9/30 - System start time 00:32:18\n", "Epoch 10/30 - System start time 00:32:37\n", "Epoch 11/30 - System start time 00:32:57\n", "Epoch 12/30 - System start time 00:33:17\n", "Epoch 13/30 - System start time 00:33:37\n", "Epoch 14/30 - System start time 00:33:58\n", "Epoch 15/30 - System start time 00:34:18\n", "Epoch 16/30 - System start time 00:34:38\n", "Epoch 17/30 - System start time 00:34:59\n", "Epoch 18/30 - System start time 00:35:19\n", "Epoch 19/30 - System start time 00:35:39\n", "Epoch 20/30 - System start time 00:36:00\n", "Epoch 21/30 - System start time 00:36:20\n", "Epoch 22/30 - System start time 00:36:41\n", "Epoch 23/30 - System start time 00:37:00\n", "Epoch 24/30 - System start time 00:37:21\n", "Epoch 25/30 - System start time 00:37:41\n", "Epoch 26/30 - System start time 00:38:02\n", "Epoch 27/30 - System start time 00:38:22\n", "Epoch 28/30 - System start time 00:38:42\n", "Epoch 29/30 - System start time 00:39:03\n", "Epoch 30/30 - System start time 00:39:23\n", "Epoch 1/30 - System start time 00:40:10\n", "Epoch 2/30 - System start time 00:40:39\n", "Epoch 3/30 - System start time 00:41:10\n", "Epoch 4/30 - System start time 00:41:41\n", "Epoch 5/30 - System start time 00:42:11\n", "Epoch 6/30 - System start time 00:42:39\n", "Epoch 7/30 - System start time 00:43:09\n", "Epoch 8/30 - System start time 00:43:38\n", "Epoch 9/30 - System start time 00:44:08\n", "Epoch 10/30 - System start time 00:44:38\n", "Epoch 11/30 - System start time 00:45:08\n", "Epoch 12/30 - System start time 00:45:39\n", "Epoch 13/30 - System start time 00:46:09\n", "Epoch 14/30 - System start time 00:46:40\n", "Epoch 15/30 - System start time 00:47:10\n", "Epoch 16/30 - System start time 00:47:41\n", "Epoch 17/30 - System start time 00:48:12\n", "Epoch 18/30 - System start time 00:48:42\n", "Epoch 19/30 - System start time 00:49:11\n", "Epoch 20/30 - System start time 00:49:40\n", "Epoch 21/30 - System start time 00:50:09\n", "Epoch 22/30 - System start time 00:50:39\n", "Epoch 23/30 - System start time 00:51:09\n", "Epoch 24/30 - System start time 00:51:40\n", "Epoch 25/30 - System start time 00:52:11\n", "Epoch 26/30 - System start time 00:52:41\n", "Epoch 27/30 - System start time 00:53:11\n", "Epoch 28/30 - System start time 00:53:42\n", "Epoch 29/30 - System start time 00:54:13\n", "Epoch 30/30 - System start time 00:54:43\n", "Epoch 1/30 - System start time 00:55:40\n", "Epoch 2/30 - System start time 00:56:08\n", "Epoch 3/30 - System start time 00:56:35\n", "Epoch 4/30 - System start time 00:57:01\n", "Epoch 5/30 - System start time 00:57:28\n", "Epoch 6/30 - System start time 00:57:55\n", "Epoch 7/30 - System start time 00:58:22\n", "Epoch 8/30 - System start time 00:58:49\n", "Epoch 9/30 - System start time 00:59:16\n", "Epoch 10/30 - System start time 00:59:43\n", "Epoch 11/30 - System start time 01:00:10\n", "Epoch 12/30 - System start time 01:00:36\n", "Epoch 13/30 - System start time 01:01:03\n", "Epoch 14/30 - System start time 01:01:30\n", "Epoch 15/30 - System start time 01:01:57\n", "Epoch 16/30 - System start time 01:02:24\n", "Epoch 17/30 - System start time 01:02:51\n", "Epoch 18/30 - System start time 01:03:19\n", "Epoch 19/30 - System start time 01:03:46\n", "Epoch 20/30 - System start time 01:04:13\n", "Epoch 21/30 - System start time 01:04:41\n", "Epoch 22/30 - System start time 01:05:08\n", "Epoch 23/30 - System start time 01:05:35\n", "Epoch 24/30 - System start time 01:06:03\n", "Epoch 25/30 - System start time 01:06:29\n", "Epoch 26/30 - System start time 01:06:56\n", "Epoch 27/30 - System start time 01:07:23\n", "Epoch 28/30 - System start time 01:07:50\n", "Epoch 29/30 - System start time 01:08:17\n", "Epoch 30/30 - System start time 01:08:44\n", "Epoch 1/30 - System start time 01:09:38\n", "Epoch 2/30 - System start time 01:10:14\n", "Epoch 3/30 - System start time 01:10:51\n", "Epoch 4/30 - System start time 01:11:28\n", "Epoch 5/30 - System start time 01:12:04\n", "Epoch 6/30 - System start time 01:12:40\n", "Epoch 7/30 - System start time 01:13:17\n", "Epoch 8/30 - System start time 01:13:54\n", "Epoch 9/30 - System start time 01:14:31\n", "Epoch 10/30 - System start time 01:15:07\n", "Epoch 11/30 - System start time 01:15:44\n", "Epoch 12/30 - System start time 01:16:21\n", "Epoch 13/30 - System start time 01:16:58\n", "Epoch 14/30 - System start time 01:17:33\n", "Epoch 15/30 - System start time 01:18:10\n", "Epoch 16/30 - System start time 01:18:47\n", "Epoch 17/30 - System start time 01:19:24\n", "Epoch 18/30 - System start time 01:20:00\n", "Epoch 19/30 - System start time 01:20:37\n", "Epoch 20/30 - System start time 01:21:12\n", "Epoch 21/30 - System start time 01:21:49\n", "Epoch 22/30 - System start time 01:22:26\n", "Epoch 23/30 - System start time 01:23:02\n", "Epoch 24/30 - System start time 01:23:39\n", "Epoch 25/30 - System start time 01:24:17\n", "Epoch 26/30 - System start time 01:24:54\n", "Epoch 27/30 - System start time 01:25:31\n", "Epoch 28/30 - System start time 01:26:08\n", "Epoch 29/30 - System start time 01:26:44\n", "Epoch 30/30 - System start time 01:27:21\n", "Current model beeing trained - unet_ce\n", "Epoch 1/30 - System start time 01:28:24\n", "Epoch 2/30 - System start time 01:28:44\n", "Epoch 3/30 - System start time 01:29:04\n", "Epoch 4/30 - System start time 01:29:25\n", "Epoch 5/30 - System start time 01:29:46\n", "Epoch 6/30 - System start time 01:30:07\n", "Epoch 7/30 - System start time 01:30:27\n", "Epoch 8/30 - System start time 01:30:47\n", "Epoch 9/30 - System start time 01:31:08\n", "Epoch 10/30 - System start time 01:31:28\n", "Epoch 11/30 - System start time 01:31:49\n", "Epoch 12/30 - System start time 01:32:09\n", "Epoch 13/30 - System start time 01:32:30\n", "Epoch 14/30 - System start time 01:32:50\n", "Epoch 15/30 - System start time 01:33:11\n", "Epoch 16/30 - System start time 01:33:31\n", "Epoch 17/30 - System start time 01:33:51\n", "Epoch 18/30 - System start time 01:34:12\n", "Epoch 19/30 - System start time 01:34:32\n", "Epoch 20/30 - System start time 01:34:53\n", "Epoch 21/30 - System start time 01:35:14\n", "Epoch 22/30 - System start time 01:35:35\n", "Epoch 23/30 - System start time 01:35:56\n", "Epoch 24/30 - System start time 01:36:17\n", "Epoch 25/30 - System start time 01:36:38\n", "Epoch 26/30 - System start time 01:36:59\n", "Epoch 27/30 - System start time 01:37:20\n", "Epoch 28/30 - System start time 01:37:40\n", "Epoch 29/30 - System start time 01:38:00\n", "Epoch 30/30 - System start time 01:38:20\n", "Epoch 1/30 - System start time 01:39:08\n", "Epoch 2/30 - System start time 01:39:39\n", "Epoch 3/30 - System start time 01:40:11\n", "Epoch 4/30 - System start time 01:40:42\n", "Epoch 5/30 - System start time 01:41:14\n", "Epoch 6/30 - System start time 01:41:46\n", "Epoch 7/30 - System start time 01:42:17\n", "Epoch 8/30 - System start time 01:42:47\n", "Epoch 9/30 - System start time 01:43:17\n", "Epoch 10/30 - System start time 01:43:48\n", "Epoch 11/30 - System start time 01:44:18\n", "Epoch 12/30 - System start time 01:44:48\n", "Epoch 13/30 - System start time 01:45:18\n", "Epoch 14/30 - System start time 01:45:48\n", "Epoch 15/30 - System start time 01:46:19\n", "Epoch 16/30 - System start time 01:46:49\n", "Epoch 17/30 - System start time 01:47:20\n", "Epoch 18/30 - System start time 01:47:51\n", "Epoch 19/30 - System start time 01:48:21\n", "Epoch 20/30 - System start time 01:48:51\n", "Epoch 21/30 - System start time 01:49:21\n", "Epoch 22/30 - System start time 01:49:51\n", "Epoch 23/30 - System start time 01:50:21\n", "Epoch 24/30 - System start time 01:50:51\n", "Epoch 25/30 - System start time 01:51:22\n", "Epoch 26/30 - System start time 01:51:53\n", "Epoch 27/30 - System start time 01:52:24\n", "Epoch 28/30 - System start time 01:52:54\n", "Epoch 29/30 - System start time 01:53:24\n", "Epoch 30/30 - System start time 01:53:55\n", "Epoch 1/30 - System start time 01:54:52\n", "Epoch 2/30 - System start time 01:55:20\n", "Epoch 3/30 - System start time 01:55:48\n", "Epoch 4/30 - System start time 01:56:14\n", "Epoch 5/30 - System start time 01:56:41\n", "Epoch 6/30 - System start time 01:57:08\n", "Epoch 7/30 - System start time 01:57:35\n", "Epoch 8/30 - System start time 01:58:02\n", "Epoch 9/30 - System start time 01:58:28\n", "Epoch 10/30 - System start time 01:58:55\n", "Epoch 11/30 - System start time 01:59:23\n", "Epoch 12/30 - System start time 01:59:50\n", "Epoch 13/30 - System start time 02:00:17\n", "Epoch 14/30 - System start time 02:00:43\n", "Epoch 15/30 - System start time 02:01:10\n", "Epoch 16/30 - System start time 02:01:37\n", "Epoch 17/30 - System start time 02:02:04\n", "Epoch 18/30 - System start time 02:02:31\n", "Epoch 19/30 - System start time 02:02:59\n", "Epoch 20/30 - System start time 02:03:26\n", "Epoch 21/30 - System start time 02:03:53\n", "Epoch 22/30 - System start time 02:04:21\n", "Epoch 23/30 - System start time 02:04:48\n", "Epoch 24/30 - System start time 02:05:14\n", "Epoch 25/30 - System start time 02:05:42\n", "Epoch 26/30 - System start time 02:06:08\n", "Epoch 27/30 - System start time 02:06:36\n", "Epoch 28/30 - System start time 02:07:03\n", "Epoch 29/30 - System start time 02:07:31\n", "Epoch 30/30 - System start time 02:07:58\n", "Epoch 1/30 - System start time 02:08:51\n", "Epoch 2/30 - System start time 02:09:29\n", "Epoch 3/30 - System start time 02:10:08\n", "Epoch 4/30 - System start time 02:10:43\n", "Epoch 5/30 - System start time 02:11:19\n", "Epoch 6/30 - System start time 02:11:55\n", "Epoch 7/30 - System start time 02:12:31\n", "Epoch 8/30 - System start time 02:13:07\n", "Epoch 9/30 - System start time 02:13:43\n", "Epoch 10/30 - System start time 02:14:19\n", "Epoch 11/30 - System start time 02:14:55\n", "Epoch 12/30 - System start time 02:15:32\n", "Epoch 13/30 - System start time 02:16:09\n", "Epoch 14/30 - System start time 02:16:45\n", "Epoch 15/30 - System start time 02:17:23\n", "Epoch 16/30 - System start time 02:18:01\n", "Epoch 17/30 - System start time 02:18:37\n", "Epoch 18/30 - System start time 02:19:13\n", "Epoch 19/30 - System start time 02:19:49\n", "Epoch 20/30 - System start time 02:20:26\n", "Epoch 21/30 - System start time 02:21:03\n", "Epoch 22/30 - System start time 02:21:41\n", "Epoch 23/30 - System start time 02:22:17\n", "Epoch 24/30 - System start time 02:22:56\n", "Epoch 25/30 - System start time 02:23:32\n", "Epoch 26/30 - System start time 02:24:09\n", "Epoch 27/30 - System start time 02:24:45\n", "Epoch 28/30 - System start time 02:25:21\n", "Epoch 29/30 - System start time 02:25:57\n", "Epoch 30/30 - System start time 02:26:34\n" ] } ], "source": [ "for model_config in models:\n", " model_name = model_config[\"name\"]\n", " model_image_size = model_config[\"image_size\"]\n", " arch = model_config[\"model\"]\n", " weights_to_remove = model_config[\"seg_remove_names\"]\n", " print(f\"Current model beeing trained - {model_name}\")\n", " for i, exp in enumerate(experiance):\n", " train_ds_mae = SatellitePatchDataset(\n", " root,\n", " locations=train_locs,\n", " band_stats=band_stats,\n", " transform=augment,\n", " patch_size=model_image_size,\n", " ch_s1=exp[\"S1\"],\n", " ch_s2=exp[\"S2\"],\n", " ch_dem=exp[\"DEM\"],\n", " ch_hillshade=exp[\"Hillshade\"],\n", " ch_cloudmask=exp[\"Cloudmask\"],\n", " task='mae',\n", " skip_empty=True,\n", " empty_tile_ratio=0.0,\n", " masking_ratio=0.7\n", " )\n", "\n", " val_ds_mae = SatellitePatchDataset(\n", " root,\n", " locations=val_locs,\n", " band_stats=band_stats,\n", " transform=None,\n", " patch_size=model_image_size,\n", " ch_s1=exp[\"S1\"],\n", " ch_s2=exp[\"S2\"],\n", " ch_dem=exp[\"DEM\"],\n", " ch_hillshade=exp[\"Hillshade\"],\n", " ch_cloudmask=exp[\"Cloudmask\"],\n", " task='mae',\n", " skip_empty=False,\n", " empty_tile_ratio=0.0,\n", " masking_ratio=0.7\n", " )\n", "\n", " train_ds_seg = SatellitePatchDataset(\n", " root,\n", " locations=train_locs,\n", " patch_size=model_image_size,\n", " stride=model_image_size,\n", " skip_empty=True,\n", " empty_tile_ratio=0.0,\n", " ch_s1=exp[\"S1\"],\n", " ch_s2=exp[\"S2\"],\n", " ch_dem=exp[\"DEM\"],\n", " ch_hillshade=exp[\"Hillshade\"],\n", " ch_cloudmask=exp[\"Cloudmask\"],\n", " transform=augment,\n", " band_stats=band_stats,\n", " task='segmentation'\n", " )\n", "\n", " val_ds_seg = SatellitePatchDataset(\n", " root,\n", " locations=val_locs,\n", " patch_size=model_image_size,\n", " ch_s1=exp[\"S1\"],\n", " ch_s2=exp[\"S2\"],\n", " ch_dem=exp[\"DEM\"],\n", " ch_hillshade=exp[\"Hillshade\"],\n", " ch_cloudmask=exp[\"Cloudmask\"],\n", " stride=model_image_size,\n", " band_stats=band_stats,\n", " task='segmentation'\n", " )\n", "\n", " train_loader_mae = DataLoader(train_ds_mae, batch_size=BATCH_SIZE_MAE, shuffle=True)\n", " val_loader_mae = DataLoader(val_ds_mae, batch_size=BATCH_SIZE_MAE, shuffle=False)\n", " train_loader_seg = DataLoader(train_ds_seg, batch_size=BATCH_SIZE_SEG, shuffle=True)\n", " val_loader_seg = DataLoader(val_ds_seg, batch_size=BATCH_SIZE_SEG, shuffle=False)\n", "\n", " IN_CHANNELS = len(exp[\"S1\"]) + len(exp[\"S2\"]) + len(exp[\"DEM\"]) + len(exp[\"Hillshade\"]) + len(exp[\"Cloudmask\"])\n", "\n", " model = arch(in_channels=IN_CHANNELS, num_classes=IN_CHANNELS, freeze_encoder=False).cuda()\n", "\n", " # train(model, train_loader_mae, val_loader_mae, optim.AdamW(model.parameters(), lr=LEARNING_RATE_MAE), epochs=EPOCH_MAE, save_folder=f\"{model_name}/{i}/{model_name}_mae_models\", task='mae')\n", "\n", " checkpoint = torch.load(f\"{model_name}/{i}/{model_name}_mae_models/epoch_best.pth\")\n", " state_dict = checkpoint[\"model\"] if \"model\" in checkpoint else checkpoint\n", "\n", " state_dict = {k: v for k, v in state_dict.items()\n", " if not any(k.startswith(name) for name in weights_to_remove)}\n", "\n", " model = arch(in_channels=IN_CHANNELS, num_classes=1, freeze_encoder=True).cuda()\n", " model.load_state_dict(state_dict, strict=False)\n", " train(model, train_loader_seg, val_loader_seg, optim.AdamW(model.parameters(), lr=LEARNING_RATE_SEG),\n", " epochs=EPOCH_SEG,\n", " save_folder=f\"{model_name}/{i}/{model_name}_seg_models\",\n", " task='segmentation',\n", " loss=model_config.get(\"loss\", \"default\"),\n", " focal_alpha=model_config.get(\"focal_alpha\", 0.25),\n", " focal_gamma=model_config.get(\"focal_gamma\", 2.0))" ] }, { "cell_type": "code", "execution_count": null, "id": "5faf22b9-75bc-4f7a-b14f-ae17a3c79a04", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": ".venv", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.7" } }, "nbformat": 4, "nbformat_minor": 5 }