{ "cells": [ { "cell_type": "markdown", "id": "d5e78019", "metadata": {}, "source": [ "# UnReflectAnything API Examples\n" ] }, { "cell_type": "code", "execution_count": 1, "id": "db2eda79", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Using device: cuda\n" ] } ], "source": [ "import unreflectanything\n", "import torch\n", "\n", "%load_ext autoreload\n", "%autoreload 2\n", "\n", "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "print(f\"Using device: {device}\")" ] }, { "cell_type": "markdown", "id": "94f8c2fb", "metadata": {}, "source": [ "### 1. Get the model class (for custom setup or training)\n", "\n", "`unreflectanything.model()` with no arguments returns the underlying model class `UnReflect_Model_TokenInpainter`. Use it when you need to build the architecture yourself (e.g. from config or for training)." ] }, { "cell_type": "code", "execution_count": 13, "id": "f49c99b7", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "cuda:0\n" ] } ], "source": [ "UnReflectModel = unreflectanything.model()\n", "UnReflectModel_Pretrained = unreflectanything.model(pretrained=True)\n", "print((next(UnReflectModel.parameters()).device))" ] }, { "cell_type": "markdown", "id": "575fb9a1", "metadata": {}, "source": [ "### 2. Get a pretrained model and run on batched RGB\n", "\n", "`unreflectanything.model(pretrained=True)` returns an `UnReflectModel` instance (a `torch.nn.Module`) with weights loaded. Call it with a batch of RGB tensors `[B, 3, H, W]` (values in [0, 1]); it returns the diffuse (reflection-removed) tensor." ] }, { "cell_type": "markdown", "id": "d1cdc14f", "metadata": {}, "source": [ "#### Load pretrained model (uses cached weights; run 'unreflectanything download --weights' first)" ] }, { "cell_type": "code", "execution_count": null, "id": "d58ad7f1", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Model is nn.Module: True\n", "Expected image size (side): 896\n", "Device: cuda\n" ] } ], "source": [ "import torch\n", "\n", "# Load pretrained model (uses cached weights; run 'unreflectanything download --weights' first)\n", "unreflectanythingmodel = unreflectanything.model(pretrained=True)\n", "unreflectanythingmodel_scratch = unreflectanything.model(pretrained=False)\n", "print(f\"Model is nn.Module: {isinstance(unreflectanythingmodel, torch.nn.Module)}\")\n", "print(f\"Expected image size (side): {unreflectanythingmodel.image_size}\")\n", "print(f\"Device: {unreflectanythingmodel.device}\")" ] }, { "cell_type": "code", "execution_count": null, "id": "34e01754", "metadata": {}, "outputs": [], "source": [ "# Batched RGB tensor [B, 3, H, W], values in [0, 1]\n", "batch_size = 2\n", "images = torch.rand(batch_size, 3, 448, 448, device=unreflectanythingmodel.device)\n", "model_out = unreflectanythingmodel(images) # [B, 3, H, W] diffuse tensor\n", "print(f\"Input shape: {images.shape} -> Output shape: {model_out.shape}\")" ] }, { "cell_type": "markdown", "id": "696bce42", "metadata": {}, "source": [ "### 3. Full output dict and custom mask (optional)\n", "\n", "You can get the full model outputs (e.g. highlight mask, patch mask) with `return_dict=True`, or pass a custom inpainting mask with `inpaint_mask_override`." ] }, { "cell_type": "code", "execution_count": null, "id": "dc2ecc8a", "metadata": {}, "outputs": [], "source": [ "# Get full outputs: diffuse, highlight, patch_mask, etc.\n", "outputs = unreflectanythingmodel(images, return_dict=True)\n", "print(\"Keys:\", list(outputs.keys())) # e.g. diffuse, highlight, patch_mask, tokens_completed\n", "diffuse_only = outputs[\"diffuse\"]\n", "highlight_mask = outputs[\"highlight\"] # [B, 1, H, W]" ] }, { "cell_type": "markdown", "id": "87fe354c", "metadata": {}, "source": [ "### 4. One-shot inference (no model handle)\n", "\n", "For a single call without keeping a model in memory, use `unreflectanything.inference()`. It accepts a file path, directory, or tensor and returns a tensor (or writes to disk if `output=` is set)." ] }, { "cell_type": "code", "execution_count": null, "id": "ff5740b8", "metadata": {}, "outputs": [], "source": [ "# Tensor in -> tensor out (loads model internally, then discards)\n", "result = unreflectanything.inference(images)\n", "print(f\"unreflectanything.inference(images) shape: {result.shape}\")\n", "\n", "# File-based: save to disk\n", "# unreflectanything.inference(\"input.png\", output=\"output.png\")\n", "# unreflectanything.inference(\"input_dir/\", output=\"output_dir/\", batch_size=8)" ] }, { "cell_type": "markdown", "id": "e2d1673d", "metadata": {}, "source": [ "### 5. Loading sample images (optional)\n", "\n", "If you have downloaded sample images with `unreflectanything download --images`, you can run inference on that directory." ] }, { "cell_type": "code", "execution_count": null, "id": "1834686c", "metadata": {}, "outputs": [], "source": [ "SAMPLE_IMAGE_PATH_DIR = \"sample_images\" # default from 'unreflectanything download --images'\n", "# unreflectanything.inference(SAMPLE_IMAGE_PATH_DIR, output=\"output_sample/\", verbose=True)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.11" } }, "nbformat": 4, "nbformat_minor": 5 }