AlbeRota commited on
Commit
e4a56e4
·
verified ·
1 Parent(s): 222d18b

Upload weights, notebooks, sample images

Browse files
Files changed (1) hide show
  1. notebooks/api_examples.ipynb +201 -92
notebooks/api_examples.ipynb CHANGED
@@ -5,7 +5,16 @@
5
  "id": "d5e78019",
6
  "metadata": {},
7
  "source": [
8
- "# UnReflectAnything API Examples\n"
 
 
 
 
 
 
 
 
 
9
  ]
10
  },
11
  {
@@ -35,162 +44,262 @@
35
  },
36
  {
37
  "cell_type": "markdown",
38
- "id": "94f8c2fb",
39
- "metadata": {},
40
- "source": [
41
- "### 1. Get the model class (for custom setup or training)\n",
42
- "\n",
43
- "`unreflectanything.model()` with no arguments returns the underlying model class `UnReflect_Model_TokenInpainter`. Use it when you need to build the architecture yourself (e.g. from config or for training)."
44
- ]
45
- },
46
- {
47
- "cell_type": "code",
48
- "execution_count": 13,
49
- "id": "f49c99b7",
50
  "metadata": {},
51
- "outputs": [
52
- {
53
- "name": "stdout",
54
- "output_type": "stream",
55
- "text": [
56
- "cuda:0\n"
57
- ]
58
- }
59
- ],
60
  "source": [
61
- "UnReflectModel = unreflectanything.model()\n",
62
- "UnReflectModel_Pretrained = unreflectanything.model(pretrained=True)\n",
63
- "print((next(UnReflectModel.parameters()).device))"
64
  ]
65
  },
66
  {
67
  "cell_type": "markdown",
68
- "id": "575fb9a1",
69
  "metadata": {},
70
  "source": [
71
- "### 2. Get a pretrained model and run on batched RGB\n",
72
  "\n",
73
- "`unreflectanything.model(pretrained=True)` returns an `UnReflectModel` instance (a `torch.nn.Module`) with weights loaded. Call it with a batch of RGB tensors `[B, 3, H, W]` (values in [0, 1]); it returns the diffuse (reflection-removed) tensor."
74
- ]
75
- },
76
- {
77
- "cell_type": "markdown",
78
- "id": "d1cdc14f",
79
- "metadata": {},
80
- "source": [
81
- "#### Load pretrained model (uses cached weights; run 'unreflectanything download --weights' first)"
82
  ]
83
  },
84
  {
85
  "cell_type": "code",
86
- "execution_count": null,
87
  "id": "d58ad7f1",
88
  "metadata": {},
89
  "outputs": [
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
  {
91
  "name": "stdout",
92
  "output_type": "stream",
93
  "text": [
94
- "Model is nn.Module: True\n",
95
- "Expected image size (side): 896\n",
96
- "Device: cuda\n"
97
  ]
98
  }
99
  ],
100
  "source": [
101
- "import torch\n",
102
- "\n",
103
- "# Load pretrained model (uses cached weights; run 'unreflectanything download --weights' first)\n",
104
- "unreflectanythingmodel = unreflectanything.model(pretrained=True)\n",
105
- "unreflectanythingmodel_scratch = unreflectanything.model(pretrained=False)\n",
106
- "print(f\"Model is nn.Module: {isinstance(unreflectanythingmodel, torch.nn.Module)}\")\n",
107
- "print(f\"Expected image size (side): {unreflectanythingmodel.image_size}\")\n",
108
- "print(f\"Device: {unreflectanythingmodel.device}\")"
109
- ]
110
- },
111
- {
112
- "cell_type": "code",
113
- "execution_count": null,
114
- "id": "34e01754",
115
- "metadata": {},
116
- "outputs": [],
117
- "source": [
118
- "# Batched RGB tensor [B, 3, H, W], values in [0, 1]\n",
119
- "batch_size = 2\n",
120
- "images = torch.rand(batch_size, 3, 448, 448, device=unreflectanythingmodel.device)\n",
121
- "model_out = unreflectanythingmodel(images) # [B, 3, H, W] diffuse tensor\n",
122
- "print(f\"Input shape: {images.shape} -> Output shape: {model_out.shape}\")"
123
  ]
124
  },
125
  {
126
  "cell_type": "markdown",
127
- "id": "696bce42",
128
  "metadata": {},
129
  "source": [
130
- "### 3. Full output dict and custom mask (optional)\n",
131
- "\n",
132
- "You can get the full model outputs (e.g. highlight mask, patch mask) with `return_dict=True`, or pass a custom inpainting mask with `inpaint_mask_override`."
133
  ]
134
  },
135
  {
136
  "cell_type": "code",
137
  "execution_count": null,
138
- "id": "dc2ecc8a",
139
  "metadata": {},
140
  "outputs": [],
141
  "source": [
142
- "# Get full outputs: diffuse, highlight, patch_mask, etc.\n",
143
- "outputs = unreflectanythingmodel(images, return_dict=True)\n",
144
- "print(\"Keys:\", list(outputs.keys())) # e.g. diffuse, highlight, patch_mask, tokens_completed\n",
145
- "diffuse_only = outputs[\"diffuse\"]\n",
146
- "highlight_mask = outputs[\"highlight\"] # [B, 1, H, W]"
 
 
 
 
147
  ]
148
  },
149
  {
150
  "cell_type": "markdown",
151
- "id": "87fe354c",
152
  "metadata": {},
153
  "source": [
154
- "### 4. One-shot inference (no model handle)\n",
155
- "\n",
156
- "For a single call without keeping a model in memory, use `unreflectanything.inference()`. It accepts a file path, directory, or tensor and returns a tensor (or writes to disk if `output=` is set)."
157
  ]
158
  },
159
  {
160
  "cell_type": "code",
161
  "execution_count": null,
162
- "id": "ff5740b8",
163
  "metadata": {},
164
  "outputs": [],
165
  "source": [
166
- "# Tensor in -> tensor out (loads model internally, then discards)\n",
167
- "result = unreflectanything.inference(images)\n",
168
- "print(f\"unreflectanything.inference(images) shape: {result.shape}\")\n",
169
- "\n",
170
- "# File-based: save to disk\n",
171
- "# unreflectanything.inference(\"input.png\", output=\"output.png\")\n",
172
- "# unreflectanything.inference(\"input_dir/\", output=\"output_dir/\", batch_size=8)"
173
  ]
174
  },
175
  {
176
  "cell_type": "markdown",
177
- "id": "e2d1673d",
178
  "metadata": {},
179
  "source": [
180
- "### 5. Loading sample images (optional)\n",
181
- "\n",
182
- "If you have downloaded sample images with `unreflectanything download --images`, you can run inference on that directory."
183
  ]
184
  },
185
  {
186
  "cell_type": "code",
187
  "execution_count": null,
188
- "id": "1834686c",
189
  "metadata": {},
190
- "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
191
  "source": [
192
- "SAMPLE_IMAGE_PATH_DIR = \"sample_images\" # default from 'unreflectanything download --images'\n",
193
- "# unreflectanything.inference(SAMPLE_IMAGE_PATH_DIR, output=\"output_sample/\", verbose=True)"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
  ]
195
  }
196
  ],
 
5
  "id": "d5e78019",
6
  "metadata": {},
7
  "source": [
8
+ "# UnReflectAnything API Examples\n",
9
+ "---"
10
+ ]
11
+ },
12
+ {
13
+ "cell_type": "markdown",
14
+ "id": "d423248d",
15
+ "metadata": {},
16
+ "source": [
17
+ "### Package Import"
18
  ]
19
  },
20
  {
 
44
  },
45
  {
46
  "cell_type": "markdown",
47
+ "id": "c3828c5e",
 
 
 
 
 
 
 
 
 
 
 
48
  "metadata": {},
 
 
 
 
 
 
 
 
 
49
  "source": [
50
+ "### Model Loading"
 
 
51
  ]
52
  },
53
  {
54
  "cell_type": "markdown",
55
+ "id": "cabb1b8a",
56
  "metadata": {},
57
  "source": [
58
+ "If you haven't downloaded the pre-trained weights yet, do so with \n",
59
  "\n",
60
+ "`unreflectanything download --weights` from the terminal\n",
61
+ "\n",
62
+ "\n",
63
+ "or with `unreflectanything.download(\"weights\")` from Python."
 
 
 
 
 
64
  ]
65
  },
66
  {
67
  "cell_type": "code",
68
+ "execution_count": 4,
69
  "id": "d58ad7f1",
70
  "metadata": {},
71
  "outputs": [
72
+ {
73
+ "data": {
74
+ "text/html": [
75
+ "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">MODEL <span style=\"font-weight: bold\">[</span><span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">23:19:01</span><span style=\"font-weight: bold\">]</span> ✓ Decoder <span style=\"color: #008000; text-decoration-color: #008000\">'diffuse'</span>: Successfully loaded all <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">54</span> state dict keys from weights/rgb_decoder.pth\n",
76
+ "</pre>\n"
77
+ ],
78
+ "text/plain": [
79
+ "MODEL \u001b[1m[\u001b[0m\u001b[1;92m23:19:01\u001b[0m\u001b[1m]\u001b[0m ✓ Decoder \u001b[32m'diffuse'\u001b[0m: Successfully loaded all \u001b[1;36m54\u001b[0m state dict keys from weights/rgb_decoder.pth\n"
80
+ ]
81
+ },
82
+ "metadata": {},
83
+ "output_type": "display_data"
84
+ },
85
+ {
86
+ "data": {
87
+ "text/html": [
88
+ "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">MODEL <span style=\"font-weight: bold\">[</span><span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">23:19:01</span><span style=\"font-weight: bold\">]</span> Loaded pre-trained decoder weights from weights/rgb_decoder.pth\n",
89
+ "</pre>\n"
90
+ ],
91
+ "text/plain": [
92
+ "MODEL \u001b[1m[\u001b[0m\u001b[1;92m23:19:01\u001b[0m\u001b[1m]\u001b[0m Loaded pre-trained decoder weights from weights/rgb_decoder.pth\n"
93
+ ]
94
+ },
95
+ "metadata": {},
96
+ "output_type": "display_data"
97
+ },
98
+ {
99
+ "data": {
100
+ "text/html": [
101
+ "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">MODEL <span style=\"font-weight: bold\">[</span><span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">23:19:02</span><span style=\"font-weight: bold\">]</span> ✓ Token Inpainter: Successfully loaded all <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">78</span> state dict keys from weights/token_inpainter.pth\n",
102
+ "</pre>\n"
103
+ ],
104
+ "text/plain": [
105
+ "MODEL \u001b[1m[\u001b[0m\u001b[1;92m23:19:02\u001b[0m\u001b[1m]\u001b[0m ✓ Token Inpainter: Successfully loaded all \u001b[1;36m78\u001b[0m state dict keys from weights/token_inpainter.pth\n"
106
+ ]
107
+ },
108
+ "metadata": {},
109
+ "output_type": "display_data"
110
+ },
111
+ {
112
+ "data": {
113
+ "text/html": [
114
+ "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">MODEL <span style=\"font-weight: bold\">[</span><span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">23:19:02</span><span style=\"font-weight: bold\">]</span> Loaded pretrained token inpainter weights from weights/token_inpainter.pth\n",
115
+ "</pre>\n"
116
+ ],
117
+ "text/plain": [
118
+ "MODEL \u001b[1m[\u001b[0m\u001b[1;92m23:19:02\u001b[0m\u001b[1m]\u001b[0m Loaded pretrained token inpainter weights from weights/token_inpainter.pth\n"
119
+ ]
120
+ },
121
+ "metadata": {},
122
+ "output_type": "display_data"
123
+ },
124
  {
125
  "name": "stdout",
126
  "output_type": "stream",
127
  "text": [
128
+ "Warning: missing keys when loading checkpoint: ['decoders.highlight.reassemble_layers.0.proj.weight', 'decoders.highlight.reassemble_layers.0.proj.bias', 'decoders.highlight.reassemble_layers.0.resample.weight', 'decoders.highlight.reassemble_layers.0.resample.bias', 'decoders.highlight.reassemble_layers.1.proj.weight', 'decoders.highlight.reassemble_layers.1.proj.bias', 'decoders.highlight.reassemble_layers.1.resample.weight', 'decoders.highlight.reassemble_layers.1.resample.bias', 'decoders.highlight.reassemble_layers.2.proj.weight', 'decoders.highlight.reassemble_layers.2.proj.bias', 'decoders.highlight.reassemble_layers.3.proj.weight', 'decoders.highlight.reassemble_layers.3.proj.bias', 'decoders.highlight.reassemble_layers.3.resample.weight', 'decoders.highlight.reassemble_layers.3.resample.bias', 'decoders.highlight.fusion_blocks.0.residual_conv1.weight', 'decoders.highlight.fusion_blocks.0.residual_conv1.bias', 'decoders.highlight.fusion_blocks.0.residual_conv2.0.weight', 'decoders.highlight.fusion_blocks.0.residual_conv2.0.bias', 'decoders.highlight.fusion_blocks.0.residual_conv2.3.weight', 'decoders.highlight.fusion_blocks.0.residual_conv2.3.bias', 'decoders.highlight.fusion_blocks.0.out_conv.weight', 'decoders.highlight.fusion_blocks.0.out_conv.bias', 'decoders.highlight.fusion_blocks.1.residual_conv1.weight', 'decoders.highlight.fusion_blocks.1.residual_conv1.bias', 'decoders.highlight.fusion_blocks.1.residual_conv2.0.weight', 'decoders.highlight.fusion_blocks.1.residual_conv2.0.bias', 'decoders.highlight.fusion_blocks.1.residual_conv2.3.weight', 'decoders.highlight.fusion_blocks.1.residual_conv2.3.bias', 'decoders.highlight.fusion_blocks.1.out_conv.weight', 'decoders.highlight.fusion_blocks.1.out_conv.bias', 'decoders.highlight.fusion_blocks.2.residual_conv1.weight', 'decoders.highlight.fusion_blocks.2.residual_conv1.bias', 'decoders.highlight.fusion_blocks.2.residual_conv2.0.weight', 'decoders.highlight.fusion_blocks.2.residual_conv2.0.bias', 'decoders.highlight.fusion_blocks.2.residual_conv2.3.weight', 'decoders.highlight.fusion_blocks.2.residual_conv2.3.bias', 'decoders.highlight.fusion_blocks.2.out_conv.weight', 'decoders.highlight.fusion_blocks.2.out_conv.bias', 'decoders.highlight.fusion_blocks.3.residual_conv1.weight', 'decoders.highlight.fusion_blocks.3.residual_conv1.bias', 'decoders.highlight.fusion_blocks.3.residual_conv2.0.weight', 'decoders.highlight.fusion_blocks.3.residual_conv2.0.bias', 'decoders.highlight.fusion_blocks.3.residual_conv2.3.weight', 'decoders.highlight.fusion_blocks.3.residual_conv2.3.bias', 'decoders.highlight.fusion_blocks.3.out_conv.weight', 'decoders.highlight.fusion_blocks.3.out_conv.bias', 'decoders.highlight.rgb_head.0.weight', 'decoders.highlight.rgb_head.0.bias', 'decoders.highlight.rgb_head.5.weight', 'decoders.highlight.rgb_head.5.bias', 'decoders.highlight.rgb_head.9.weight', 'decoders.highlight.rgb_head.9.bias', 'decoders.highlight.rgb_head.13.weight', 'decoders.highlight.rgb_head.13.bias', 'token_inpaint.mask_token', 'token_inpaint.mask_indicator', 'token_inpaint.blocks.0.attn.norm.weight', 'token_inpaint.blocks.0.attn.norm.bias', 'token_inpaint.blocks.0.attn.fn.attn.in_proj_weight', 'token_inpaint.blocks.0.attn.fn.attn.in_proj_bias', 'token_inpaint.blocks.0.attn.fn.attn.out_proj.weight', 'token_inpaint.blocks.0.attn.fn.attn.out_proj.bias', 'token_inpaint.blocks.0.mlp.norm.weight', 'token_inpaint.blocks.0.mlp.norm.bias', 'token_inpaint.blocks.0.mlp.fn.fc1.weight', 'token_inpaint.blocks.0.mlp.fn.fc1.bias', 'token_inpaint.blocks.0.mlp.fn.fc2.weight', 'token_inpaint.blocks.0.mlp.fn.fc2.bias', 'token_inpaint.blocks.1.attn.norm.weight', 'token_inpaint.blocks.1.attn.norm.bias', 'token_inpaint.blocks.1.attn.fn.attn.in_proj_weight', 'token_inpaint.blocks.1.attn.fn.attn.in_proj_bias', 'token_inpaint.blocks.1.attn.fn.attn.out_proj.weight', 'token_inpaint.blocks.1.attn.fn.attn.out_proj.bias', 'token_inpaint.blocks.1.mlp.norm.weight', 'token_inpaint.blocks.1.mlp.norm.bias', 'token_inpaint.blocks.1.mlp.fn.fc1.weight', 'token_inpaint.blocks.1.mlp.fn.fc1.bias', 'token_inpaint.blocks.1.mlp.fn.fc2.weight', 'token_inpaint.blocks.1.mlp.fn.fc2.bias', 'token_inpaint.blocks.2.attn.norm.weight', 'token_inpaint.blocks.2.attn.norm.bias', 'token_inpaint.blocks.2.attn.fn.attn.in_proj_weight', 'token_inpaint.blocks.2.attn.fn.attn.in_proj_bias', 'token_inpaint.blocks.2.attn.fn.attn.out_proj.weight', 'token_inpaint.blocks.2.attn.fn.attn.out_proj.bias', 'token_inpaint.blocks.2.mlp.norm.weight', 'token_inpaint.blocks.2.mlp.norm.bias', 'token_inpaint.blocks.2.mlp.fn.fc1.weight', 'token_inpaint.blocks.2.mlp.fn.fc1.bias', 'token_inpaint.blocks.2.mlp.fn.fc2.weight', 'token_inpaint.blocks.2.mlp.fn.fc2.bias', 'token_inpaint.blocks.3.attn.norm.weight', 'token_inpaint.blocks.3.attn.norm.bias', 'token_inpaint.blocks.3.attn.fn.attn.in_proj_weight', 'token_inpaint.blocks.3.attn.fn.attn.in_proj_bias', 'token_inpaint.blocks.3.attn.fn.attn.out_proj.weight', 'token_inpaint.blocks.3.attn.fn.attn.out_proj.bias', 'token_inpaint.blocks.3.mlp.norm.weight', 'token_inpaint.blocks.3.mlp.norm.bias', 'token_inpaint.blocks.3.mlp.fn.fc1.weight', 'token_inpaint.blocks.3.mlp.fn.fc1.bias', 'token_inpaint.blocks.3.mlp.fn.fc2.weight', 'token_inpaint.blocks.3.mlp.fn.fc2.bias', 'token_inpaint.blocks.4.attn.norm.weight', 'token_inpaint.blocks.4.attn.norm.bias', 'token_inpaint.blocks.4.attn.fn.attn.in_proj_weight', 'token_inpaint.blocks.4.attn.fn.attn.in_proj_bias', 'token_inpaint.blocks.4.attn.fn.attn.out_proj.weight', 'token_inpaint.blocks.4.attn.fn.attn.out_proj.bias', 'token_inpaint.blocks.4.mlp.norm.weight', 'token_inpaint.blocks.4.mlp.norm.bias', 'token_inpaint.blocks.4.mlp.fn.fc1.weight', 'token_inpaint.blocks.4.mlp.fn.fc1.bias', 'token_inpaint.blocks.4.mlp.fn.fc2.weight', 'token_inpaint.blocks.4.mlp.fn.fc2.bias', 'token_inpaint.blocks.5.attn.norm.weight', 'token_inpaint.blocks.5.attn.norm.bias', 'token_inpaint.blocks.5.attn.fn.attn.in_proj_weight', 'token_inpaint.blocks.5.attn.fn.attn.in_proj_bias', 'token_inpaint.blocks.5.attn.fn.attn.out_proj.weight', 'token_inpaint.blocks.5.attn.fn.attn.out_proj.bias', 'token_inpaint.blocks.5.mlp.norm.weight', 'token_inpaint.blocks.5.mlp.norm.bias', 'token_inpaint.blocks.5.mlp.fn.fc1.weight', 'token_inpaint.blocks.5.mlp.fn.fc1.bias', 'token_inpaint.blocks.5.mlp.fn.fc2.weight', 'token_inpaint.blocks.5.mlp.fn.fc2.bias', 'token_inpaint.out_proj.weight', 'token_inpaint.out_proj.bias', 'token_inpaint._final_norm.weight', 'token_inpaint._final_norm.bias']\n"
 
 
129
  ]
130
  }
131
  ],
132
  "source": [
133
+ "# unreflectanything.download(\"weights\")\n",
134
+ "# unreflectanything.download(\"images\") # --> Loads 20 sample images\n",
135
+ "unreflectanythingmodel = unreflectanything.model(pretrained=True)"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
  ]
137
  },
138
  {
139
  "cell_type": "markdown",
140
+ "id": "f3dfa889",
141
  "metadata": {},
142
  "source": [
143
+ "Load a dataset of images. Change `PATH_TO_IMAGE_DIR` to point to your own image directory"
 
 
144
  ]
145
  },
146
  {
147
  "cell_type": "code",
148
  "execution_count": null,
149
+ "id": "da39fa39",
150
  "metadata": {},
151
  "outputs": [],
152
  "source": [
153
+ "from unreflectanything import ImageDirDataset, get_cache_dir\n",
154
+ "from torch.utils.data import DataLoader\n",
155
+ "\n",
156
+ "PATH_TO_IMAGE_DIR = get_cache_dir(\n",
157
+ " \"images\"\n",
158
+ ") # Modify this path to point to your image directory\n",
159
+ "\n",
160
+ "ds = ImageDirDataset(PATH_TO_IMAGE_DIR, target_size=(896, 896), return_path=False)\n",
161
+ "loader = DataLoader(ds, batch_size=1, shuffle=False)"
162
  ]
163
  },
164
  {
165
  "cell_type": "markdown",
166
+ "id": "4c8312f0",
167
  "metadata": {},
168
  "source": [
169
+ "### Forward Pass / Inference"
 
 
170
  ]
171
  },
172
  {
173
  "cell_type": "code",
174
  "execution_count": null,
175
+ "id": "34e01754",
176
  "metadata": {},
177
  "outputs": [],
178
  "source": [
179
+ "output_images = [unreflectanythingmodel(batch_images) for batch_images in loader]"
 
 
 
 
 
 
180
  ]
181
  },
182
  {
183
  "cell_type": "markdown",
184
+ "id": "94690751",
185
  "metadata": {},
186
  "source": [
187
+ "### Displaying results"
 
 
188
  ]
189
  },
190
  {
191
  "cell_type": "code",
192
  "execution_count": null,
193
+ "id": "b588087b",
194
  "metadata": {},
195
+ "outputs": [
196
+ {
197
+ "data": {
198
+ "text/html": [
199
+ "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">MODEL <span style=\"font-weight: bold\">[</span><span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">22:43:07</span><span style=\"font-weight: bold\">]</span> ✓ Decoder <span style=\"color: #008000; text-decoration-color: #008000\">'diffuse'</span>: Successfully loaded all <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">54</span> state dict keys from weights/rgb_decoder.pth\n",
200
+ "</pre>\n"
201
+ ],
202
+ "text/plain": [
203
+ "MODEL \u001b[1m[\u001b[0m\u001b[1;92m22:43:07\u001b[0m\u001b[1m]\u001b[0m ✓ Decoder \u001b[32m'diffuse'\u001b[0m: Successfully loaded all \u001b[1;36m54\u001b[0m state dict keys from weights/rgb_decoder.pth\n"
204
+ ]
205
+ },
206
+ "metadata": {},
207
+ "output_type": "display_data"
208
+ },
209
+ {
210
+ "data": {
211
+ "text/html": [
212
+ "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">MODEL <span style=\"font-weight: bold\">[</span><span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">22:43:07</span><span style=\"font-weight: bold\">]</span> Loaded pre-trained decoder weights from weights/rgb_decoder.pth\n",
213
+ "</pre>\n"
214
+ ],
215
+ "text/plain": [
216
+ "MODEL \u001b[1m[\u001b[0m\u001b[1;92m22:43:07\u001b[0m\u001b[1m]\u001b[0m Loaded pre-trained decoder weights from weights/rgb_decoder.pth\n"
217
+ ]
218
+ },
219
+ "metadata": {},
220
+ "output_type": "display_data"
221
+ },
222
+ {
223
+ "data": {
224
+ "text/html": [
225
+ "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">MODEL <span style=\"font-weight: bold\">[</span><span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">22:43:08</span><span style=\"font-weight: bold\">]</span> ✓ Token Inpainter: Successfully loaded all <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">78</span> state dict keys from weights/token_inpainter.pth\n",
226
+ "</pre>\n"
227
+ ],
228
+ "text/plain": [
229
+ "MODEL \u001b[1m[\u001b[0m\u001b[1;92m22:43:08\u001b[0m\u001b[1m]\u001b[0m ✓ Token Inpainter: Successfully loaded all \u001b[1;36m78\u001b[0m state dict keys from weights/token_inpainter.pth\n"
230
+ ]
231
+ },
232
+ "metadata": {},
233
+ "output_type": "display_data"
234
+ },
235
+ {
236
+ "data": {
237
+ "text/html": [
238
+ "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">MODEL <span style=\"font-weight: bold\">[</span><span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">22:43:08</span><span style=\"font-weight: bold\">]</span> Loaded pretrained token inpainter weights from weights/token_inpainter.pth\n",
239
+ "</pre>\n"
240
+ ],
241
+ "text/plain": [
242
+ "MODEL \u001b[1m[\u001b[0m\u001b[1;92m22:43:08\u001b[0m\u001b[1m]\u001b[0m Loaded pretrained token inpainter weights from weights/token_inpainter.pth\n"
243
+ ]
244
+ },
245
+ "metadata": {},
246
+ "output_type": "display_data"
247
+ },
248
+ {
249
+ "name": "stdout",
250
+ "output_type": "stream",
251
+ "text": [
252
+ "Warning: missing keys when loading checkpoint: ['decoders.highlight.reassemble_layers.0.proj.weight', 'decoders.highlight.reassemble_layers.0.proj.bias', 'decoders.highlight.reassemble_layers.0.resample.weight', 'decoders.highlight.reassemble_layers.0.resample.bias', 'decoders.highlight.reassemble_layers.1.proj.weight', 'decoders.highlight.reassemble_layers.1.proj.bias', 'decoders.highlight.reassemble_layers.1.resample.weight', 'decoders.highlight.reassemble_layers.1.resample.bias', 'decoders.highlight.reassemble_layers.2.proj.weight', 'decoders.highlight.reassemble_layers.2.proj.bias', 'decoders.highlight.reassemble_layers.3.proj.weight', 'decoders.highlight.reassemble_layers.3.proj.bias', 'decoders.highlight.reassemble_layers.3.resample.weight', 'decoders.highlight.reassemble_layers.3.resample.bias', 'decoders.highlight.fusion_blocks.0.residual_conv1.weight', 'decoders.highlight.fusion_blocks.0.residual_conv1.bias', 'decoders.highlight.fusion_blocks.0.residual_conv2.0.weight', 'decoders.highlight.fusion_blocks.0.residual_conv2.0.bias', 'decoders.highlight.fusion_blocks.0.residual_conv2.3.weight', 'decoders.highlight.fusion_blocks.0.residual_conv2.3.bias', 'decoders.highlight.fusion_blocks.0.out_conv.weight', 'decoders.highlight.fusion_blocks.0.out_conv.bias', 'decoders.highlight.fusion_blocks.1.residual_conv1.weight', 'decoders.highlight.fusion_blocks.1.residual_conv1.bias', 'decoders.highlight.fusion_blocks.1.residual_conv2.0.weight', 'decoders.highlight.fusion_blocks.1.residual_conv2.0.bias', 'decoders.highlight.fusion_blocks.1.residual_conv2.3.weight', 'decoders.highlight.fusion_blocks.1.residual_conv2.3.bias', 'decoders.highlight.fusion_blocks.1.out_conv.weight', 'decoders.highlight.fusion_blocks.1.out_conv.bias', 'decoders.highlight.fusion_blocks.2.residual_conv1.weight', 'decoders.highlight.fusion_blocks.2.residual_conv1.bias', 'decoders.highlight.fusion_blocks.2.residual_conv2.0.weight', 'decoders.highlight.fusion_blocks.2.residual_conv2.0.bias', 'decoders.highlight.fusion_blocks.2.residual_conv2.3.weight', 'decoders.highlight.fusion_blocks.2.residual_conv2.3.bias', 'decoders.highlight.fusion_blocks.2.out_conv.weight', 'decoders.highlight.fusion_blocks.2.out_conv.bias', 'decoders.highlight.fusion_blocks.3.residual_conv1.weight', 'decoders.highlight.fusion_blocks.3.residual_conv1.bias', 'decoders.highlight.fusion_blocks.3.residual_conv2.0.weight', 'decoders.highlight.fusion_blocks.3.residual_conv2.0.bias', 'decoders.highlight.fusion_blocks.3.residual_conv2.3.weight', 'decoders.highlight.fusion_blocks.3.residual_conv2.3.bias', 'decoders.highlight.fusion_blocks.3.out_conv.weight', 'decoders.highlight.fusion_blocks.3.out_conv.bias', 'decoders.highlight.rgb_head.0.weight', 'decoders.highlight.rgb_head.0.bias', 'decoders.highlight.rgb_head.5.weight', 'decoders.highlight.rgb_head.5.bias', 'decoders.highlight.rgb_head.9.weight', 'decoders.highlight.rgb_head.9.bias', 'decoders.highlight.rgb_head.13.weight', 'decoders.highlight.rgb_head.13.bias', 'token_inpaint.mask_token', 'token_inpaint.mask_indicator', 'token_inpaint.blocks.0.attn.norm.weight', 'token_inpaint.blocks.0.attn.norm.bias', 'token_inpaint.blocks.0.attn.fn.attn.in_proj_weight', 'token_inpaint.blocks.0.attn.fn.attn.in_proj_bias', 'token_inpaint.blocks.0.attn.fn.attn.out_proj.weight', 'token_inpaint.blocks.0.attn.fn.attn.out_proj.bias', 'token_inpaint.blocks.0.mlp.norm.weight', 'token_inpaint.blocks.0.mlp.norm.bias', 'token_inpaint.blocks.0.mlp.fn.fc1.weight', 'token_inpaint.blocks.0.mlp.fn.fc1.bias', 'token_inpaint.blocks.0.mlp.fn.fc2.weight', 'token_inpaint.blocks.0.mlp.fn.fc2.bias', 'token_inpaint.blocks.1.attn.norm.weight', 'token_inpaint.blocks.1.attn.norm.bias', 'token_inpaint.blocks.1.attn.fn.attn.in_proj_weight', 'token_inpaint.blocks.1.attn.fn.attn.in_proj_bias', 'token_inpaint.blocks.1.attn.fn.attn.out_proj.weight', 'token_inpaint.blocks.1.attn.fn.attn.out_proj.bias', 'token_inpaint.blocks.1.mlp.norm.weight', 'token_inpaint.blocks.1.mlp.norm.bias', 'token_inpaint.blocks.1.mlp.fn.fc1.weight', 'token_inpaint.blocks.1.mlp.fn.fc1.bias', 'token_inpaint.blocks.1.mlp.fn.fc2.weight', 'token_inpaint.blocks.1.mlp.fn.fc2.bias', 'token_inpaint.blocks.2.attn.norm.weight', 'token_inpaint.blocks.2.attn.norm.bias', 'token_inpaint.blocks.2.attn.fn.attn.in_proj_weight', 'token_inpaint.blocks.2.attn.fn.attn.in_proj_bias', 'token_inpaint.blocks.2.attn.fn.attn.out_proj.weight', 'token_inpaint.blocks.2.attn.fn.attn.out_proj.bias', 'token_inpaint.blocks.2.mlp.norm.weight', 'token_inpaint.blocks.2.mlp.norm.bias', 'token_inpaint.blocks.2.mlp.fn.fc1.weight', 'token_inpaint.blocks.2.mlp.fn.fc1.bias', 'token_inpaint.blocks.2.mlp.fn.fc2.weight', 'token_inpaint.blocks.2.mlp.fn.fc2.bias', 'token_inpaint.blocks.3.attn.norm.weight', 'token_inpaint.blocks.3.attn.norm.bias', 'token_inpaint.blocks.3.attn.fn.attn.in_proj_weight', 'token_inpaint.blocks.3.attn.fn.attn.in_proj_bias', 'token_inpaint.blocks.3.attn.fn.attn.out_proj.weight', 'token_inpaint.blocks.3.attn.fn.attn.out_proj.bias', 'token_inpaint.blocks.3.mlp.norm.weight', 'token_inpaint.blocks.3.mlp.norm.bias', 'token_inpaint.blocks.3.mlp.fn.fc1.weight', 'token_inpaint.blocks.3.mlp.fn.fc1.bias', 'token_inpaint.blocks.3.mlp.fn.fc2.weight', 'token_inpaint.blocks.3.mlp.fn.fc2.bias', 'token_inpaint.blocks.4.attn.norm.weight', 'token_inpaint.blocks.4.attn.norm.bias', 'token_inpaint.blocks.4.attn.fn.attn.in_proj_weight', 'token_inpaint.blocks.4.attn.fn.attn.in_proj_bias', 'token_inpaint.blocks.4.attn.fn.attn.out_proj.weight', 'token_inpaint.blocks.4.attn.fn.attn.out_proj.bias', 'token_inpaint.blocks.4.mlp.norm.weight', 'token_inpaint.blocks.4.mlp.norm.bias', 'token_inpaint.blocks.4.mlp.fn.fc1.weight', 'token_inpaint.blocks.4.mlp.fn.fc1.bias', 'token_inpaint.blocks.4.mlp.fn.fc2.weight', 'token_inpaint.blocks.4.mlp.fn.fc2.bias', 'token_inpaint.blocks.5.attn.norm.weight', 'token_inpaint.blocks.5.attn.norm.bias', 'token_inpaint.blocks.5.attn.fn.attn.in_proj_weight', 'token_inpaint.blocks.5.attn.fn.attn.in_proj_bias', 'token_inpaint.blocks.5.attn.fn.attn.out_proj.weight', 'token_inpaint.blocks.5.attn.fn.attn.out_proj.bias', 'token_inpaint.blocks.5.mlp.norm.weight', 'token_inpaint.blocks.5.mlp.norm.bias', 'token_inpaint.blocks.5.mlp.fn.fc1.weight', 'token_inpaint.blocks.5.mlp.fn.fc1.bias', 'token_inpaint.blocks.5.mlp.fn.fc2.weight', 'token_inpaint.blocks.5.mlp.fn.fc2.bias', 'token_inpaint.out_proj.weight', 'token_inpaint.out_proj.bias', 'token_inpaint._final_norm.weight', 'token_inpaint._final_norm.bias']\n"
253
+ ]
254
+ }
255
+ ],
256
+ "source": [
257
+ "unreflectanythingmodel = unreflectanything.model(\n",
258
+ " pretrained=True,\n",
259
+ " config_path=\"huggingface/configs/pretrained_config.yaml\",\n",
260
+ " weights_path=get_cache_dir(\"weights\") / \"full_model_weights.pt\",\n",
261
+ ")"
262
+ ]
263
+ },
264
+ {
265
+ "cell_type": "code",
266
+ "execution_count": 57,
267
+ "id": "a130c042",
268
+ "metadata": {},
269
+ "outputs": [
270
+ {
271
+ "ename": "AttributeError",
272
+ "evalue": "'dict' object has no attribute 'cpu'",
273
+ "output_type": "error",
274
+ "traceback": [
275
+ "\u001b[31m---------------------------------------------------------------------------\u001b[39m",
276
+ "\u001b[31mAttributeError\u001b[39m Traceback (most recent call last)",
277
+ "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[57]\u001b[39m\u001b[32m, line 15\u001b[39m\n\u001b[32m 10\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m arr\n\u001b[32m 13\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m input_batch, output_batch \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mzip\u001b[39m(loader, output_images):\n\u001b[32m 14\u001b[39m concat_images = torch.cat(\n\u001b[32m---> \u001b[39m\u001b[32m15\u001b[39m [input_batch.cpu(), \u001b[43moutput_batch\u001b[49m\u001b[43m.\u001b[49m\u001b[43mcpu\u001b[49m()], dim=\u001b[32m3\u001b[39m\n\u001b[32m 16\u001b[39m ) \u001b[38;5;66;03m# (B, 3, H, 2W)\u001b[39;00m\n\u001b[32m 17\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m sample \u001b[38;5;129;01min\u001b[39;00m concat_images:\n\u001b[32m 18\u001b[39m img_uint8 = tensor_to_uint8_img(sample)\n",
278
+ "\u001b[31mAttributeError\u001b[39m: 'dict' object has no attribute 'cpu'"
279
+ ]
280
+ }
281
+ ],
282
  "source": [
283
+ "from PIL import Image\n",
284
+ "import numpy as np\n",
285
+ "\n",
286
+ "\n",
287
+ "# Helper: Convert tensor [H, W, C] in [0,1] float32 to uint8\n",
288
+ "def tensor_to_uint8_img(t):\n",
289
+ " arr = t.permute(1, 2, 0).cpu().detach().numpy()\n",
290
+ " arr = np.clip(arr, 0, 1)\n",
291
+ " arr = (arr * 255).round().astype(np.uint8)\n",
292
+ " return arr\n",
293
+ "\n",
294
+ "\n",
295
+ "for input_batch, output_batch in zip(loader, output_images):\n",
296
+ " concat_images = torch.cat(\n",
297
+ " [input_batch.cpu(), output_batch.cpu()], dim=3\n",
298
+ " ) # (B, 3, H, 2W)\n",
299
+ " for sample in concat_images:\n",
300
+ " img_uint8 = tensor_to_uint8_img(sample)\n",
301
+ " display(Image.fromarray(img_uint8))\n",
302
+ " break\n"
303
  ]
304
  }
305
  ],