AlbeRota commited on
Commit
656765a
·
verified ·
1 Parent(s): b67700d

Upload weights, notebooks, sample images

Browse files
notebooks/UnReflectAnything.ipynb CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:c58d9f13fa5df97d73d4b4769062d01c0ec034ff0e3237dae715292c72161bb7
3
- size 14063
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0fc591caa8e1f251f75b0bd093de2f86535a079ef73e3afde0372871213cdaf2
3
+ size 14524
notebooks/api_examples.ipynb ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "d5e78019",
6
+ "metadata": {},
7
+ "source": [
8
+ "# UnReflectAnything API Examples\n",
9
+ "---"
10
+ ]
11
+ },
12
+ {
13
+ "cell_type": "markdown",
14
+ "id": "d423248d",
15
+ "metadata": {},
16
+ "source": [
17
+ "### Package Import"
18
+ ]
19
+ },
20
+ {
21
+ "cell_type": "code",
22
+ "execution_count": 1,
23
+ "id": "db2eda79",
24
+ "metadata": {},
25
+ "outputs": [
26
+ {
27
+ "name": "stdout",
28
+ "output_type": "stream",
29
+ "text": [
30
+ "Using device: cuda\n"
31
+ ]
32
+ }
33
+ ],
34
+ "source": [
35
+ "import unreflectanything\n",
36
+ "import torch\n",
37
+ "\n",
38
+ "%load_ext autoreload\n",
39
+ "%autoreload 2\n",
40
+ "\n",
41
+ "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
42
+ "print(f\"Using device: {device}\")"
43
+ ]
44
+ },
45
+ {
46
+ "cell_type": "markdown",
47
+ "id": "c3828c5e",
48
+ "metadata": {},
49
+ "source": [
50
+ "### Model Loading"
51
+ ]
52
+ },
53
+ {
54
+ "cell_type": "markdown",
55
+ "id": "cabb1b8a",
56
+ "metadata": {},
57
+ "source": [
58
+ "If you haven't downloaded the pre-trained weights yet, do so with \n",
59
+ "\n",
60
+ "`unreflectanything download --weights` from the terminal\n",
61
+ "\n",
62
+ "\n",
63
+ "or with `unreflectanything.download(\"weights\")` from Python."
64
+ ]
65
+ },
66
+ {
67
+ "cell_type": "code",
68
+ "execution_count": 6,
69
+ "id": "d58ad7f1",
70
+ "metadata": {},
71
+ "outputs": [
72
+ {
73
+ "data": {
74
+ "text/html": [
75
+ "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">MODEL <span style=\"font-weight: bold\">[</span><span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">18:45:03</span><span style=\"font-weight: bold\">]</span> ✓ Decoder <span style=\"color: #008000; text-decoration-color: #008000\">'diffuse'</span>: Successfully loaded all <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">54</span> state dict keys from weights/rgb_decoder.pth\n",
76
+ "</pre>\n"
77
+ ],
78
+ "text/plain": [
79
+ "MODEL \u001b[1m[\u001b[0m\u001b[1;92m18:45:03\u001b[0m\u001b[1m]\u001b[0m ✓ Decoder \u001b[32m'diffuse'\u001b[0m: Successfully loaded all \u001b[1;36m54\u001b[0m state dict keys from weights/rgb_decoder.pth\n"
80
+ ]
81
+ },
82
+ "metadata": {},
83
+ "output_type": "display_data"
84
+ },
85
+ {
86
+ "data": {
87
+ "text/html": [
88
+ "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">MODEL <span style=\"font-weight: bold\">[</span><span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">18:45:03</span><span style=\"font-weight: bold\">]</span> Loaded pre-trained decoder weights from weights/rgb_decoder.pth\n",
89
+ "</pre>\n"
90
+ ],
91
+ "text/plain": [
92
+ "MODEL \u001b[1m[\u001b[0m\u001b[1;92m18:45:03\u001b[0m\u001b[1m]\u001b[0m Loaded pre-trained decoder weights from weights/rgb_decoder.pth\n"
93
+ ]
94
+ },
95
+ "metadata": {},
96
+ "output_type": "display_data"
97
+ },
98
+ {
99
+ "data": {
100
+ "text/html": [
101
+ "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">MODEL <span style=\"font-weight: bold\">[</span><span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">18:45:03</span><span style=\"font-weight: bold\">]</span> ✓ Token Inpainter: Successfully loaded all <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">78</span> state dict keys from weights/token_inpainter.pth\n",
102
+ "</pre>\n"
103
+ ],
104
+ "text/plain": [
105
+ "MODEL \u001b[1m[\u001b[0m\u001b[1;92m18:45:03\u001b[0m\u001b[1m]\u001b[0m ✓ Token Inpainter: Successfully loaded all \u001b[1;36m78\u001b[0m state dict keys from weights/token_inpainter.pth\n"
106
+ ]
107
+ },
108
+ "metadata": {},
109
+ "output_type": "display_data"
110
+ },
111
+ {
112
+ "data": {
113
+ "text/html": [
114
+ "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">MODEL <span style=\"font-weight: bold\">[</span><span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">18:45:03</span><span style=\"font-weight: bold\">]</span> Loaded pretrained token inpainter weights from weights/token_inpainter.pth\n",
115
+ "</pre>\n"
116
+ ],
117
+ "text/plain": [
118
+ "MODEL \u001b[1m[\u001b[0m\u001b[1;92m18:45:03\u001b[0m\u001b[1m]\u001b[0m Loaded pretrained token inpainter weights from weights/token_inpainter.pth\n"
119
+ ]
120
+ },
121
+ "metadata": {},
122
+ "output_type": "display_data"
123
+ },
124
+ {
125
+ "name": "stdout",
126
+ "output_type": "stream",
127
+ "text": [
128
+ "Warning: missing keys when loading checkpoint: ['decoders.highlight.reassemble_layers.0.proj.weight', 'decoders.highlight.reassemble_layers.0.proj.bias', 'decoders.highlight.reassemble_layers.0.resample.weight', 'decoders.highlight.reassemble_layers.0.resample.bias', 'decoders.highlight.reassemble_layers.1.proj.weight', 'decoders.highlight.reassemble_layers.1.proj.bias', 'decoders.highlight.reassemble_layers.1.resample.weight', 'decoders.highlight.reassemble_layers.1.resample.bias', 'decoders.highlight.reassemble_layers.2.proj.weight', 'decoders.highlight.reassemble_layers.2.proj.bias', 'decoders.highlight.reassemble_layers.3.proj.weight', 'decoders.highlight.reassemble_layers.3.proj.bias', 'decoders.highlight.reassemble_layers.3.resample.weight', 'decoders.highlight.reassemble_layers.3.resample.bias', 'decoders.highlight.fusion_blocks.0.residual_conv1.weight', 'decoders.highlight.fusion_blocks.0.residual_conv1.bias', 'decoders.highlight.fusion_blocks.0.residual_conv2.0.weight', 'decoders.highlight.fusion_blocks.0.residual_conv2.0.bias', 'decoders.highlight.fusion_blocks.0.residual_conv2.3.weight', 'decoders.highlight.fusion_blocks.0.residual_conv2.3.bias', 'decoders.highlight.fusion_blocks.0.out_conv.weight', 'decoders.highlight.fusion_blocks.0.out_conv.bias', 'decoders.highlight.fusion_blocks.1.residual_conv1.weight', 'decoders.highlight.fusion_blocks.1.residual_conv1.bias', 'decoders.highlight.fusion_blocks.1.residual_conv2.0.weight', 'decoders.highlight.fusion_blocks.1.residual_conv2.0.bias', 'decoders.highlight.fusion_blocks.1.residual_conv2.3.weight', 'decoders.highlight.fusion_blocks.1.residual_conv2.3.bias', 'decoders.highlight.fusion_blocks.1.out_conv.weight', 'decoders.highlight.fusion_blocks.1.out_conv.bias', 'decoders.highlight.fusion_blocks.2.residual_conv1.weight', 'decoders.highlight.fusion_blocks.2.residual_conv1.bias', 'decoders.highlight.fusion_blocks.2.residual_conv2.0.weight', 'decoders.highlight.fusion_blocks.2.residual_conv2.0.bias', 'decoders.highlight.fusion_blocks.2.residual_conv2.3.weight', 'decoders.highlight.fusion_blocks.2.residual_conv2.3.bias', 'decoders.highlight.fusion_blocks.2.out_conv.weight', 'decoders.highlight.fusion_blocks.2.out_conv.bias', 'decoders.highlight.fusion_blocks.3.residual_conv1.weight', 'decoders.highlight.fusion_blocks.3.residual_conv1.bias', 'decoders.highlight.fusion_blocks.3.residual_conv2.0.weight', 'decoders.highlight.fusion_blocks.3.residual_conv2.0.bias', 'decoders.highlight.fusion_blocks.3.residual_conv2.3.weight', 'decoders.highlight.fusion_blocks.3.residual_conv2.3.bias', 'decoders.highlight.fusion_blocks.3.out_conv.weight', 'decoders.highlight.fusion_blocks.3.out_conv.bias', 'decoders.highlight.rgb_head.0.weight', 'decoders.highlight.rgb_head.0.bias', 'decoders.highlight.rgb_head.5.weight', 'decoders.highlight.rgb_head.5.bias', 'decoders.highlight.rgb_head.9.weight', 'decoders.highlight.rgb_head.9.bias', 'decoders.highlight.rgb_head.13.weight', 'decoders.highlight.rgb_head.13.bias', 'token_inpaint.mask_token', 'token_inpaint.mask_indicator', 'token_inpaint.blocks.0.attn.norm.weight', 'token_inpaint.blocks.0.attn.norm.bias', 'token_inpaint.blocks.0.attn.fn.attn.in_proj_weight', 'token_inpaint.blocks.0.attn.fn.attn.in_proj_bias', 'token_inpaint.blocks.0.attn.fn.attn.out_proj.weight', 'token_inpaint.blocks.0.attn.fn.attn.out_proj.bias', 'token_inpaint.blocks.0.mlp.norm.weight', 'token_inpaint.blocks.0.mlp.norm.bias', 'token_inpaint.blocks.0.mlp.fn.fc1.weight', 'token_inpaint.blocks.0.mlp.fn.fc1.bias', 'token_inpaint.blocks.0.mlp.fn.fc2.weight', 'token_inpaint.blocks.0.mlp.fn.fc2.bias', 'token_inpaint.blocks.1.attn.norm.weight', 'token_inpaint.blocks.1.attn.norm.bias', 'token_inpaint.blocks.1.attn.fn.attn.in_proj_weight', 'token_inpaint.blocks.1.attn.fn.attn.in_proj_bias', 'token_inpaint.blocks.1.attn.fn.attn.out_proj.weight', 'token_inpaint.blocks.1.attn.fn.attn.out_proj.bias', 'token_inpaint.blocks.1.mlp.norm.weight', 'token_inpaint.blocks.1.mlp.norm.bias', 'token_inpaint.blocks.1.mlp.fn.fc1.weight', 'token_inpaint.blocks.1.mlp.fn.fc1.bias', 'token_inpaint.blocks.1.mlp.fn.fc2.weight', 'token_inpaint.blocks.1.mlp.fn.fc2.bias', 'token_inpaint.blocks.2.attn.norm.weight', 'token_inpaint.blocks.2.attn.norm.bias', 'token_inpaint.blocks.2.attn.fn.attn.in_proj_weight', 'token_inpaint.blocks.2.attn.fn.attn.in_proj_bias', 'token_inpaint.blocks.2.attn.fn.attn.out_proj.weight', 'token_inpaint.blocks.2.attn.fn.attn.out_proj.bias', 'token_inpaint.blocks.2.mlp.norm.weight', 'token_inpaint.blocks.2.mlp.norm.bias', 'token_inpaint.blocks.2.mlp.fn.fc1.weight', 'token_inpaint.blocks.2.mlp.fn.fc1.bias', 'token_inpaint.blocks.2.mlp.fn.fc2.weight', 'token_inpaint.blocks.2.mlp.fn.fc2.bias', 'token_inpaint.blocks.3.attn.norm.weight', 'token_inpaint.blocks.3.attn.norm.bias', 'token_inpaint.blocks.3.attn.fn.attn.in_proj_weight', 'token_inpaint.blocks.3.attn.fn.attn.in_proj_bias', 'token_inpaint.blocks.3.attn.fn.attn.out_proj.weight', 'token_inpaint.blocks.3.attn.fn.attn.out_proj.bias', 'token_inpaint.blocks.3.mlp.norm.weight', 'token_inpaint.blocks.3.mlp.norm.bias', 'token_inpaint.blocks.3.mlp.fn.fc1.weight', 'token_inpaint.blocks.3.mlp.fn.fc1.bias', 'token_inpaint.blocks.3.mlp.fn.fc2.weight', 'token_inpaint.blocks.3.mlp.fn.fc2.bias', 'token_inpaint.blocks.4.attn.norm.weight', 'token_inpaint.blocks.4.attn.norm.bias', 'token_inpaint.blocks.4.attn.fn.attn.in_proj_weight', 'token_inpaint.blocks.4.attn.fn.attn.in_proj_bias', 'token_inpaint.blocks.4.attn.fn.attn.out_proj.weight', 'token_inpaint.blocks.4.attn.fn.attn.out_proj.bias', 'token_inpaint.blocks.4.mlp.norm.weight', 'token_inpaint.blocks.4.mlp.norm.bias', 'token_inpaint.blocks.4.mlp.fn.fc1.weight', 'token_inpaint.blocks.4.mlp.fn.fc1.bias', 'token_inpaint.blocks.4.mlp.fn.fc2.weight', 'token_inpaint.blocks.4.mlp.fn.fc2.bias', 'token_inpaint.blocks.5.attn.norm.weight', 'token_inpaint.blocks.5.attn.norm.bias', 'token_inpaint.blocks.5.attn.fn.attn.in_proj_weight', 'token_inpaint.blocks.5.attn.fn.attn.in_proj_bias', 'token_inpaint.blocks.5.attn.fn.attn.out_proj.weight', 'token_inpaint.blocks.5.attn.fn.attn.out_proj.bias', 'token_inpaint.blocks.5.mlp.norm.weight', 'token_inpaint.blocks.5.mlp.norm.bias', 'token_inpaint.blocks.5.mlp.fn.fc1.weight', 'token_inpaint.blocks.5.mlp.fn.fc1.bias', 'token_inpaint.blocks.5.mlp.fn.fc2.weight', 'token_inpaint.blocks.5.mlp.fn.fc2.bias', 'token_inpaint.out_proj.weight', 'token_inpaint.out_proj.bias', 'token_inpaint._final_norm.weight', 'token_inpaint._final_norm.bias']\n"
129
+ ]
130
+ }
131
+ ],
132
+ "source": [
133
+ "# unreflectanything.download(\"weights\")\n",
134
+ "# unreflectanything.download(\"images\") # --> Loads 20 sample images\n",
135
+ "unreflectanythingmodel = unreflectanything.model(pretrained=True)"
136
+ ]
137
+ },
138
+ {
139
+ "cell_type": "markdown",
140
+ "id": "f3dfa889",
141
+ "metadata": {},
142
+ "source": [
143
+ "Load a dataset of images. Change `PATH_TO_IMAGE_DIR` to point to your own image directory"
144
+ ]
145
+ },
146
+ {
147
+ "cell_type": "code",
148
+ "execution_count": null,
149
+ "id": "da39fa39",
150
+ "metadata": {},
151
+ "outputs": [],
152
+ "source": [
153
+ "from unreflectanything import ImageDirDataset, get_cache_dir\n",
154
+ "from torch.utils.data import DataLoader\n",
155
+ "\n",
156
+ "PATH_TO_IMAGE_DIR = get_cache_dir(\n",
157
+ " \"images\"\n",
158
+ ") # Modify this path to point to your image directory\n",
159
+ "\n",
160
+ "ds = ImageDirDataset(PATH_TO_IMAGE_DIR, target_size=(448, 448), return_path=False)\n",
161
+ "loader = DataLoader(ds, batch_size=1, shuffle=False)"
162
+ ]
163
+ },
164
+ {
165
+ "cell_type": "markdown",
166
+ "id": "4c8312f0",
167
+ "metadata": {},
168
+ "source": [
169
+ "### Forward Pass / Inference"
170
+ ]
171
+ },
172
+ {
173
+ "cell_type": "code",
174
+ "execution_count": 8,
175
+ "id": "34e01754",
176
+ "metadata": {},
177
+ "outputs": [],
178
+ "source": [
179
+ "output_images = [unreflectanythingmodel(batch_images) for batch_images in loader]"
180
+ ]
181
+ },
182
+ {
183
+ "cell_type": "markdown",
184
+ "id": "94690751",
185
+ "metadata": {},
186
+ "source": [
187
+ "### Displaying results"
188
+ ]
189
+ },
190
+ {
191
+ "cell_type": "code",
192
+ "execution_count": 9,
193
+ "id": "a130c042",
194
+ "metadata": {},
195
+ "outputs": [
196
+ {
197
+ "ename": "RuntimeError",
198
+ "evalue": "Sizes of tensors must match except in dimension 3. Expected size 896 but got size 448 for tensor number 1 in the list.",
199
+ "output_type": "error",
200
+ "traceback": [
201
+ "\u001b[31m---------------------------------------------------------------------------\u001b[39m",
202
+ "\u001b[31mRuntimeError\u001b[39m Traceback (most recent call last)",
203
+ "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[9]\u001b[39m\u001b[32m, line 14\u001b[39m\n\u001b[32m 10\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m arr\n\u001b[32m 13\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m input_batch, output_batch \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mzip\u001b[39m(loader, output_images):\n\u001b[32m---> \u001b[39m\u001b[32m14\u001b[39m concat_images = \u001b[43mtorch\u001b[49m\u001b[43m.\u001b[49m\u001b[43mcat\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 15\u001b[39m \u001b[43m \u001b[49m\u001b[43m[\u001b[49m\u001b[43minput_batch\u001b[49m\u001b[43m.\u001b[49m\u001b[43mcpu\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43moutput_batch\u001b[49m\u001b[43m.\u001b[49m\u001b[43mcpu\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdim\u001b[49m\u001b[43m=\u001b[49m\u001b[32;43m3\u001b[39;49m\n\u001b[32m 16\u001b[39m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;66;03m# (B, 3, H, 2W)\u001b[39;00m\n\u001b[32m 17\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m sample \u001b[38;5;129;01min\u001b[39;00m concat_images:\n\u001b[32m 18\u001b[39m img_uint8 = tensor_to_uint8_img(sample)\n",
204
+ "\u001b[31mRuntimeError\u001b[39m: Sizes of tensors must match except in dimension 3. Expected size 896 but got size 448 for tensor number 1 in the list."
205
+ ]
206
+ }
207
+ ],
208
+ "source": [
209
+ "from PIL import Image\n",
210
+ "import numpy as np\n",
211
+ "\n",
212
+ "\n",
213
+ "# Helper: Convert tensor [H, W, C] in [0,1] float32 to uint8\n",
214
+ "def tensor_to_uint8_img(t):\n",
215
+ " arr = t.permute(1, 2, 0).cpu().detach().numpy()\n",
216
+ " arr = np.clip(arr, 0, 1)\n",
217
+ " arr = (arr * 255).round().astype(np.uint8)\n",
218
+ " return arr\n",
219
+ "\n",
220
+ "\n",
221
+ "for input_batch, output_batch in zip(loader, output_images):\n",
222
+ " concat_images = torch.cat(\n",
223
+ " [input_batch.cpu(), output_batch.cpu()], dim=3\n",
224
+ " ) # (B, 3, H, 2W)\n",
225
+ " for sample in concat_images:\n",
226
+ " img_uint8 = tensor_to_uint8_img(sample)\n",
227
+ " display(Image.fromarray(img_uint8))\n",
228
+ " break\n"
229
+ ]
230
+ }
231
+ ],
232
+ "metadata": {
233
+ "kernelspec": {
234
+ "display_name": "Python 3 (ipykernel)",
235
+ "language": "python",
236
+ "name": "python3"
237
+ },
238
+ "language_info": {
239
+ "codemirror_mode": {
240
+ "name": "ipython",
241
+ "version": 3
242
+ },
243
+ "file_extension": ".py",
244
+ "mimetype": "text/x-python",
245
+ "name": "python",
246
+ "nbconvert_exporter": "python",
247
+ "pygments_lexer": "ipython3",
248
+ "version": "3.12.11"
249
+ }
250
+ },
251
+ "nbformat": 4,
252
+ "nbformat_minor": 5
253
+ }