AlbeRota commited on
Commit
431d504
·
verified ·
1 Parent(s): db1444c

Upload weights, notebooks, sample images

Browse files
notebooks/UnReflectAnything.ipynb CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:0fc591caa8e1f251f75b0bd093de2f86535a079ef73e3afde0372871213cdaf2
3
- size 14524
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:18d11460be23d0a05546ea99774373739f3ce6a75148bea9b992548de852f3d4
3
+ size 15136
notebooks/api_examples.ipynb DELETED
@@ -1,253 +0,0 @@
1
- {
2
- "cells": [
3
- {
4
- "cell_type": "markdown",
5
- "id": "d5e78019",
6
- "metadata": {},
7
- "source": [
8
- "# UnReflectAnything API Examples\n",
9
- "---"
10
- ]
11
- },
12
- {
13
- "cell_type": "markdown",
14
- "id": "d423248d",
15
- "metadata": {},
16
- "source": [
17
- "### Package Import"
18
- ]
19
- },
20
- {
21
- "cell_type": "code",
22
- "execution_count": 1,
23
- "id": "db2eda79",
24
- "metadata": {},
25
- "outputs": [
26
- {
27
- "name": "stdout",
28
- "output_type": "stream",
29
- "text": [
30
- "Using device: cuda\n"
31
- ]
32
- }
33
- ],
34
- "source": [
35
- "import unreflectanything\n",
36
- "import torch\n",
37
- "\n",
38
- "%load_ext autoreload\n",
39
- "%autoreload 2\n",
40
- "\n",
41
- "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
42
- "print(f\"Using device: {device}\")"
43
- ]
44
- },
45
- {
46
- "cell_type": "markdown",
47
- "id": "c3828c5e",
48
- "metadata": {},
49
- "source": [
50
- "### Model Loading"
51
- ]
52
- },
53
- {
54
- "cell_type": "markdown",
55
- "id": "cabb1b8a",
56
- "metadata": {},
57
- "source": [
58
- "If you haven't downloaded the pre-trained weights yet, do so with \n",
59
- "\n",
60
- "`unreflectanything download --weights` from the terminal\n",
61
- "\n",
62
- "\n",
63
- "or with `unreflectanything.download(\"weights\")` from Python."
64
- ]
65
- },
66
- {
67
- "cell_type": "code",
68
- "execution_count": 6,
69
- "id": "d58ad7f1",
70
- "metadata": {},
71
- "outputs": [
72
- {
73
- "data": {
74
- "text/html": [
75
- "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">MODEL <span style=\"font-weight: bold\">[</span><span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">18:45:03</span><span style=\"font-weight: bold\">]</span> ✓ Decoder <span style=\"color: #008000; text-decoration-color: #008000\">'diffuse'</span>: Successfully loaded all <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">54</span> state dict keys from weights/rgb_decoder.pth\n",
76
- "</pre>\n"
77
- ],
78
- "text/plain": [
79
- "MODEL \u001b[1m[\u001b[0m\u001b[1;92m18:45:03\u001b[0m\u001b[1m]\u001b[0m ✓ Decoder \u001b[32m'diffuse'\u001b[0m: Successfully loaded all \u001b[1;36m54\u001b[0m state dict keys from weights/rgb_decoder.pth\n"
80
- ]
81
- },
82
- "metadata": {},
83
- "output_type": "display_data"
84
- },
85
- {
86
- "data": {
87
- "text/html": [
88
- "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">MODEL <span style=\"font-weight: bold\">[</span><span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">18:45:03</span><span style=\"font-weight: bold\">]</span> Loaded pre-trained decoder weights from weights/rgb_decoder.pth\n",
89
- "</pre>\n"
90
- ],
91
- "text/plain": [
92
- "MODEL \u001b[1m[\u001b[0m\u001b[1;92m18:45:03\u001b[0m\u001b[1m]\u001b[0m Loaded pre-trained decoder weights from weights/rgb_decoder.pth\n"
93
- ]
94
- },
95
- "metadata": {},
96
- "output_type": "display_data"
97
- },
98
- {
99
- "data": {
100
- "text/html": [
101
- "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">MODEL <span style=\"font-weight: bold\">[</span><span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">18:45:03</span><span style=\"font-weight: bold\">]</span> ✓ Token Inpainter: Successfully loaded all <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">78</span> state dict keys from weights/token_inpainter.pth\n",
102
- "</pre>\n"
103
- ],
104
- "text/plain": [
105
- "MODEL \u001b[1m[\u001b[0m\u001b[1;92m18:45:03\u001b[0m\u001b[1m]\u001b[0m ✓ Token Inpainter: Successfully loaded all \u001b[1;36m78\u001b[0m state dict keys from weights/token_inpainter.pth\n"
106
- ]
107
- },
108
- "metadata": {},
109
- "output_type": "display_data"
110
- },
111
- {
112
- "data": {
113
- "text/html": [
114
- "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">MODEL <span style=\"font-weight: bold\">[</span><span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">18:45:03</span><span style=\"font-weight: bold\">]</span> Loaded pretrained token inpainter weights from weights/token_inpainter.pth\n",
115
- "</pre>\n"
116
- ],
117
- "text/plain": [
118
- "MODEL \u001b[1m[\u001b[0m\u001b[1;92m18:45:03\u001b[0m\u001b[1m]\u001b[0m Loaded pretrained token inpainter weights from weights/token_inpainter.pth\n"
119
- ]
120
- },
121
- "metadata": {},
122
- "output_type": "display_data"
123
- },
124
- {
125
- "name": "stdout",
126
- "output_type": "stream",
127
- "text": [
128
- "Warning: missing keys when loading checkpoint: ['decoders.highlight.reassemble_layers.0.proj.weight', 'decoders.highlight.reassemble_layers.0.proj.bias', 'decoders.highlight.reassemble_layers.0.resample.weight', 'decoders.highlight.reassemble_layers.0.resample.bias', 'decoders.highlight.reassemble_layers.1.proj.weight', 'decoders.highlight.reassemble_layers.1.proj.bias', 'decoders.highlight.reassemble_layers.1.resample.weight', 'decoders.highlight.reassemble_layers.1.resample.bias', 'decoders.highlight.reassemble_layers.2.proj.weight', 'decoders.highlight.reassemble_layers.2.proj.bias', 'decoders.highlight.reassemble_layers.3.proj.weight', 'decoders.highlight.reassemble_layers.3.proj.bias', 'decoders.highlight.reassemble_layers.3.resample.weight', 'decoders.highlight.reassemble_layers.3.resample.bias', 'decoders.highlight.fusion_blocks.0.residual_conv1.weight', 'decoders.highlight.fusion_blocks.0.residual_conv1.bias', 'decoders.highlight.fusion_blocks.0.residual_conv2.0.weight', 'decoders.highlight.fusion_blocks.0.residual_conv2.0.bias', 'decoders.highlight.fusion_blocks.0.residual_conv2.3.weight', 'decoders.highlight.fusion_blocks.0.residual_conv2.3.bias', 'decoders.highlight.fusion_blocks.0.out_conv.weight', 'decoders.highlight.fusion_blocks.0.out_conv.bias', 'decoders.highlight.fusion_blocks.1.residual_conv1.weight', 'decoders.highlight.fusion_blocks.1.residual_conv1.bias', 'decoders.highlight.fusion_blocks.1.residual_conv2.0.weight', 'decoders.highlight.fusion_blocks.1.residual_conv2.0.bias', 'decoders.highlight.fusion_blocks.1.residual_conv2.3.weight', 'decoders.highlight.fusion_blocks.1.residual_conv2.3.bias', 'decoders.highlight.fusion_blocks.1.out_conv.weight', 'decoders.highlight.fusion_blocks.1.out_conv.bias', 'decoders.highlight.fusion_blocks.2.residual_conv1.weight', 'decoders.highlight.fusion_blocks.2.residual_conv1.bias', 'decoders.highlight.fusion_blocks.2.residual_conv2.0.weight', 'decoders.highlight.fusion_blocks.2.residual_conv2.0.bias', 'decoders.highlight.fusion_blocks.2.residual_conv2.3.weight', 'decoders.highlight.fusion_blocks.2.residual_conv2.3.bias', 'decoders.highlight.fusion_blocks.2.out_conv.weight', 'decoders.highlight.fusion_blocks.2.out_conv.bias', 'decoders.highlight.fusion_blocks.3.residual_conv1.weight', 'decoders.highlight.fusion_blocks.3.residual_conv1.bias', 'decoders.highlight.fusion_blocks.3.residual_conv2.0.weight', 'decoders.highlight.fusion_blocks.3.residual_conv2.0.bias', 'decoders.highlight.fusion_blocks.3.residual_conv2.3.weight', 'decoders.highlight.fusion_blocks.3.residual_conv2.3.bias', 'decoders.highlight.fusion_blocks.3.out_conv.weight', 'decoders.highlight.fusion_blocks.3.out_conv.bias', 'decoders.highlight.rgb_head.0.weight', 'decoders.highlight.rgb_head.0.bias', 'decoders.highlight.rgb_head.5.weight', 'decoders.highlight.rgb_head.5.bias', 'decoders.highlight.rgb_head.9.weight', 'decoders.highlight.rgb_head.9.bias', 'decoders.highlight.rgb_head.13.weight', 'decoders.highlight.rgb_head.13.bias', 'token_inpaint.mask_token', 'token_inpaint.mask_indicator', 'token_inpaint.blocks.0.attn.norm.weight', 'token_inpaint.blocks.0.attn.norm.bias', 'token_inpaint.blocks.0.attn.fn.attn.in_proj_weight', 'token_inpaint.blocks.0.attn.fn.attn.in_proj_bias', 'token_inpaint.blocks.0.attn.fn.attn.out_proj.weight', 'token_inpaint.blocks.0.attn.fn.attn.out_proj.bias', 'token_inpaint.blocks.0.mlp.norm.weight', 'token_inpaint.blocks.0.mlp.norm.bias', 'token_inpaint.blocks.0.mlp.fn.fc1.weight', 'token_inpaint.blocks.0.mlp.fn.fc1.bias', 'token_inpaint.blocks.0.mlp.fn.fc2.weight', 'token_inpaint.blocks.0.mlp.fn.fc2.bias', 'token_inpaint.blocks.1.attn.norm.weight', 'token_inpaint.blocks.1.attn.norm.bias', 'token_inpaint.blocks.1.attn.fn.attn.in_proj_weight', 'token_inpaint.blocks.1.attn.fn.attn.in_proj_bias', 'token_inpaint.blocks.1.attn.fn.attn.out_proj.weight', 'token_inpaint.blocks.1.attn.fn.attn.out_proj.bias', 'token_inpaint.blocks.1.mlp.norm.weight', 'token_inpaint.blocks.1.mlp.norm.bias', 'token_inpaint.blocks.1.mlp.fn.fc1.weight', 'token_inpaint.blocks.1.mlp.fn.fc1.bias', 'token_inpaint.blocks.1.mlp.fn.fc2.weight', 'token_inpaint.blocks.1.mlp.fn.fc2.bias', 'token_inpaint.blocks.2.attn.norm.weight', 'token_inpaint.blocks.2.attn.norm.bias', 'token_inpaint.blocks.2.attn.fn.attn.in_proj_weight', 'token_inpaint.blocks.2.attn.fn.attn.in_proj_bias', 'token_inpaint.blocks.2.attn.fn.attn.out_proj.weight', 'token_inpaint.blocks.2.attn.fn.attn.out_proj.bias', 'token_inpaint.blocks.2.mlp.norm.weight', 'token_inpaint.blocks.2.mlp.norm.bias', 'token_inpaint.blocks.2.mlp.fn.fc1.weight', 'token_inpaint.blocks.2.mlp.fn.fc1.bias', 'token_inpaint.blocks.2.mlp.fn.fc2.weight', 'token_inpaint.blocks.2.mlp.fn.fc2.bias', 'token_inpaint.blocks.3.attn.norm.weight', 'token_inpaint.blocks.3.attn.norm.bias', 'token_inpaint.blocks.3.attn.fn.attn.in_proj_weight', 'token_inpaint.blocks.3.attn.fn.attn.in_proj_bias', 'token_inpaint.blocks.3.attn.fn.attn.out_proj.weight', 'token_inpaint.blocks.3.attn.fn.attn.out_proj.bias', 'token_inpaint.blocks.3.mlp.norm.weight', 'token_inpaint.blocks.3.mlp.norm.bias', 'token_inpaint.blocks.3.mlp.fn.fc1.weight', 'token_inpaint.blocks.3.mlp.fn.fc1.bias', 'token_inpaint.blocks.3.mlp.fn.fc2.weight', 'token_inpaint.blocks.3.mlp.fn.fc2.bias', 'token_inpaint.blocks.4.attn.norm.weight', 'token_inpaint.blocks.4.attn.norm.bias', 'token_inpaint.blocks.4.attn.fn.attn.in_proj_weight', 'token_inpaint.blocks.4.attn.fn.attn.in_proj_bias', 'token_inpaint.blocks.4.attn.fn.attn.out_proj.weight', 'token_inpaint.blocks.4.attn.fn.attn.out_proj.bias', 'token_inpaint.blocks.4.mlp.norm.weight', 'token_inpaint.blocks.4.mlp.norm.bias', 'token_inpaint.blocks.4.mlp.fn.fc1.weight', 'token_inpaint.blocks.4.mlp.fn.fc1.bias', 'token_inpaint.blocks.4.mlp.fn.fc2.weight', 'token_inpaint.blocks.4.mlp.fn.fc2.bias', 'token_inpaint.blocks.5.attn.norm.weight', 'token_inpaint.blocks.5.attn.norm.bias', 'token_inpaint.blocks.5.attn.fn.attn.in_proj_weight', 'token_inpaint.blocks.5.attn.fn.attn.in_proj_bias', 'token_inpaint.blocks.5.attn.fn.attn.out_proj.weight', 'token_inpaint.blocks.5.attn.fn.attn.out_proj.bias', 'token_inpaint.blocks.5.mlp.norm.weight', 'token_inpaint.blocks.5.mlp.norm.bias', 'token_inpaint.blocks.5.mlp.fn.fc1.weight', 'token_inpaint.blocks.5.mlp.fn.fc1.bias', 'token_inpaint.blocks.5.mlp.fn.fc2.weight', 'token_inpaint.blocks.5.mlp.fn.fc2.bias', 'token_inpaint.out_proj.weight', 'token_inpaint.out_proj.bias', 'token_inpaint._final_norm.weight', 'token_inpaint._final_norm.bias']\n"
129
- ]
130
- }
131
- ],
132
- "source": [
133
- "# unreflectanything.download(\"weights\")\n",
134
- "# unreflectanything.download(\"images\") # --> Loads 20 sample images\n",
135
- "unreflectanythingmodel = unreflectanything.model(pretrained=True)"
136
- ]
137
- },
138
- {
139
- "cell_type": "markdown",
140
- "id": "f3dfa889",
141
- "metadata": {},
142
- "source": [
143
- "Load a dataset of images. Change `PATH_TO_IMAGE_DIR` to point to your own image directory"
144
- ]
145
- },
146
- {
147
- "cell_type": "code",
148
- "execution_count": null,
149
- "id": "da39fa39",
150
- "metadata": {},
151
- "outputs": [],
152
- "source": [
153
- "from unreflectanything import ImageDirDataset, get_cache_dir\n",
154
- "from torch.utils.data import DataLoader\n",
155
- "\n",
156
- "PATH_TO_IMAGE_DIR = get_cache_dir(\n",
157
- " \"images\"\n",
158
- ") # Modify this path to point to your image directory\n",
159
- "\n",
160
- "ds = ImageDirDataset(PATH_TO_IMAGE_DIR, target_size=(448, 448), return_path=False)\n",
161
- "loader = DataLoader(ds, batch_size=1, shuffle=False)"
162
- ]
163
- },
164
- {
165
- "cell_type": "markdown",
166
- "id": "4c8312f0",
167
- "metadata": {},
168
- "source": [
169
- "### Forward Pass / Inference"
170
- ]
171
- },
172
- {
173
- "cell_type": "code",
174
- "execution_count": 8,
175
- "id": "34e01754",
176
- "metadata": {},
177
- "outputs": [],
178
- "source": [
179
- "output_images = [unreflectanythingmodel(batch_images) for batch_images in loader]"
180
- ]
181
- },
182
- {
183
- "cell_type": "markdown",
184
- "id": "94690751",
185
- "metadata": {},
186
- "source": [
187
- "### Displaying results"
188
- ]
189
- },
190
- {
191
- "cell_type": "code",
192
- "execution_count": 9,
193
- "id": "a130c042",
194
- "metadata": {},
195
- "outputs": [
196
- {
197
- "ename": "RuntimeError",
198
- "evalue": "Sizes of tensors must match except in dimension 3. Expected size 896 but got size 448 for tensor number 1 in the list.",
199
- "output_type": "error",
200
- "traceback": [
201
- "\u001b[31m---------------------------------------------------------------------------\u001b[39m",
202
- "\u001b[31mRuntimeError\u001b[39m Traceback (most recent call last)",
203
- "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[9]\u001b[39m\u001b[32m, line 14\u001b[39m\n\u001b[32m 10\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m arr\n\u001b[32m 13\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m input_batch, output_batch \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mzip\u001b[39m(loader, output_images):\n\u001b[32m---> \u001b[39m\u001b[32m14\u001b[39m concat_images = \u001b[43mtorch\u001b[49m\u001b[43m.\u001b[49m\u001b[43mcat\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 15\u001b[39m \u001b[43m \u001b[49m\u001b[43m[\u001b[49m\u001b[43minput_batch\u001b[49m\u001b[43m.\u001b[49m\u001b[43mcpu\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43moutput_batch\u001b[49m\u001b[43m.\u001b[49m\u001b[43mcpu\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdim\u001b[49m\u001b[43m=\u001b[49m\u001b[32;43m3\u001b[39;49m\n\u001b[32m 16\u001b[39m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;66;03m# (B, 3, H, 2W)\u001b[39;00m\n\u001b[32m 17\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m sample \u001b[38;5;129;01min\u001b[39;00m concat_images:\n\u001b[32m 18\u001b[39m img_uint8 = tensor_to_uint8_img(sample)\n",
204
- "\u001b[31mRuntimeError\u001b[39m: Sizes of tensors must match except in dimension 3. Expected size 896 but got size 448 for tensor number 1 in the list."
205
- ]
206
- }
207
- ],
208
- "source": [
209
- "from PIL import Image\n",
210
- "import numpy as np\n",
211
- "\n",
212
- "\n",
213
- "# Helper: Convert tensor [H, W, C] in [0,1] float32 to uint8\n",
214
- "def tensor_to_uint8_img(t):\n",
215
- " arr = t.permute(1, 2, 0).cpu().detach().numpy()\n",
216
- " arr = np.clip(arr, 0, 1)\n",
217
- " arr = (arr * 255).round().astype(np.uint8)\n",
218
- " return arr\n",
219
- "\n",
220
- "\n",
221
- "for input_batch, output_batch in zip(loader, output_images):\n",
222
- " concat_images = torch.cat(\n",
223
- " [input_batch.cpu(), output_batch.cpu()], dim=3\n",
224
- " ) # (B, 3, H, 2W)\n",
225
- " for sample in concat_images:\n",
226
- " img_uint8 = tensor_to_uint8_img(sample)\n",
227
- " display(Image.fromarray(img_uint8))\n",
228
- " break\n"
229
- ]
230
- }
231
- ],
232
- "metadata": {
233
- "kernelspec": {
234
- "display_name": "Python 3 (ipykernel)",
235
- "language": "python",
236
- "name": "python3"
237
- },
238
- "language_info": {
239
- "codemirror_mode": {
240
- "name": "ipython",
241
- "version": 3
242
- },
243
- "file_extension": ".py",
244
- "mimetype": "text/x-python",
245
- "name": "python",
246
- "nbconvert_exporter": "python",
247
- "pygments_lexer": "ipython3",
248
- "version": "3.12.11"
249
- }
250
- },
251
- "nbformat": 4,
252
- "nbformat_minor": 5
253
- }