# MIT License # Copyright (c) [2026] [Tim Büchner, Sai Karthikeya Vemuri, Joachim Denzler] # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. import os import requests from io import BytesIO from pathlib import Path import gradio as gr import jax import jax.numpy as jnp import numpy as np import optax from PIL import Image from skimage.metrics import peak_signal_noise_ratio, structural_similarity from model import DecompositionType, EmbeddingType, MLPType, get_model_2D from utils import img_loss, img_train_generator os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false" def train_model(img, decomp_type, backend_type, embedding_type, rank, epochs, update_freq, diff_boost): img_array = np.array(img).astype(np.float32) / 255.0 backend = MLPType[backend_type] if embedding_type == "None": embedding = EmbeddingType.PE000 elif embedding_type == "Positional": embedding = EmbeddingType.PE100 elif embedding_type == "Hash": embedding = EmbeddingType.HE else: raise ValueError(f"Unsupported embedding type: {embedding_type}") # Forcing CP decomposition as per requirement, while preserving the user dropdown decomp = DecompositionType.CP features = [256] * 4 model = get_model_2D(backend=backend, embedding=embedding, decomp=decomp, rank=rank, features=features) key = jax.random.PRNGKey(42) key, subkey = jax.random.split(key, 2) # CP Decomposition (Always used) train_data = img_train_generator(img_array) x_coords, y_coords, u_channels = train_data params = model.init(subkey, x_coords[:1], y_coords[:1]) optim = optax.adam(0.001) apply_fn = jax.jit(model.apply) l_fn = img_loss(apply_fn, *train_data) @jax.jit def train_step(params, opt_state): loss, grads = jax.value_and_grad(l_fn)(params) updates, opt_state = optim.update(grads, opt_state) params = optax.apply_updates(params, updates) return loss, params, opt_state opt_state = optim.init(params) # Calculate model size and compression ratio num_params = sum(x.size for x in jax.tree_util.tree_leaves(params)) model_size_bytes = num_params * 4 img_size_bytes = np.array(img).shape[0] * np.array(img).shape[1] * 3 compression_ratio = img_size_bytes / model_size_bytes if model_size_bytes > 0 else 0 info_str = f"Original image size: {img_size_bytes / 1024:.2f} KB | Model size: {model_size_bytes / 1024:.2f} KB | Compression ratio: {compression_ratio:.2f}x" print(info_str) loss_history = [] for i in range(epochs): loss, params, opt_state = train_step(params, opt_state) loss_history.append([i, float(loss)]) if stop_requested: break if i % update_freq == 0 or i == epochs - 1: u = apply_fn(params, x_coords, y_coords) ut = jnp.transpose(jnp.array(u), (1, 2, 0)).clip(0, 1) recon = (np.array(ut) * 255).astype(np.uint8) diff = np.abs(np.array(ut) - img_array) * diff_boost diff_img = (np.clip(diff, 0, 1) * 255).astype(np.uint8) yield Image.fromarray(recon), Image.fromarray(diff_img), f"Epoch {i}: loss {loss:.4f}\n{info_str}", loss_history # Final execution will already cover up to the last state due to the loop # Load sample image remote_url = "https://f-inr.github.io/static/images/0886.png" try: response = requests.get(remote_url, timeout=10) sample_img = Image.open(BytesIO(response.content)) print(f"Loaded remote image from {remote_url}") except Exception as e: print(f"Failed to load remote image: {e}. Falling back to local search.") image_files = list(Path(__file__).parent.glob("*.png")) + list(Path(__file__).parent.glob("*.jpg")) sample_img_path = image_files[0] if image_files else None if sample_img_path is None or not sample_img_path.exists(): sample_img = None else: sample_img = Image.open(sample_img_path) with gr.Blocks() as demo: gr.HTML("""
Winter Conference on Applications of Computer Vision (WACV 2026)
Implicit Neural Representations (INRs) model signals as continuous, differentiable functions. However, monolithic INRs scale poorly with data dimensionality. F-INR factorizes a high-dimensional INR into a set of compact, axis-specific sub-networks using functional tensor decomposition. This demo allows you to train an INR model on a 2D image (as showcased in our paper) and compare our approach to a standard baseline.