AlbeRota commited on
Commit
c9f4709
·
verified ·
1 Parent(s): 766ba3d

Upload weights, notebooks, sample images

Browse files
Files changed (1) hide show
  1. notebooks/api_examples.ipynb +17 -91
notebooks/api_examples.ipynb CHANGED
@@ -65,18 +65,18 @@
65
  },
66
  {
67
  "cell_type": "code",
68
- "execution_count": 4,
69
  "id": "d58ad7f1",
70
  "metadata": {},
71
  "outputs": [
72
  {
73
  "data": {
74
  "text/html": [
75
- "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">MODEL <span style=\"font-weight: bold\">[</span><span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">23:19:01</span><span style=\"font-weight: bold\">]</span> ✓ Decoder <span style=\"color: #008000; text-decoration-color: #008000\">'diffuse'</span>: Successfully loaded all <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">54</span> state dict keys from weights/rgb_decoder.pth\n",
76
  "</pre>\n"
77
  ],
78
  "text/plain": [
79
- "MODEL \u001b[1m[\u001b[0m\u001b[1;92m23:19:01\u001b[0m\u001b[1m]\u001b[0m ✓ Decoder \u001b[32m'diffuse'\u001b[0m: Successfully loaded all \u001b[1;36m54\u001b[0m state dict keys from weights/rgb_decoder.pth\n"
80
  ]
81
  },
82
  "metadata": {},
@@ -85,11 +85,11 @@
85
  {
86
  "data": {
87
  "text/html": [
88
- "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">MODEL <span style=\"font-weight: bold\">[</span><span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">23:19:01</span><span style=\"font-weight: bold\">]</span> Loaded pre-trained decoder weights from weights/rgb_decoder.pth\n",
89
  "</pre>\n"
90
  ],
91
  "text/plain": [
92
- "MODEL \u001b[1m[\u001b[0m\u001b[1;92m23:19:01\u001b[0m\u001b[1m]\u001b[0m Loaded pre-trained decoder weights from weights/rgb_decoder.pth\n"
93
  ]
94
  },
95
  "metadata": {},
@@ -98,11 +98,11 @@
98
  {
99
  "data": {
100
  "text/html": [
101
- "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">MODEL <span style=\"font-weight: bold\">[</span><span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">23:19:02</span><span style=\"font-weight: bold\">]</span> ✓ Token Inpainter: Successfully loaded all <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">78</span> state dict keys from weights/token_inpainter.pth\n",
102
  "</pre>\n"
103
  ],
104
  "text/plain": [
105
- "MODEL \u001b[1m[\u001b[0m\u001b[1;92m23:19:02\u001b[0m\u001b[1m]\u001b[0m ✓ Token Inpainter: Successfully loaded all \u001b[1;36m78\u001b[0m state dict keys from weights/token_inpainter.pth\n"
106
  ]
107
  },
108
  "metadata": {},
@@ -111,11 +111,11 @@
111
  {
112
  "data": {
113
  "text/html": [
114
- "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">MODEL <span style=\"font-weight: bold\">[</span><span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">23:19:02</span><span style=\"font-weight: bold\">]</span> Loaded pretrained token inpainter weights from weights/token_inpainter.pth\n",
115
  "</pre>\n"
116
  ],
117
  "text/plain": [
118
- "MODEL \u001b[1m[\u001b[0m\u001b[1;92m23:19:02\u001b[0m\u001b[1m]\u001b[0m Loaded pretrained token inpainter weights from weights/token_inpainter.pth\n"
119
  ]
120
  },
121
  "metadata": {},
@@ -157,7 +157,7 @@
157
  " \"images\"\n",
158
  ") # Modify this path to point to your image directory\n",
159
  "\n",
160
- "ds = ImageDirDataset(PATH_TO_IMAGE_DIR, target_size=(896, 896), return_path=False)\n",
161
  "loader = DataLoader(ds, batch_size=1, shuffle=False)"
162
  ]
163
  },
@@ -171,7 +171,7 @@
171
  },
172
  {
173
  "cell_type": "code",
174
- "execution_count": null,
175
  "id": "34e01754",
176
  "metadata": {},
177
  "outputs": [],
@@ -189,93 +189,19 @@
189
  },
190
  {
191
  "cell_type": "code",
192
- "execution_count": null,
193
- "id": "b588087b",
194
- "metadata": {},
195
- "outputs": [
196
- {
197
- "data": {
198
- "text/html": [
199
- "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">MODEL <span style=\"font-weight: bold\">[</span><span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">22:43:07</span><span style=\"font-weight: bold\">]</span> ✓ Decoder <span style=\"color: #008000; text-decoration-color: #008000\">'diffuse'</span>: Successfully loaded all <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">54</span> state dict keys from weights/rgb_decoder.pth\n",
200
- "</pre>\n"
201
- ],
202
- "text/plain": [
203
- "MODEL \u001b[1m[\u001b[0m\u001b[1;92m22:43:07\u001b[0m\u001b[1m]\u001b[0m ✓ Decoder \u001b[32m'diffuse'\u001b[0m: Successfully loaded all \u001b[1;36m54\u001b[0m state dict keys from weights/rgb_decoder.pth\n"
204
- ]
205
- },
206
- "metadata": {},
207
- "output_type": "display_data"
208
- },
209
- {
210
- "data": {
211
- "text/html": [
212
- "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">MODEL <span style=\"font-weight: bold\">[</span><span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">22:43:07</span><span style=\"font-weight: bold\">]</span> Loaded pre-trained decoder weights from weights/rgb_decoder.pth\n",
213
- "</pre>\n"
214
- ],
215
- "text/plain": [
216
- "MODEL \u001b[1m[\u001b[0m\u001b[1;92m22:43:07\u001b[0m\u001b[1m]\u001b[0m Loaded pre-trained decoder weights from weights/rgb_decoder.pth\n"
217
- ]
218
- },
219
- "metadata": {},
220
- "output_type": "display_data"
221
- },
222
- {
223
- "data": {
224
- "text/html": [
225
- "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">MODEL <span style=\"font-weight: bold\">[</span><span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">22:43:08</span><span style=\"font-weight: bold\">]</span> ✓ Token Inpainter: Successfully loaded all <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">78</span> state dict keys from weights/token_inpainter.pth\n",
226
- "</pre>\n"
227
- ],
228
- "text/plain": [
229
- "MODEL \u001b[1m[\u001b[0m\u001b[1;92m22:43:08\u001b[0m\u001b[1m]\u001b[0m ✓ Token Inpainter: Successfully loaded all \u001b[1;36m78\u001b[0m state dict keys from weights/token_inpainter.pth\n"
230
- ]
231
- },
232
- "metadata": {},
233
- "output_type": "display_data"
234
- },
235
- {
236
- "data": {
237
- "text/html": [
238
- "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">MODEL <span style=\"font-weight: bold\">[</span><span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">22:43:08</span><span style=\"font-weight: bold\">]</span> Loaded pretrained token inpainter weights from weights/token_inpainter.pth\n",
239
- "</pre>\n"
240
- ],
241
- "text/plain": [
242
- "MODEL \u001b[1m[\u001b[0m\u001b[1;92m22:43:08\u001b[0m\u001b[1m]\u001b[0m Loaded pretrained token inpainter weights from weights/token_inpainter.pth\n"
243
- ]
244
- },
245
- "metadata": {},
246
- "output_type": "display_data"
247
- },
248
- {
249
- "name": "stdout",
250
- "output_type": "stream",
251
- "text": [
252
- "Warning: missing keys when loading checkpoint: ['decoders.highlight.reassemble_layers.0.proj.weight', 'decoders.highlight.reassemble_layers.0.proj.bias', 'decoders.highlight.reassemble_layers.0.resample.weight', 'decoders.highlight.reassemble_layers.0.resample.bias', 'decoders.highlight.reassemble_layers.1.proj.weight', 'decoders.highlight.reassemble_layers.1.proj.bias', 'decoders.highlight.reassemble_layers.1.resample.weight', 'decoders.highlight.reassemble_layers.1.resample.bias', 'decoders.highlight.reassemble_layers.2.proj.weight', 'decoders.highlight.reassemble_layers.2.proj.bias', 'decoders.highlight.reassemble_layers.3.proj.weight', 'decoders.highlight.reassemble_layers.3.proj.bias', 'decoders.highlight.reassemble_layers.3.resample.weight', 'decoders.highlight.reassemble_layers.3.resample.bias', 'decoders.highlight.fusion_blocks.0.residual_conv1.weight', 'decoders.highlight.fusion_blocks.0.residual_conv1.bias', 'decoders.highlight.fusion_blocks.0.residual_conv2.0.weight', 'decoders.highlight.fusion_blocks.0.residual_conv2.0.bias', 'decoders.highlight.fusion_blocks.0.residual_conv2.3.weight', 'decoders.highlight.fusion_blocks.0.residual_conv2.3.bias', 'decoders.highlight.fusion_blocks.0.out_conv.weight', 'decoders.highlight.fusion_blocks.0.out_conv.bias', 'decoders.highlight.fusion_blocks.1.residual_conv1.weight', 'decoders.highlight.fusion_blocks.1.residual_conv1.bias', 'decoders.highlight.fusion_blocks.1.residual_conv2.0.weight', 'decoders.highlight.fusion_blocks.1.residual_conv2.0.bias', 'decoders.highlight.fusion_blocks.1.residual_conv2.3.weight', 'decoders.highlight.fusion_blocks.1.residual_conv2.3.bias', 'decoders.highlight.fusion_blocks.1.out_conv.weight', 'decoders.highlight.fusion_blocks.1.out_conv.bias', 'decoders.highlight.fusion_blocks.2.residual_conv1.weight', 'decoders.highlight.fusion_blocks.2.residual_conv1.bias', 'decoders.highlight.fusion_blocks.2.residual_conv2.0.weight', 'decoders.highlight.fusion_blocks.2.residual_conv2.0.bias', 'decoders.highlight.fusion_blocks.2.residual_conv2.3.weight', 'decoders.highlight.fusion_blocks.2.residual_conv2.3.bias', 'decoders.highlight.fusion_blocks.2.out_conv.weight', 'decoders.highlight.fusion_blocks.2.out_conv.bias', 'decoders.highlight.fusion_blocks.3.residual_conv1.weight', 'decoders.highlight.fusion_blocks.3.residual_conv1.bias', 'decoders.highlight.fusion_blocks.3.residual_conv2.0.weight', 'decoders.highlight.fusion_blocks.3.residual_conv2.0.bias', 'decoders.highlight.fusion_blocks.3.residual_conv2.3.weight', 'decoders.highlight.fusion_blocks.3.residual_conv2.3.bias', 'decoders.highlight.fusion_blocks.3.out_conv.weight', 'decoders.highlight.fusion_blocks.3.out_conv.bias', 'decoders.highlight.rgb_head.0.weight', 'decoders.highlight.rgb_head.0.bias', 'decoders.highlight.rgb_head.5.weight', 'decoders.highlight.rgb_head.5.bias', 'decoders.highlight.rgb_head.9.weight', 'decoders.highlight.rgb_head.9.bias', 'decoders.highlight.rgb_head.13.weight', 'decoders.highlight.rgb_head.13.bias', 'token_inpaint.mask_token', 'token_inpaint.mask_indicator', 'token_inpaint.blocks.0.attn.norm.weight', 'token_inpaint.blocks.0.attn.norm.bias', 'token_inpaint.blocks.0.attn.fn.attn.in_proj_weight', 'token_inpaint.blocks.0.attn.fn.attn.in_proj_bias', 'token_inpaint.blocks.0.attn.fn.attn.out_proj.weight', 'token_inpaint.blocks.0.attn.fn.attn.out_proj.bias', 'token_inpaint.blocks.0.mlp.norm.weight', 'token_inpaint.blocks.0.mlp.norm.bias', 'token_inpaint.blocks.0.mlp.fn.fc1.weight', 'token_inpaint.blocks.0.mlp.fn.fc1.bias', 'token_inpaint.blocks.0.mlp.fn.fc2.weight', 'token_inpaint.blocks.0.mlp.fn.fc2.bias', 'token_inpaint.blocks.1.attn.norm.weight', 'token_inpaint.blocks.1.attn.norm.bias', 'token_inpaint.blocks.1.attn.fn.attn.in_proj_weight', 'token_inpaint.blocks.1.attn.fn.attn.in_proj_bias', 'token_inpaint.blocks.1.attn.fn.attn.out_proj.weight', 'token_inpaint.blocks.1.attn.fn.attn.out_proj.bias', 'token_inpaint.blocks.1.mlp.norm.weight', 'token_inpaint.blocks.1.mlp.norm.bias', 'token_inpaint.blocks.1.mlp.fn.fc1.weight', 'token_inpaint.blocks.1.mlp.fn.fc1.bias', 'token_inpaint.blocks.1.mlp.fn.fc2.weight', 'token_inpaint.blocks.1.mlp.fn.fc2.bias', 'token_inpaint.blocks.2.attn.norm.weight', 'token_inpaint.blocks.2.attn.norm.bias', 'token_inpaint.blocks.2.attn.fn.attn.in_proj_weight', 'token_inpaint.blocks.2.attn.fn.attn.in_proj_bias', 'token_inpaint.blocks.2.attn.fn.attn.out_proj.weight', 'token_inpaint.blocks.2.attn.fn.attn.out_proj.bias', 'token_inpaint.blocks.2.mlp.norm.weight', 'token_inpaint.blocks.2.mlp.norm.bias', 'token_inpaint.blocks.2.mlp.fn.fc1.weight', 'token_inpaint.blocks.2.mlp.fn.fc1.bias', 'token_inpaint.blocks.2.mlp.fn.fc2.weight', 'token_inpaint.blocks.2.mlp.fn.fc2.bias', 'token_inpaint.blocks.3.attn.norm.weight', 'token_inpaint.blocks.3.attn.norm.bias', 'token_inpaint.blocks.3.attn.fn.attn.in_proj_weight', 'token_inpaint.blocks.3.attn.fn.attn.in_proj_bias', 'token_inpaint.blocks.3.attn.fn.attn.out_proj.weight', 'token_inpaint.blocks.3.attn.fn.attn.out_proj.bias', 'token_inpaint.blocks.3.mlp.norm.weight', 'token_inpaint.blocks.3.mlp.norm.bias', 'token_inpaint.blocks.3.mlp.fn.fc1.weight', 'token_inpaint.blocks.3.mlp.fn.fc1.bias', 'token_inpaint.blocks.3.mlp.fn.fc2.weight', 'token_inpaint.blocks.3.mlp.fn.fc2.bias', 'token_inpaint.blocks.4.attn.norm.weight', 'token_inpaint.blocks.4.attn.norm.bias', 'token_inpaint.blocks.4.attn.fn.attn.in_proj_weight', 'token_inpaint.blocks.4.attn.fn.attn.in_proj_bias', 'token_inpaint.blocks.4.attn.fn.attn.out_proj.weight', 'token_inpaint.blocks.4.attn.fn.attn.out_proj.bias', 'token_inpaint.blocks.4.mlp.norm.weight', 'token_inpaint.blocks.4.mlp.norm.bias', 'token_inpaint.blocks.4.mlp.fn.fc1.weight', 'token_inpaint.blocks.4.mlp.fn.fc1.bias', 'token_inpaint.blocks.4.mlp.fn.fc2.weight', 'token_inpaint.blocks.4.mlp.fn.fc2.bias', 'token_inpaint.blocks.5.attn.norm.weight', 'token_inpaint.blocks.5.attn.norm.bias', 'token_inpaint.blocks.5.attn.fn.attn.in_proj_weight', 'token_inpaint.blocks.5.attn.fn.attn.in_proj_bias', 'token_inpaint.blocks.5.attn.fn.attn.out_proj.weight', 'token_inpaint.blocks.5.attn.fn.attn.out_proj.bias', 'token_inpaint.blocks.5.mlp.norm.weight', 'token_inpaint.blocks.5.mlp.norm.bias', 'token_inpaint.blocks.5.mlp.fn.fc1.weight', 'token_inpaint.blocks.5.mlp.fn.fc1.bias', 'token_inpaint.blocks.5.mlp.fn.fc2.weight', 'token_inpaint.blocks.5.mlp.fn.fc2.bias', 'token_inpaint.out_proj.weight', 'token_inpaint.out_proj.bias', 'token_inpaint._final_norm.weight', 'token_inpaint._final_norm.bias']\n"
253
- ]
254
- }
255
- ],
256
- "source": [
257
- "unreflectanythingmodel = unreflectanything.model(\n",
258
- " pretrained=True,\n",
259
- " config_path=\"huggingface/configs/pretrained_config.yaml\",\n",
260
- " weights_path=get_cache_dir(\"weights\") / \"full_model_weights.pt\",\n",
261
- ")"
262
- ]
263
- },
264
- {
265
- "cell_type": "code",
266
- "execution_count": 57,
267
  "id": "a130c042",
268
  "metadata": {},
269
  "outputs": [
270
  {
271
- "ename": "AttributeError",
272
- "evalue": "'dict' object has no attribute 'cpu'",
273
  "output_type": "error",
274
  "traceback": [
275
  "\u001b[31m---------------------------------------------------------------------------\u001b[39m",
276
- "\u001b[31mAttributeError\u001b[39m Traceback (most recent call last)",
277
- "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[57]\u001b[39m\u001b[32m, line 15\u001b[39m\n\u001b[32m 10\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m arr\n\u001b[32m 13\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m input_batch, output_batch \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mzip\u001b[39m(loader, output_images):\n\u001b[32m 14\u001b[39m concat_images = torch.cat(\n\u001b[32m---> \u001b[39m\u001b[32m15\u001b[39m [input_batch.cpu(), \u001b[43moutput_batch\u001b[49m\u001b[43m.\u001b[49m\u001b[43mcpu\u001b[49m()], dim=\u001b[32m3\u001b[39m\n\u001b[32m 16\u001b[39m ) \u001b[38;5;66;03m# (B, 3, H, 2W)\u001b[39;00m\n\u001b[32m 17\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m sample \u001b[38;5;129;01min\u001b[39;00m concat_images:\n\u001b[32m 18\u001b[39m img_uint8 = tensor_to_uint8_img(sample)\n",
278
- "\u001b[31mAttributeError\u001b[39m: 'dict' object has no attribute 'cpu'"
279
  ]
280
  }
281
  ],
 
65
  },
66
  {
67
  "cell_type": "code",
68
+ "execution_count": 6,
69
  "id": "d58ad7f1",
70
  "metadata": {},
71
  "outputs": [
72
  {
73
  "data": {
74
  "text/html": [
75
+ "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">MODEL <span style=\"font-weight: bold\">[</span><span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">18:45:03</span><span style=\"font-weight: bold\">]</span> ✓ Decoder <span style=\"color: #008000; text-decoration-color: #008000\">'diffuse'</span>: Successfully loaded all <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">54</span> state dict keys from weights/rgb_decoder.pth\n",
76
  "</pre>\n"
77
  ],
78
  "text/plain": [
79
+ "MODEL \u001b[1m[\u001b[0m\u001b[1;92m18:45:03\u001b[0m\u001b[1m]\u001b[0m ✓ Decoder \u001b[32m'diffuse'\u001b[0m: Successfully loaded all \u001b[1;36m54\u001b[0m state dict keys from weights/rgb_decoder.pth\n"
80
  ]
81
  },
82
  "metadata": {},
 
85
  {
86
  "data": {
87
  "text/html": [
88
+ "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">MODEL <span style=\"font-weight: bold\">[</span><span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">18:45:03</span><span style=\"font-weight: bold\">]</span> Loaded pre-trained decoder weights from weights/rgb_decoder.pth\n",
89
  "</pre>\n"
90
  ],
91
  "text/plain": [
92
+ "MODEL \u001b[1m[\u001b[0m\u001b[1;92m18:45:03\u001b[0m\u001b[1m]\u001b[0m Loaded pre-trained decoder weights from weights/rgb_decoder.pth\n"
93
  ]
94
  },
95
  "metadata": {},
 
98
  {
99
  "data": {
100
  "text/html": [
101
+ "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">MODEL <span style=\"font-weight: bold\">[</span><span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">18:45:03</span><span style=\"font-weight: bold\">]</span> ✓ Token Inpainter: Successfully loaded all <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">78</span> state dict keys from weights/token_inpainter.pth\n",
102
  "</pre>\n"
103
  ],
104
  "text/plain": [
105
+ "MODEL \u001b[1m[\u001b[0m\u001b[1;92m18:45:03\u001b[0m\u001b[1m]\u001b[0m ✓ Token Inpainter: Successfully loaded all \u001b[1;36m78\u001b[0m state dict keys from weights/token_inpainter.pth\n"
106
  ]
107
  },
108
  "metadata": {},
 
111
  {
112
  "data": {
113
  "text/html": [
114
+ "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">MODEL <span style=\"font-weight: bold\">[</span><span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">18:45:03</span><span style=\"font-weight: bold\">]</span> Loaded pretrained token inpainter weights from weights/token_inpainter.pth\n",
115
  "</pre>\n"
116
  ],
117
  "text/plain": [
118
+ "MODEL \u001b[1m[\u001b[0m\u001b[1;92m18:45:03\u001b[0m\u001b[1m]\u001b[0m Loaded pretrained token inpainter weights from weights/token_inpainter.pth\n"
119
  ]
120
  },
121
  "metadata": {},
 
157
  " \"images\"\n",
158
  ") # Modify this path to point to your image directory\n",
159
  "\n",
160
+ "ds = ImageDirDataset(PATH_TO_IMAGE_DIR, target_size=(448, 448), return_path=False)\n",
161
  "loader = DataLoader(ds, batch_size=1, shuffle=False)"
162
  ]
163
  },
 
171
  },
172
  {
173
  "cell_type": "code",
174
+ "execution_count": 8,
175
  "id": "34e01754",
176
  "metadata": {},
177
  "outputs": [],
 
189
  },
190
  {
191
  "cell_type": "code",
192
+ "execution_count": 9,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
193
  "id": "a130c042",
194
  "metadata": {},
195
  "outputs": [
196
  {
197
+ "ename": "RuntimeError",
198
+ "evalue": "Sizes of tensors must match except in dimension 3. Expected size 896 but got size 448 for tensor number 1 in the list.",
199
  "output_type": "error",
200
  "traceback": [
201
  "\u001b[31m---------------------------------------------------------------------------\u001b[39m",
202
+ "\u001b[31mRuntimeError\u001b[39m Traceback (most recent call last)",
203
+ "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[9]\u001b[39m\u001b[32m, line 14\u001b[39m\n\u001b[32m 10\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m arr\n\u001b[32m 13\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m input_batch, output_batch \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mzip\u001b[39m(loader, output_images):\n\u001b[32m---> \u001b[39m\u001b[32m14\u001b[39m concat_images = \u001b[43mtorch\u001b[49m\u001b[43m.\u001b[49m\u001b[43mcat\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 15\u001b[39m \u001b[43m \u001b[49m\u001b[43m[\u001b[49m\u001b[43minput_batch\u001b[49m\u001b[43m.\u001b[49m\u001b[43mcpu\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43moutput_batch\u001b[49m\u001b[43m.\u001b[49m\u001b[43mcpu\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdim\u001b[49m\u001b[43m=\u001b[49m\u001b[32;43m3\u001b[39;49m\n\u001b[32m 16\u001b[39m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;66;03m# (B, 3, H, 2W)\u001b[39;00m\n\u001b[32m 17\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m sample \u001b[38;5;129;01min\u001b[39;00m concat_images:\n\u001b[32m 18\u001b[39m img_uint8 = tensor_to_uint8_img(sample)\n",
204
+ "\u001b[31mRuntimeError\u001b[39m: Sizes of tensors must match except in dimension 3. Expected size 896 but got size 448 for tensor number 1 in the list."
205
  ]
206
  }
207
  ],