AlbeRota commited on
Commit
0637cee
·
verified ·
1 Parent(s): 03935f5

Upload weights, notebooks, sample images

Browse files
Files changed (1) hide show
  1. notebooks/api_examples.ipynb +218 -0
notebooks/api_examples.ipynb ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "d5e78019",
6
+ "metadata": {},
7
+ "source": [
8
+ "# UnReflectAnything API Examples\n"
9
+ ]
10
+ },
11
+ {
12
+ "cell_type": "code",
13
+ "execution_count": 1,
14
+ "id": "db2eda79",
15
+ "metadata": {},
16
+ "outputs": [
17
+ {
18
+ "name": "stdout",
19
+ "output_type": "stream",
20
+ "text": [
21
+ "Using device: cuda\n"
22
+ ]
23
+ }
24
+ ],
25
+ "source": [
26
+ "import unreflectanything\n",
27
+ "import torch\n",
28
+ "\n",
29
+ "%load_ext autoreload\n",
30
+ "%autoreload 2\n",
31
+ "\n",
32
+ "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
33
+ "print(f\"Using device: {device}\")"
34
+ ]
35
+ },
36
+ {
37
+ "cell_type": "markdown",
38
+ "id": "94f8c2fb",
39
+ "metadata": {},
40
+ "source": [
41
+ "### 1. Get the model class (for custom setup or training)\n",
42
+ "\n",
43
+ "`unreflectanything.model()` with no arguments returns the underlying model class `UnReflect_Model_TokenInpainter`. Use it when you need to build the architecture yourself (e.g. from config or for training)."
44
+ ]
45
+ },
46
+ {
47
+ "cell_type": "code",
48
+ "execution_count": 13,
49
+ "id": "f49c99b7",
50
+ "metadata": {},
51
+ "outputs": [
52
+ {
53
+ "name": "stdout",
54
+ "output_type": "stream",
55
+ "text": [
56
+ "cuda:0\n"
57
+ ]
58
+ }
59
+ ],
60
+ "source": [
61
+ "UnReflectModel = unreflectanything.model()\n",
62
+ "UnReflectModel_Pretrained = unreflectanything.model(pretrained=True)\n",
63
+ "print((next(UnReflectModel.parameters()).device))"
64
+ ]
65
+ },
66
+ {
67
+ "cell_type": "markdown",
68
+ "id": "575fb9a1",
69
+ "metadata": {},
70
+ "source": [
71
+ "### 2. Get a pretrained model and run on batched RGB\n",
72
+ "\n",
73
+ "`unreflectanything.model(pretrained=True)` returns an `UnReflectModel` instance (a `torch.nn.Module`) with weights loaded. Call it with a batch of RGB tensors `[B, 3, H, W]` (values in [0, 1]); it returns the diffuse (reflection-removed) tensor."
74
+ ]
75
+ },
76
+ {
77
+ "cell_type": "markdown",
78
+ "id": "d1cdc14f",
79
+ "metadata": {},
80
+ "source": [
81
+ "#### Load pretrained model (uses cached weights; run 'unreflectanything download --weights' first)"
82
+ ]
83
+ },
84
+ {
85
+ "cell_type": "code",
86
+ "execution_count": null,
87
+ "id": "d58ad7f1",
88
+ "metadata": {},
89
+ "outputs": [
90
+ {
91
+ "name": "stdout",
92
+ "output_type": "stream",
93
+ "text": [
94
+ "Model is nn.Module: True\n",
95
+ "Expected image size (side): 896\n",
96
+ "Device: cuda\n"
97
+ ]
98
+ }
99
+ ],
100
+ "source": [
101
+ "import torch\n",
102
+ "\n",
103
+ "# Load pretrained model (uses cached weights; run 'unreflectanything download --weights' first)\n",
104
+ "unreflectanythingmodel = unreflectanything.model(pretrained=True)\n",
105
+ "unreflectanythingmodel_scratch = unreflectanything.model(pretrained=False)\n",
106
+ "print(f\"Model is nn.Module: {isinstance(unreflectanythingmodel, torch.nn.Module)}\")\n",
107
+ "print(f\"Expected image size (side): {unreflectanythingmodel.image_size}\")\n",
108
+ "print(f\"Device: {unreflectanythingmodel.device}\")"
109
+ ]
110
+ },
111
+ {
112
+ "cell_type": "code",
113
+ "execution_count": null,
114
+ "id": "34e01754",
115
+ "metadata": {},
116
+ "outputs": [],
117
+ "source": [
118
+ "# Batched RGB tensor [B, 3, H, W], values in [0, 1]\n",
119
+ "batch_size = 2\n",
120
+ "images = torch.rand(batch_size, 3, 448, 448, device=unreflectanythingmodel.device)\n",
121
+ "model_out = unreflectanythingmodel(images) # [B, 3, H, W] diffuse tensor\n",
122
+ "print(f\"Input shape: {images.shape} -> Output shape: {model_out.shape}\")"
123
+ ]
124
+ },
125
+ {
126
+ "cell_type": "markdown",
127
+ "id": "696bce42",
128
+ "metadata": {},
129
+ "source": [
130
+ "### 3. Full output dict and custom mask (optional)\n",
131
+ "\n",
132
+ "You can get the full model outputs (e.g. highlight mask, patch mask) with `return_dict=True`, or pass a custom inpainting mask with `inpaint_mask_override`."
133
+ ]
134
+ },
135
+ {
136
+ "cell_type": "code",
137
+ "execution_count": null,
138
+ "id": "dc2ecc8a",
139
+ "metadata": {},
140
+ "outputs": [],
141
+ "source": [
142
+ "# Get full outputs: diffuse, highlight, patch_mask, etc.\n",
143
+ "outputs = unreflectanythingmodel(images, return_dict=True)\n",
144
+ "print(\"Keys:\", list(outputs.keys())) # e.g. diffuse, highlight, patch_mask, tokens_completed\n",
145
+ "diffuse_only = outputs[\"diffuse\"]\n",
146
+ "highlight_mask = outputs[\"highlight\"] # [B, 1, H, W]"
147
+ ]
148
+ },
149
+ {
150
+ "cell_type": "markdown",
151
+ "id": "87fe354c",
152
+ "metadata": {},
153
+ "source": [
154
+ "### 4. One-shot inference (no model handle)\n",
155
+ "\n",
156
+ "For a single call without keeping a model in memory, use `unreflectanything.inference()`. It accepts a file path, directory, or tensor and returns a tensor (or writes to disk if `output=` is set)."
157
+ ]
158
+ },
159
+ {
160
+ "cell_type": "code",
161
+ "execution_count": null,
162
+ "id": "ff5740b8",
163
+ "metadata": {},
164
+ "outputs": [],
165
+ "source": [
166
+ "# Tensor in -> tensor out (loads model internally, then discards)\n",
167
+ "result = unreflectanything.inference(images)\n",
168
+ "print(f\"unreflectanything.inference(images) shape: {result.shape}\")\n",
169
+ "\n",
170
+ "# File-based: save to disk\n",
171
+ "# unreflectanything.inference(\"input.png\", output=\"output.png\")\n",
172
+ "# unreflectanything.inference(\"input_dir/\", output=\"output_dir/\", batch_size=8)"
173
+ ]
174
+ },
175
+ {
176
+ "cell_type": "markdown",
177
+ "id": "e2d1673d",
178
+ "metadata": {},
179
+ "source": [
180
+ "### 5. Loading sample images (optional)\n",
181
+ "\n",
182
+ "If you have downloaded sample images with `unreflectanything download --images`, you can run inference on that directory."
183
+ ]
184
+ },
185
+ {
186
+ "cell_type": "code",
187
+ "execution_count": null,
188
+ "id": "1834686c",
189
+ "metadata": {},
190
+ "outputs": [],
191
+ "source": [
192
+ "SAMPLE_IMAGE_PATH_DIR = \"sample_images\" # default from 'unreflectanything download --images'\n",
193
+ "# unreflectanything.inference(SAMPLE_IMAGE_PATH_DIR, output=\"output_sample/\", verbose=True)"
194
+ ]
195
+ }
196
+ ],
197
+ "metadata": {
198
+ "kernelspec": {
199
+ "display_name": "Python 3 (ipykernel)",
200
+ "language": "python",
201
+ "name": "python3"
202
+ },
203
+ "language_info": {
204
+ "codemirror_mode": {
205
+ "name": "ipython",
206
+ "version": 3
207
+ },
208
+ "file_extension": ".py",
209
+ "mimetype": "text/x-python",
210
+ "name": "python",
211
+ "nbconvert_exporter": "python",
212
+ "pygments_lexer": "ipython3",
213
+ "version": "3.12.11"
214
+ }
215
+ },
216
+ "nbformat": 4,
217
+ "nbformat_minor": 5
218
+ }