diff --git "a/diffusion.ipynb" "b/diffusion.ipynb" --- "a/diffusion.ipynb" +++ "b/diffusion.ipynb" @@ -246,7 +246,7 @@ " # dim = 2\n", " dim = 2\n", " stride = (2,2) if dim == 2 else (2,2,4)\n", - " num_image = 5000#2560#800#2560\n", + " num_image = 10000#2560#800#2560\n", " batch_size = 50#20#2#100 # 10\n", " n_epoch = 50#20#20#2#5#25 # 120\n", " HII_DIM = 64\n", @@ -452,9 +452,12 @@ " print('saved model at ' + self.config.save_name)\n", " # print('saved model at ' + config.save_dir + f\"model_epoch_{ep}_test_{config.run_name}.pth\")\n", "\n", - " def sample(self, file, params:torch.tensor=None, ema=False, entire=False):\n", + " def sample(self, file, params:torch.tensor=None, repeat=192, ema=False, entire=False):\n", " # n_sample = params.shape[0]\n", - " params = params or torch.tensor([0.20000000000000018, 0.5055875000000001]).repeat(192,1)\n", + " if params is None:\n", + " params = torch.tensor([0.20000000000000018, 0.5055875000000001])\n", + " print(f\"sampling {repeat} images with normalized params = {params}\")\n", + " params = params.repeat(repeat,1)\n", " assert params.dim() == 2, \"params must be a 2D torch.tensor\"\n", " # print(\"params =\", params)\n", " # print(\"params =\", params)\n", @@ -501,7 +504,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "b8ede2f6124343619dc5bd2c56869776", + "model_id": "3a9163ca96334ab5b7ce24b7f45d07cb", "version_major": 2, "version_minor": 0 }, @@ -519,7 +522,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 7, "metadata": {}, "outputs": [ { @@ -528,14 +531,14 @@ "text": [ "-------------------- round 0 ---------------------\n", "Number of parameters for nn_model: 111048705\n", - "run_name = 0604-1713\n", + "run_name = 0604-2255\n", "Launching training on one GPU.\n", "dataset content: \n", "51200 images can be loaded\n", "field.shape = (64, 64, 514)\n", "params keys = [b'ION_Tvir_MIN', b'HII_EFF_FACTOR']\n", - "loading 5000 images randomly\n", - "images loaded: (5000, 1, 64, 64)\n" + "loading 10000 images randomly\n", + "images loaded: (10000, 1, 64, 64)\n" ] }, { @@ -549,20 +552,20 @@ "name": "stdout", "output_type": "stream", "text": [ - "params loaded: (5000, 2)\n", - "images rescaled to [-1.0, 1.221698522567749]\n", - "params rescaled to [0.0, 0.99988945406874]\n" + "params loaded: (10000, 2)\n", + "images rescaled to [-1.0, 1.1904358863830566]\n", + "params rescaled to [0.0, 0.9999527071628225]\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "4f30ac5237314eb6b560c288cfcec5a0", + "model_id": "bacb18c398d24b2090ea56edf3491954", "version_major": 2, "version_minor": 0 }, "text/plain": [ - " 0%| | 0/100 [00:00" ] @@ -1422,7 +1477,7 @@ "source": [ "# plot(\"outputs/0528-1433.npy\")\n", "# plot(\"outputs/0520-2323.npy\")\n", - "plot(\"outputs/0604-1643.npy\")" + "plot(\"outputs/0604-1927.npy\")" ] }, {