Spaces:
nemo10101
/
Runtime error

gQIR / app.py
aRy4n's picture
Update app.py
6c52580 verified
#!/usr/bin/env python3
"""gQIR Gradio demo
Single-frame mode follows infer_sd2GAN_stage2.py (color path only).
Burst mode follows infer_burst_realistic.py for 77->11 aggregation and reconstruction.
Local run cmd:
python gradio_app.py
--single-config configs/inference/eval_sd2GAN.yaml \
--burst-config configs/inference/eval_burst_mosaic.yaml \
--device cuda --local
"""
from __future__ import annotations
import spaces
import argparse
import atexit
import os
import random
import shutil
import subprocess
import tempfile
import threading
import traceback
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Optional
import gradio as gr
import numpy as np
import torch
import torch.nn.functional as F
import torchvision.transforms as transforms
from accelerate.utils import set_seed
from diffusers import DDPMScheduler, UNet2DConditionModel
from omegaconf import OmegaConf
from peft import LoraConfig
from PIL import Image
from transformers import CLIPTextModel, CLIPTokenizer
from gqvr.dataset.utils import emulate_spc, srgb_to_linearrgb
from gqvr.model.core_raft.raft import RAFT
from gqvr.model.fusionViT import LightweightHybrid3DFusion
from gqvr.model.generator import SD2Enhancer
from gqvr.model.vae import AutoencoderKL
try:
import h5py
except Exception:
h5py = None
try:
from huggingface_hub import hf_hub_download
except Exception:
hf_hub_download = None
APP_ROOT = Path(__file__).resolve().parent
DEFAULT_SINGLE_CONFIG_COLOR = "configs/inference/eval_3bit_color.yaml"
DEFAULT_SINGLE_CONFIG_MONO = "configs/inference/eval_3bit_mono.yaml"
DEFAULT_BURST_CONFIG_COLOR = "configs/inference/eval_burst_mosaic.yaml"
DEFAULT_BURST_CONFIG_MONO = "configs/inference/eval_burst.yaml"
DEFAULT_MAX_SIZE = (512, 512)
BURST_WINDOW = 77
PIPELINE_COLOR = "Color"
PIPELINE_MONO = "Monochrome"
PIPELINE_OPTIONS = [PIPELINE_COLOR, PIPELINE_MONO]
HF_DEFAULT_REPO_ID = "aRy4n/gQIR"
HF_MODEL_FILES = {
PIPELINE_COLOR: {
"single_qvae": "0105000.pt",
"single_lora": "state_dict.pth",
"burst_qvae": "0105000.pt",
"burst_lora": "state_dict.pth",
"burst_fusion": "fusion_vit_0050000.pt",
},
PIPELINE_MONO: {
"single_qvae": "mono/0150000.pt",
"single_lora": "mono/state_dict.pth",
"burst_qvae": "mono/0150000.pt",
"burst_lora": "mono/state_dict.pth",
"burst_fusion": "mono/fusion_vit_0020000.pt",
},
}
SINGLE_MODE_GT = "GT image (simulate 3-bit SPAD)"
SINGLE_MODE_REAL = "Real SPAD frame"
BURST_MODE_GT = "GT cube (simulate SPAD from RGB cube)"
BURST_MODE_REAL = "Real photon cube / SPAD cube"
TO_TENSOR = transforms.ToTensor()
_SINGLE_PIPELINES: dict[str, "SingleColorPipeline"] = {}
_BURST_PIPELINES: dict[str, "BurstColorPipeline"] = {}
_SINGLE_LOCK = threading.Lock()
_BURST_LOCK = threading.Lock()
RUNTIME_SINGLE_CONFIGS: dict[str, Path] = {}
RUNTIME_BURST_CONFIGS: dict[str, Path] = {}
RUNTIME_DEVICE: str = "cuda"
RUNTIME_BURST_OUT_SIZES: dict[str, int] = {}
RUNTIME_HF_REPO_ID: str = HF_DEFAULT_REPO_ID
RUNTIME_HF_CACHE_DIR: Optional[str] = None
RUNTIME_HF_TOKEN: Optional[str] = None
_TEMP_VIDEO_DIRS: list[str] = []
def _cleanup_temp_video_dirs() -> None:
for p in _TEMP_VIDEO_DIRS:
try:
shutil.rmtree(p, ignore_errors=True)
except Exception:
pass
atexit.register(_cleanup_temp_video_dirs)
@dataclass
class CubeDescriptor:
source_mode: str
kind: str # dir | video | array | h5
path: str
total_frames: int
out_size: int
files: Optional[list[str]] = None
array_format: Optional[str] = None # npy | npz | pt
array_key: Optional[str] = None
h5_keys: Optional[list[str]] = None
temp_dir: Optional[str] = None
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument(
"--single-config",
type=str,
default=None,
help="Deprecated alias for --single-config-color",
)
parser.add_argument(
"--burst-config",
type=str,
default=None,
help="Deprecated alias for --burst-config-color",
)
parser.add_argument("--single-config-color", type=str, default=str(DEFAULT_SINGLE_CONFIG_COLOR))
parser.add_argument("--single-config-mono", type=str, default=str(DEFAULT_SINGLE_CONFIG_MONO))
parser.add_argument("--burst-config-color", type=str, default=str(DEFAULT_BURST_CONFIG_COLOR))
parser.add_argument("--burst-config-mono", type=str, default=str(DEFAULT_BURST_CONFIG_MONO))
parser.add_argument(
"--device",
type=str,
default="cuda" if torch.cuda.is_available() else "cpu",
help="Inference device, e.g. cuda, cuda:0, cpu",
)
parser.add_argument(
"--hf-repo-id",
type=str,
default=HF_DEFAULT_REPO_ID,
help="Hugging Face repo containing gQIR checkpoints used when config paths are not local files.",
)
parser.add_argument(
"--hf-cache-dir",
type=str,
default=None,
help="Optional Hugging Face cache directory for checkpoint downloads.",
)
parser.add_argument(
"--hf-token",
type=str,
default=None,
help="Optional HF token. If omitted, app reads HF_TOKEN or HUGGINGFACE_HUB_TOKEN env vars.",
)
parser.add_argument("--port", type=int, default=7860)
parser.add_argument("--local", action="store_true", help="Bind to 127.0.0.1 instead of 0.0.0.0")
parser.add_argument("--share", action="store_true")
return parser.parse_args()
def _resolve_existing_file(path_value: Optional[str]) -> Optional[str]:
if not path_value:
return None
raw = Path(str(path_value)).expanduser()
candidates = [raw]
if not raw.is_absolute():
candidates.append(APP_ROOT / raw)
for p in candidates:
if p.is_file():
return str(p.resolve())
return None
def _download_hf_checkpoint(filename: str) -> str:
if hf_hub_download is None:
raise RuntimeError(
"huggingface_hub is required to download checkpoints. Install it or provide local model paths."
)
kwargs: dict[str, Any] = {
"repo_id": RUNTIME_HF_REPO_ID,
"filename": filename,
}
if RUNTIME_HF_CACHE_DIR:
kwargs["cache_dir"] = RUNTIME_HF_CACHE_DIR
if RUNTIME_HF_TOKEN:
kwargs["token"] = RUNTIME_HF_TOKEN
try:
downloaded = hf_hub_download(**kwargs)
except Exception as exc:
raise RuntimeError(
f"Failed to download '{filename}' from '{RUNTIME_HF_REPO_ID}'. "
"Check repo, token permissions, and network availability."
) from exc
return str(Path(downloaded).resolve())
def _resolve_checkpoint_path(config_value: Optional[str], pipeline_type: str, file_key: str) -> str:
if pipeline_type not in HF_MODEL_FILES:
raise ValueError(f"Unknown pipeline type for checkpoint resolution: {pipeline_type}")
if file_key not in HF_MODEL_FILES[pipeline_type]:
raise ValueError(f"Unknown checkpoint key '{file_key}' for pipeline type '{pipeline_type}'")
existing = _resolve_existing_file(config_value)
if existing is not None:
return existing
hf_file = HF_MODEL_FILES[pipeline_type][file_key]
print(f"[gQIR] Missing local checkpoint; downloading {RUNTIME_HF_REPO_ID}/{hf_file}")
return _download_hf_checkpoint(hf_file)
def _prepare_single_cfg_paths(cfg: Any, pipeline_type: str) -> Any:
cfg = OmegaConf.create(OmegaConf.to_container(cfg, resolve=False))
if "model" not in cfg or "vae_cfg" not in cfg.model:
raise ValueError("Single-frame config missing model.vae_cfg")
qvae_path_cfg = None
if "qvae_path" in cfg.model.vae_cfg:
qvae_path_cfg = cfg.model.vae_cfg.qvae_path
if not qvae_path_cfg and "qvae_path" in cfg:
qvae_path_cfg = cfg.qvae_path
cfg.weight_path = _resolve_checkpoint_path(cfg.get("weight_path"), pipeline_type, "single_lora")
resolved_qvae = _resolve_checkpoint_path(qvae_path_cfg, pipeline_type, "single_qvae")
cfg.model.vae_cfg.qvae_path = resolved_qvae
if "qvae_path" in cfg:
cfg.qvae_path = resolved_qvae
return cfg
def _prepare_burst_cfg_paths(cfg: Any, pipeline_type: str) -> Any:
cfg = OmegaConf.create(OmegaConf.to_container(cfg, resolve=False))
cfg.qvae_path = _resolve_checkpoint_path(cfg.get("qvae_path"), pipeline_type, "burst_qvae")
cfg.unet_weight_path = _resolve_checkpoint_path(cfg.get("unet_weight_path"), pipeline_type, "burst_lora")
cfg.fusion_vit_weight_path = _resolve_checkpoint_path(
cfg.get("fusion_vit_weight_path"), pipeline_type, "burst_fusion"
)
if "model" in cfg and "vae_cfg" in cfg.model:
cfg.model.vae_cfg.qvae_path = cfg.qvae_path
return cfg
def _ensure_rgb_image(arr: np.ndarray) -> np.ndarray:
arr = np.asarray(arr)
if arr.ndim == 2:
arr = np.stack([arr] * 3, axis=-1)
elif arr.ndim == 3 and arr.shape[-1] == 1:
arr = np.repeat(arr, 3, axis=-1)
elif arr.ndim == 3 and arr.shape[-1] == 4:
arr = arr[..., :3]
if arr.ndim != 3 or arr.shape[-1] != 3:
raise ValueError(f"Expected image shape HxWx3, got {arr.shape}")
return arr
def _normalize_float01(arr: np.ndarray) -> np.ndarray:
arr = np.asarray(arr).astype(np.float32)
if arr.size == 0:
return arr
min_v = float(arr.min())
max_v = float(arr.max())
if 0.0 <= min_v and max_v <= 1.0:
return arr
if min_v >= 0.0 and max_v <= 255.0:
arr = arr / 255.0
elif min_v >= 0.0 and max_v > 0.0:
arr = arr / max_v
else:
den = max(max_v - min_v, 1e-8)
arr = (arr - min_v) / den
return np.clip(arr, 0.0, 1.0)
def _to_uint8(arr_float01: np.ndarray) -> np.ndarray:
return np.clip(arr_float01 * 255.0, 0.0, 255.0).astype(np.uint8)
def _resize_dims_keep_aspect(h: int, w: int, max_side: int, multiple_of: int = 1) -> tuple[int, int]:
if h <= 0 or w <= 0:
raise ValueError(f"Invalid frame size: {h}x{w}")
scale = float(max_side) / float(max(h, w))
new_h = max(1, int(round(h * scale)))
new_w = max(1, int(round(w * scale)))
if multiple_of > 1:
new_h = max(multiple_of, int(round(new_h / multiple_of) * multiple_of))
new_w = max(multiple_of, int(round(new_w / multiple_of) * multiple_of))
new_h = min(new_h, max_side)
new_w = min(new_w, max_side)
return new_h, new_w
def _resize_frame_rgb(frame_float01: np.ndarray, max_side: int, multiple_of: int = 1) -> np.ndarray:
frame_float01 = _normalize_float01(_ensure_rgb_image(frame_float01))
h, w = frame_float01.shape[:2]
new_h, new_w = _resize_dims_keep_aspect(h, w, max_side=max_side, multiple_of=multiple_of)
if h == new_h and w == new_w:
return frame_float01
pil_img = Image.fromarray(_to_uint8(frame_float01))
return np.asarray(pil_img.resize((new_w, new_h), Image.LANCZOS), dtype=np.float32) / 255.0
def _resize_frames_rgb(frames_thwc: np.ndarray, max_side: int, multiple_of: int = 1) -> np.ndarray:
frames_thwc = np.asarray(frames_thwc)
if frames_thwc.ndim != 4 or frames_thwc.shape[-1] != 3:
raise ValueError(f"Expected THWC with C=3, got {frames_thwc.shape}")
h, w = frames_thwc.shape[1:3]
new_h, new_w = _resize_dims_keep_aspect(h, w, max_side=max_side, multiple_of=multiple_of)
if h == new_h and w == new_w:
return _normalize_float01(frames_thwc)
resized = [
_resize_frame_rgb(frames_thwc[i], max_side=max_side, multiple_of=multiple_of)
for i in range(frames_thwc.shape[0])
]
return np.stack(resized, axis=0).astype(np.float32)
def _to_gray_uint8(img_uint8_rgb: Optional[np.ndarray]) -> Optional[np.ndarray]:
if img_uint8_rgb is None:
return None
arr = np.asarray(img_uint8_rgb)
if arr.ndim == 2:
return arr.astype(np.uint8)
if arr.ndim == 3 and arr.shape[-1] == 1:
return arr[..., 0].astype(np.uint8)
return np.asarray(Image.fromarray(arr.astype(np.uint8)).convert("L"), dtype=np.uint8)
def _resize_uint8_to_hw(img_uint8: Optional[np.ndarray], target_h: int, target_w: int) -> Optional[np.ndarray]:
if img_uint8 is None:
return None
arr = np.asarray(img_uint8)
if arr.ndim == 2:
pil = Image.fromarray(arr.astype(np.uint8), mode="L")
return np.asarray(pil.resize((target_w, target_h), Image.LANCZOS), dtype=np.uint8)
if arr.ndim == 3 and arr.shape[-1] == 1:
pil = Image.fromarray(arr[..., 0].astype(np.uint8), mode="L")
return np.asarray(pil.resize((target_w, target_h), Image.LANCZOS), dtype=np.uint8)
pil = Image.fromarray(arr.astype(np.uint8), mode="RGB")
return np.asarray(pil.resize((target_w, target_h), Image.LANCZOS), dtype=np.uint8)
def _to_thwc(arr: np.ndarray) -> np.ndarray:
arr = np.asarray(arr)
if arr.ndim == 5 and arr.shape[0] == 1:
arr = arr[0]
if arr.ndim == 4:
if arr.shape[-1] in (1, 3, 4):
out = arr
elif arr.shape[1] in (1, 3, 4):
out = np.transpose(arr, (0, 2, 3, 1))
elif arr.shape[0] in (1, 3, 4):
out = np.transpose(arr, (3, 1, 2, 0))
else:
raise ValueError(f"Cannot infer channel axis from shape {arr.shape}")
elif arr.ndim == 3:
if arr.shape[-1] in (1, 3, 4):
out = arr[None, ...]
elif arr.shape[0] in (1, 3, 4):
out = np.transpose(arr, (1, 2, 0))[None, ...]
else:
# Treat as T x H x W single-channel.
out = arr[..., None]
else:
raise ValueError(f"Expected 3D or 4D array, got shape {arr.shape}")
if out.shape[-1] == 4:
out = out[..., :3]
return out
def _single_channel_bayer_to_sparse_rgb(frames_thw1: np.ndarray) -> np.ndarray:
bayer = np.asarray(frames_thw1).astype(np.float32)
if bayer.ndim != 4 or bayer.shape[-1] != 1:
raise ValueError(f"Expected THW1, got {bayer.shape}")
bayer = bayer[..., 0]
t, h, w = bayer.shape
out = np.zeros((t, h, w, 3), dtype=np.float32)
out[:, 0::2, 0::2, 0] = bayer[:, 0::2, 0::2]
out[:, 0::2, 1::2, 1] = bayer[:, 0::2, 1::2]
out[:, 1::2, 0::2, 1] = bayer[:, 1::2, 0::2]
out[:, 1::2, 1::2, 2] = bayer[:, 1::2, 1::2]
return out
def _mosaic_with_pattern(img_rgb: np.ndarray, pattern: str) -> np.ndarray:
r = img_rgb[:, :, 0]
g = img_rgb[:, :, 1]
b = img_rgb[:, :, 2]
out = np.zeros_like(img_rgb, dtype=np.float32)
if pattern == "RGGB":
out[0::2, 0::2, 0] = r[0::2, 0::2]
out[0::2, 1::2, 1] = g[0::2, 1::2]
out[1::2, 0::2, 1] = g[1::2, 0::2]
out[1::2, 1::2, 2] = b[1::2, 1::2]
elif pattern == "GRBG":
out[0::2, 1::2, 0] = r[0::2, 1::2]
out[0::2, 0::2, 1] = g[0::2, 0::2]
out[1::2, 1::2, 1] = g[1::2, 1::2]
out[1::2, 0::2, 2] = b[1::2, 0::2]
elif pattern == "BGGR":
out[0::2, 0::2, 2] = b[0::2, 0::2]
out[0::2, 1::2, 1] = g[0::2, 1::2]
out[1::2, 0::2, 1] = g[1::2, 0::2]
out[1::2, 1::2, 0] = r[1::2, 1::2]
elif pattern == "GBRG":
out[0::2, 0::2, 1] = g[0::2, 0::2]
out[1::2, 1::2, 1] = g[1::2, 1::2]
out[0::2, 1::2, 2] = b[0::2, 1::2]
out[1::2, 0::2, 0] = r[1::2, 0::2]
else:
raise ValueError(f"Unsupported Bayer pattern: {pattern}")
return out
def _simulate_single_3bit_from_gt(gt_rgb_float01: np.ndarray, target_ppp: float) -> np.ndarray:
bits = 3
n = (2**bits) - 1
factor = target_ppp / 3.5
lq_sum = np.zeros_like(gt_rgb_float01, dtype=np.float32)
for _ in range(n):
spc = emulate_spc(srgb_to_linearrgb(gt_rgb_float01), factor=factor).astype(np.float32)
pattern = random.choice(["RGGB", "GRBG", "BGGR", "GBRG"])
lq_sum += _mosaic_with_pattern(spc, pattern)
return np.clip(lq_sum / float(n), 0.0, 1.0)
def _simulate_single_3bit_from_gt_mono(gt_rgb_float01: np.ndarray, target_ppp: float) -> np.ndarray:
bits = 3
n = (2**bits) - 1
factor = target_ppp / 3.5
lq_sum = np.zeros_like(gt_rgb_float01, dtype=np.float32)
for _ in range(n):
spc = emulate_spc(srgb_to_linearrgb(gt_rgb_float01), factor=factor).astype(np.float32)
lq_sum += spc
return np.clip(lq_sum / float(n), 0.0, 1.0)
def _simulate_binary_burst_frame_from_gt(gt_rgb_float01: np.ndarray, target_ppp: float = 3.5) -> np.ndarray:
# Same PPP scaling convention as single-frame simulation.
factor = target_ppp / 3.5
spc = emulate_spc(srgb_to_linearrgb(gt_rgb_float01), factor=factor).astype(np.float32)
return _mosaic_with_pattern(spc, "BGGR")
def _simulate_binary_burst_frame_from_gt_mono(gt_rgb_float01: np.ndarray, target_ppp: float = 3.5) -> np.ndarray:
# Same PPP scaling convention as single-frame simulation.
factor = target_ppp / 3.5
return emulate_spc(srgb_to_linearrgb(gt_rgb_float01), factor=factor).astype(np.float32)
def _tensor_to_uint8_image(x_bchw: torch.Tensor) -> np.ndarray:
x = x_bchw.detach().cpu().clamp(0.0, 1.0)
x = (x[0].permute(1, 2, 0).numpy() * 255.0).astype(np.uint8)
return x
def _encode_prompt(tokenizer: CLIPTokenizer, text_encoder: CLIPTextModel, prompt: str, bs: int, device: str) -> torch.Tensor:
txt_ids = tokenizer(
[prompt] * bs,
max_length=tokenizer.model_max_length,
padding="max_length",
truncation=True,
return_tensors="pt",
).input_ids
return text_encoder(txt_ids.to(device))[0]
def differentiable_warp(x: torch.Tensor, flow: torch.Tensor) -> torch.Tensor:
b, c, h, w = x.size()
grid_y, grid_x = torch.meshgrid(torch.arange(h), torch.arange(w))
grid = torch.stack((grid_x, grid_y), 2).float().to(x.device)
grid = grid.unsqueeze(0).repeat(b, 1, 1, 1)
flow = flow.permute(0, 2, 3, 1)
new_grid = grid + flow
new_grid[..., 0] = 2.0 * new_grid[..., 0] / (w - 1) - 1.0
new_grid[..., 1] = 2.0 * new_grid[..., 1] / (h - 1) - 1.0
return F.grid_sample(x, new_grid, align_corners=True, padding_mode="border")
class SingleColorPipeline:
def __init__(self, config_path: Path, device: str, pipeline_type: str):
self.config_path = config_path
self.device = device
self.pipeline_type = pipeline_type
self.max_size = DEFAULT_MAX_SIZE
self.model: Optional[SD2Enhancer] = None
def load(self) -> None:
if self.model is not None:
return
cfg = OmegaConf.load(str(self.config_path))
cfg = _prepare_single_cfg_paths(cfg, self.pipeline_type)
if cfg.base_model_type != "sd2":
raise ValueError(f"Unsupported base_model_type for single pipeline: {cfg.base_model_type}")
self.model = SD2Enhancer(
base_model_path=cfg.base_model_path,
weight_path=cfg.weight_path,
lora_modules=cfg.lora_modules,
lora_rank=cfg.lora_rank,
model_t=cfg.model_t,
coeff_t=cfg.coeff_t,
vae_cfg=cfg.model.vae_cfg,
device=self.device,
)
self.model.init_models()
def _enhance(self, lq_rgb_float01: np.ndarray, prompt: str, only_vae_output: bool, seed: int) -> tuple[np.ndarray, int]:
if self.model is None:
self.load()
if seed == -1:
seed = random.randint(0, 2**32 - 1)
set_seed(seed)
out_h, out_w = lq_rgb_float01.shape[:2]
if out_h * out_w > self.max_size[0] * self.max_size[1]:
raise ValueError(
f"Resolution {out_h}x{out_w} exceeds max pixel budget "
f"{self.max_size[0]}x{self.max_size[1]}."
)
image_tensor = TO_TENSOR(lq_rgb_float01).unsqueeze(0)
pil_img = self.model.enhance(
lq=image_tensor,
prompt=prompt,
upscale=1,
return_type="pil",
only_vae_output=only_vae_output,
save_Gprocessed_latents=False,
fname="",
)[0]
return np.asarray(pil_img.convert("RGB"), dtype=np.uint8), seed
def reconstruct_from_gt(
self,
gt_image_np: np.ndarray,
prompt: str,
target_ppp: float,
only_vae_output: bool,
seed: int,
simulate_color_mosaic: bool = True,
) -> tuple[np.ndarray, np.ndarray, np.ndarray, str]:
gt_rgb = _normalize_float01(_ensure_rgb_image(gt_image_np))
gt_rgb = _resize_frame_rgb(gt_rgb, self.max_size[0], multiple_of=8)
if simulate_color_mosaic:
lq_rgb = _simulate_single_3bit_from_gt(gt_rgb, target_ppp=target_ppp)
else:
lq_rgb = _simulate_single_3bit_from_gt_mono(gt_rgb, target_ppp=target_ppp)
recon_uint8, used_seed = self._enhance(lq_rgb, prompt, only_vae_output, seed)
status = f"Single reconstruction complete (mode=GT simulation, seed={used_seed}, PPP={target_ppp:.2f})."
return _to_uint8(gt_rgb), _to_uint8(lq_rgb), recon_uint8, status
def reconstruct_from_real_spad(
self,
lq_image_np: np.ndarray,
prompt: str,
only_vae_output: bool,
seed: int,
) -> tuple[np.ndarray, np.ndarray, str]:
lq_rgb = _normalize_float01(_ensure_rgb_image(lq_image_np))
lq_rgb = _resize_frame_rgb(lq_rgb, self.max_size[0], multiple_of=8)
recon_uint8, used_seed = self._enhance(lq_rgb, prompt, only_vae_output, seed)
status = f"Single reconstruction complete (mode=real SPAD frame, seed={used_seed})."
return _to_uint8(lq_rgb), recon_uint8, status
class BurstColorPipeline:
def __init__(self, config_path: Path, device: str, pipeline_type: str):
self.config_path = config_path
self.device = device
self.pipeline_type = pipeline_type
self.cfg = None
self.out_size = 512
self.weight_dtype = torch.bfloat16 if str(device).startswith("cuda") else torch.float32
self.vae: Optional[AutoencoderKL] = None
self.raft_model: Optional[RAFT] = None
self.fusion_vit: Optional[LightweightHybrid3DFusion] = None
self.tokenizer: Optional[CLIPTokenizer] = None
self.text_encoder: Optional[CLIPTextModel] = None
self.scheduler: Optional[DDPMScheduler] = None
self.ls_burst_unet: Optional[UNet2DConditionModel] = None
def load(self) -> None:
if self.vae is not None:
return
cfg = OmegaConf.load(str(self.config_path))
cfg = _prepare_burst_cfg_paths(cfg, self.pipeline_type)
self.cfg = cfg
self.out_size = int(cfg.dataset.val.params.out_size)
vae = AutoencoderKL(cfg.model.vae_cfg.ddconfig, cfg.model.vae_cfg.embed_dim)
da_vae = torch.load(cfg.qvae_path, map_location="cpu")
init_vae = {}
scratch = vae.state_dict()
for key in scratch:
if key in da_vae:
init_vae[key] = da_vae[key].clone()
vae.load_state_dict(init_vae, strict=True)
vae.requires_grad_(False)
vae.eval().to(self.device)
self.vae = vae
class RAFTArgs:
mixed_precision = True
small = False
alternate_corr = True
dropout = False
raft_model = RAFT(RAFTArgs())
raft_path = APP_ROOT / "pretrained_ckpts" / "models" / "raft-things.pth"
raft_dict = torch.load(str(raft_path), map_location="cpu")
corrected = {}
for k, v in raft_dict.items():
k2 = ".".join(k.split(".")[1:]) if "." in k else k
corrected[k2] = v
raft_model.load_state_dict(corrected)
raft_model.eval().requires_grad_(False).to(self.device)
self.raft_model = raft_model
fusion_vit = LightweightHybrid3DFusion()
fusion_ckpt = torch.load(cfg.fusion_vit_weight_path, map_location="cpu")
fusion_vit.load_state_dict(fusion_ckpt)
fusion_vit.eval().requires_grad_(False).to(self.device)
self.fusion_vit = fusion_vit
self.tokenizer = CLIPTokenizer.from_pretrained(cfg.base_model_path, subfolder="tokenizer")
self.text_encoder = CLIPTextModel.from_pretrained(
cfg.base_model_path,
subfolder="text_encoder",
torch_dtype=self.weight_dtype,
).to(self.device)
self.text_encoder.eval().requires_grad_(False)
self.scheduler = DDPMScheduler.from_pretrained(cfg.base_model_path, subfolder="scheduler")
ls_burst_unet = UNet2DConditionModel.from_pretrained(
cfg.base_model_path,
subfolder="unet",
torch_dtype=self.weight_dtype,
).to(self.device)
lora_cfg = LoraConfig(
r=cfg.lora_rank,
lora_alpha=cfg.lora_rank,
init_lora_weights="gaussian",
target_modules=cfg.lora_modules,
)
ls_burst_unet.add_adapter(lora_cfg)
try:
state_dict = torch.load(cfg.unet_weight_path, map_location="cpu", weights_only=False)
except TypeError:
state_dict = torch.load(cfg.unet_weight_path, map_location="cpu")
ls_burst_unet.load_state_dict(state_dict, strict=False)
required_keys = {k for k in ls_burst_unet.state_dict().keys() if "lora" in k}
input_keys = set(state_dict.keys())
if required_keys != input_keys:
missing = required_keys - input_keys
unexpected = input_keys - required_keys
raise RuntimeError(f"LoRA key mismatch. Missing={len(missing)} Unexpected={len(unexpected)}")
ls_burst_unet.eval().requires_grad_(False)
self.ls_burst_unet = ls_burst_unet
def _ensure_loaded(self) -> None:
if self.vae is None:
self.load()
def reconstruct_from_binary_window(
self,
binary_window_77: np.ndarray,
gt_window_77: Optional[np.ndarray] = None,
) -> tuple[np.ndarray, np.ndarray, Optional[np.ndarray]]:
self._ensure_loaded()
assert self.cfg is not None
assert self.vae is not None
assert self.raft_model is not None
assert self.fusion_vit is not None
assert self.tokenizer is not None
assert self.text_encoder is not None
assert self.scheduler is not None
assert self.ls_burst_unet is not None
if binary_window_77.shape[0] != BURST_WINDOW:
raise ValueError(f"Burst window must have {BURST_WINDOW} frames, got {binary_window_77.shape[0]}")
binary_window_77 = _resize_frames_rgb(
_normalize_float01(binary_window_77),
max_side=self.out_size,
multiple_of=64,
)
lqs = torch.from_numpy(binary_window_77).unsqueeze(0).permute(0, 1, 4, 2, 3).to(self.device)
lqs = (lqs * 2.0) - 1.0
lqs_3bit = []
for i in range(0, lqs.size(1), 7):
chunk = lqs[:, i : i + 7, ...]
if chunk.size(1) < 7:
break
lqs_3bit.append(torch.mean(chunk, dim=1, keepdim=True))
lqs = torch.cat(lqs_3bit, dim=1)
gts = None
if gt_window_77 is not None:
gt_window_77 = _resize_frames_rgb(
_normalize_float01(gt_window_77),
max_side=self.out_size,
multiple_of=64,
)
gts = torch.from_numpy(gt_window_77).unsqueeze(0).permute(0, 1, 4, 2, 3).to(self.device)
gts = (gts * 2.0) - 1.0
gts_3bit = []
for i in range(0, gts.size(1), 7):
chunk = gts[:, i : i + 7, ...]
if chunk.size(1) < 7:
break
gts_3bit.append(torch.mean(chunk, dim=1, keepdim=True))
gts = torch.cat(gts_3bit, dim=1)
with torch.inference_mode():
bs = lqs.size(0)
t_total = lqs.size(1)
center_t = t_total // 2
latents = []
decoded_lqs = []
for t in range(t_total):
lq_t = lqs[:, t, ...].float()
z_t = self.vae.encode(lq_t).mode()
latents.append(z_t)
decoded_lqs.append(self.vae.decode(z_t).float())
y = torch.stack(decoded_lqs, dim=1)
flow_vectors = []
for t in range(t_total):
ls_in = y[:, t, ...].float()
center_in = y[:, center_t, ...].float()
if t < center_t:
_, flow_bw = self.raft_model(center_in, ls_in, iters=20, test_mode=True)
else:
_, flow_bw = self.raft_model(ls_in, center_in, iters=20, test_mode=True)
z_h, z_w = latents[t].shape[-2:]
in_h, in_w = ls_in.shape[-2:]
flow_bw = F.interpolate(flow_bw, size=(z_h, z_w), mode="bilinear", align_corners=True)
flow_bw[:, 0] *= float(z_w) / float(in_w)
flow_bw[:, 1] *= float(z_h) / float(in_h)
flow_vectors.append(flow_bw)
aligned_latents = []
for t in range(t_total):
latent_t = latents[t]
if t == center_t:
aligned_latents.append(latent_t)
else:
aligned_latents.append(differentiable_warp(latent_t, flow_vectors[t]))
aligned_latents = torch.stack(aligned_latents, dim=1)
merged_latent = self.fusion_vit(aligned_latents)
z_in = (merged_latent * 0.18215).to(self.weight_dtype)
timesteps = torch.full((bs,), int(self.cfg.model_t), dtype=torch.long, device=self.device)
text_embed = _encode_prompt(self.tokenizer, self.text_encoder, "", bs=bs, device=self.device)
eps = self.ls_burst_unet(z_in, timesteps, encoder_hidden_states=text_embed).sample
z = self.scheduler.step(eps, int(self.cfg.coeff_t), z_in).pred_original_sample
decoded_refined = self.vae.decode(z.float() / 0.18215).float().clamp(0.0, 1.0)
center_input = ((lqs[:, center_t, ...] + 1.0) / 2.0).clamp(0.0, 1.0)
center_gt = None
if gts is not None:
center_gt = ((gts[:, center_t, ...] + 1.0) / 2.0).clamp(0.0, 1.0)
if torch.cuda.is_available() and str(self.device).startswith("cuda"):
torch.cuda.empty_cache()
return (
_tensor_to_uint8_image(center_input),
_tensor_to_uint8_image(decoded_refined),
_tensor_to_uint8_image(center_gt) if center_gt is not None else None,
)
def _get_single_pipeline(pipeline_type: str) -> SingleColorPipeline:
if pipeline_type not in PIPELINE_OPTIONS:
raise ValueError(f"Unknown pipeline type: {pipeline_type}")
if pipeline_type not in RUNTIME_SINGLE_CONFIGS:
raise RuntimeError(f"Single config not initialized for pipeline type: {pipeline_type}")
with _SINGLE_LOCK:
if pipeline_type not in _SINGLE_PIPELINES:
_SINGLE_PIPELINES[pipeline_type] = SingleColorPipeline(
RUNTIME_SINGLE_CONFIGS[pipeline_type],
RUNTIME_DEVICE,
pipeline_type,
)
_SINGLE_PIPELINES[pipeline_type].load()
return _SINGLE_PIPELINES[pipeline_type]
def _get_burst_pipeline(pipeline_type: str) -> BurstColorPipeline:
if pipeline_type not in PIPELINE_OPTIONS:
raise ValueError(f"Unknown pipeline type: {pipeline_type}")
if pipeline_type not in RUNTIME_BURST_CONFIGS:
raise RuntimeError(f"Burst config not initialized for pipeline type: {pipeline_type}")
with _BURST_LOCK:
if pipeline_type not in _BURST_PIPELINES:
_BURST_PIPELINES[pipeline_type] = BurstColorPipeline(
RUNTIME_BURST_CONFIGS[pipeline_type],
RUNTIME_DEVICE,
pipeline_type,
)
_BURST_PIPELINES[pipeline_type].load()
return _BURST_PIPELINES[pipeline_type]
def _get_runtime_burst_out_size(pipeline_type: str) -> int:
if pipeline_type not in RUNTIME_BURST_OUT_SIZES:
return DEFAULT_MAX_SIZE[0]
return int(RUNTIME_BURST_OUT_SIZES[pipeline_type])
def _resolve_uploaded_path(uploaded_file: Any, local_path: str) -> Optional[str]:
if isinstance(uploaded_file, str) and uploaded_file:
return uploaded_file
if hasattr(uploaded_file, "name") and uploaded_file.name:
return uploaded_file.name
if isinstance(uploaded_file, dict) and uploaded_file.get("name"):
return uploaded_file["name"]
if local_path and local_path.strip():
return local_path.strip()
return None
def _list_image_files(dir_path: str) -> list[str]:
exts = {".png", ".jpg", ".jpeg", ".bmp", ".tif", ".tiff"}
files = []
for name in sorted(os.listdir(dir_path)):
p = os.path.join(dir_path, name)
if os.path.isfile(p) and Path(name).suffix.lower() in exts:
files.append(p)
return files
def _is_video_extension(ext: str) -> bool:
# Include common public-facing formats. Keep ".wav" for user compatibility;
# extraction will fail with a clear message since it is audio-only.
return ext in {".mp4", ".mov", ".m4v", ".avi", ".mkv", ".webm", ".wmv", ".wav"}
def _extract_video_frames_to_temp(video_path: str) -> tuple[str, list[str]]:
temp_dir = tempfile.mkdtemp(prefix="gqir_video_frames_")
out_pattern = os.path.join(temp_dir, "frame_%06d.png")
cmd = [
"ffmpeg",
"-hide_banner",
"-loglevel",
"error",
"-i",
video_path,
"-vsync",
"0",
"-start_number",
"0",
out_pattern,
]
try:
proc = subprocess.run(cmd, check=True, capture_output=True, text=True)
except FileNotFoundError as exc:
shutil.rmtree(temp_dir, ignore_errors=True)
raise RuntimeError("ffmpeg is not installed; video input requires ffmpeg.") from exc
except subprocess.CalledProcessError as exc:
stderr = (exc.stderr or "").strip()
shutil.rmtree(temp_dir, ignore_errors=True)
if Path(video_path).suffix.lower() == ".wav":
raise ValueError(
"WAV is audio-only and does not contain video frames. "
"Please upload MP4/MOV/WMV/AVI/MKV/WebM for GT video input."
) from exc
raise ValueError(f"Failed to decode video with ffmpeg: {stderr or 'unknown ffmpeg error'}") from exc
files = _list_image_files(temp_dir)
if not files:
shutil.rmtree(temp_dir, ignore_errors=True)
stderr = (proc.stderr or "").strip()
raise ValueError(f"No frames were extracted from video. {stderr}".strip())
_TEMP_VIDEO_DIRS.append(temp_dir)
return temp_dir, files
def _extract_first_array(obj: Any) -> np.ndarray:
if isinstance(obj, np.ndarray):
return obj
if isinstance(obj, torch.Tensor):
return obj.detach().cpu().numpy()
if isinstance(obj, (list, tuple)) and obj:
if all(isinstance(x, (np.ndarray, torch.Tensor)) for x in obj):
stacked = [x.detach().cpu().numpy() if isinstance(x, torch.Tensor) else x for x in obj]
return np.stack(stacked, axis=0)
return _extract_first_array(obj[0])
if isinstance(obj, dict):
preferred = ["cube", "frames", "lqs", "data", "array"]
for key in preferred:
if key in obj:
return _extract_first_array(obj[key])
for value in obj.values():
try:
return _extract_first_array(value)
except Exception:
continue
raise ValueError("Could not extract ndarray/tensor from input object.")
def _inspect_array_file(path: str) -> tuple[str, Optional[str], int]:
ext = Path(path).suffix.lower()
if ext == ".npy":
arr = np.load(path, mmap_mode="r")
arr = _to_thwc(arr)
return "npy", None, int(arr.shape[0])
if ext == ".npz":
with np.load(path) as npz_data:
if not npz_data.files:
raise ValueError("NPZ has no arrays.")
key = npz_data.files[0]
arr = _to_thwc(npz_data[key])
return "npz", key, int(arr.shape[0])
if ext in {".pt", ".pth"}:
try:
obj = torch.load(path, map_location="cpu", weights_only=False)
except TypeError:
obj = torch.load(path, map_location="cpu")
arr = _to_thwc(_extract_first_array(obj))
return "pt", None, int(arr.shape[0])
raise ValueError(f"Unsupported array file extension: {ext}")
def _load_array_window(desc: CubeDescriptor, start: int, count: int) -> np.ndarray:
assert desc.array_format is not None
path = desc.path
fmt = desc.array_format
if fmt == "npy":
arr = np.load(path, mmap_mode="r")
arr = arr[start : start + count]
elif fmt == "npz":
with np.load(path) as npz_data:
assert desc.array_key is not None
arr = npz_data[desc.array_key][start : start + count]
elif fmt == "pt":
try:
obj = torch.load(path, map_location="cpu", weights_only=False)
except TypeError:
obj = torch.load(path, map_location="cpu")
arr = _extract_first_array(obj)
arr = arr[start : start + count]
else:
raise ValueError(f"Unsupported array format in descriptor: {fmt}")
arr = _to_thwc(arr)
arr = _normalize_float01(arr)
return arr
def _load_h5_window(desc: CubeDescriptor, start: int, count: int) -> np.ndarray:
if h5py is None:
raise RuntimeError("h5py is required for .h5 photon cube loading. Install with: pip install h5py")
assert desc.h5_keys is not None
frames = []
with h5py.File(desc.path, "r") as h5f:
grp = h5f["capture_integrated"]["raw_hdf5"]
for idx in range(start, start + count):
key_slice = desc.h5_keys[idx * 4 : (idx + 1) * 4]
if len(key_slice) < 4:
raise ValueError("H5 does not contain enough raw planes for the requested frame window.")
sample_r = np.asarray(grp[key_slice[0]])[:, :, 0, 0].astype(np.float32)
sample_g1 = np.asarray(grp[key_slice[1]])[:, :, 0, 0].astype(np.float32)
sample_b = np.asarray(grp[key_slice[2]])[:, :, 0, 0].astype(np.float32)
sample_g2 = np.asarray(grp[key_slice[3]])[:, :, 0, 0].astype(np.float32)
h, w = sample_r.shape
bayer_rgb = np.zeros((h, w, 3), dtype=np.float32)
bayer_rgb[0::2, 0::2, 0] = sample_r[0::2, 0::2]
bayer_rgb[0::2, 1::2, 1] = sample_g1[0::2, 1::2]
bayer_rgb[1::2, 0::2, 1] = sample_g2[1::2, 0::2]
bayer_rgb[1::2, 1::2, 2] = sample_b[1::2, 1::2]
frames.append(bayer_rgb)
out = np.stack(frames, axis=0)
return _normalize_float01(out)
def _load_window_from_descriptor(
desc: CubeDescriptor,
start: int,
count: int,
pipeline_type: str = PIPELINE_COLOR,
resize_for_model: bool = True,
) -> np.ndarray:
if start < 0 or start + count > desc.total_frames:
raise ValueError(
f"Invalid start index {start}. Valid range is [0, {max(desc.total_frames - count, 0)}]."
)
if desc.kind in {"dir", "video"}:
assert desc.files is not None
subset = desc.files[start : start + count]
frames = []
for p in subset:
img = Image.open(p).convert("RGB")
frames.append(np.asarray(img, dtype=np.float32) / 255.0)
out = np.stack(frames, axis=0)
elif desc.kind == "array":
out = _load_array_window(desc, start, count)
elif desc.kind == "h5":
out = _load_h5_window(desc, start, count)
else:
raise ValueError(f"Unknown descriptor kind: {desc.kind}")
if out.shape[-1] == 1:
if desc.source_mode == BURST_MODE_GT or pipeline_type == PIPELINE_MONO:
out = np.repeat(out, 3, axis=-1)
else:
out = _single_channel_bayer_to_sparse_rgb(out)
if out.shape[-1] != 3:
raise ValueError(f"Expected 3 channels after conversion, got shape {out.shape}")
if not resize_for_model:
return _normalize_float01(out)
return _resize_frames_rgb(out, max_side=desc.out_size, multiple_of=64)
def _build_cube_descriptor(source_mode: str, path: str, out_size: int) -> CubeDescriptor:
p = Path(path)
if not p.exists():
raise FileNotFoundError(f"Path does not exist: {path}")
if p.is_dir():
files = _list_image_files(path)
if not files:
raise ValueError("Directory has no supported image files.")
return CubeDescriptor(
source_mode=source_mode,
kind="dir",
path=path,
total_frames=len(files),
out_size=out_size,
files=files,
)
ext = p.suffix.lower()
if _is_video_extension(ext):
if source_mode != BURST_MODE_GT:
raise ValueError("Video files are currently supported for GT burst mode only.")
temp_dir, files = _extract_video_frames_to_temp(path)
return CubeDescriptor(
source_mode=source_mode,
kind="video",
path=path,
total_frames=len(files),
out_size=out_size,
files=files,
temp_dir=temp_dir,
)
if ext in {".npy", ".npz", ".pt", ".pth"}:
fmt, key, total = _inspect_array_file(path)
return CubeDescriptor(
source_mode=source_mode,
kind="array",
path=path,
total_frames=total,
out_size=out_size,
array_format=fmt,
array_key=key,
)
if ext in {".h5", ".hdf5"}:
if source_mode != BURST_MODE_REAL:
raise ValueError("H5/UBI input is only supported for real photon cube mode.")
if h5py is None:
raise RuntimeError("h5py is required for .h5 photon cube loading. Install with: pip install h5py")
with h5py.File(path, "r") as h5f:
try:
grp = h5f["capture_integrated"]["raw_hdf5"]
except Exception as exc:
raise ValueError("Expected H5 group capture_integrated/raw_hdf5") from exc
keys = [k for k in grp.keys()]
keys = sorted(keys, key=lambda x: int(x) if str(x).isdigit() else x)
total = len(keys) // 4
if total <= 0:
raise ValueError("No usable frame groups found in H5 file.")
return CubeDescriptor(
source_mode=source_mode,
kind="h5",
path=path,
total_frames=total,
out_size=out_size,
h5_keys=keys,
)
raise ValueError(f"Unsupported cube input format: {p.suffix}")
def _single_inputs_visibility(mode: str):
gt_visible = mode == SINGLE_MODE_GT
return (
gr.update(visible=gt_visible),
gr.update(visible=not gt_visible),
gr.update(visible=gt_visible),
)
def _burst_ppp_interactivity(mode: str):
return gr.update(interactive=(mode == BURST_MODE_GT))
@spaces.GPU
def run_single_reconstruction(
pipeline_type: str,
mode: str,
gt_image: Optional[np.ndarray],
real_spad_image: Optional[np.ndarray],
prompt: str,
target_ppp: float,
only_vae_output: bool,
seed: int,
):
try:
pipeline = _get_single_pipeline(pipeline_type)
prompt = (prompt or "").strip()
seed = int(seed)
if mode == SINGLE_MODE_GT:
if gt_image is None:
raise ValueError("Please provide a GT image.")
in_h, in_w = _ensure_rgb_image(gt_image).shape[:2]
gt_prev, lq_prev, recon, status = pipeline.reconstruct_from_gt(
gt_image_np=gt_image,
prompt=prompt,
target_ppp=float(target_ppp),
only_vae_output=bool(only_vae_output),
seed=seed,
simulate_color_mosaic=(pipeline_type == PIPELINE_COLOR),
)
recon = _resize_uint8_to_hw(recon, in_h, in_w)
input_preview = _to_uint8(_normalize_float01(_ensure_rgb_image(gt_image)))
if pipeline_type == PIPELINE_MONO:
input_preview = _to_gray_uint8(input_preview)
lq_prev = _to_gray_uint8(lq_prev)
recon = _to_gray_uint8(recon)
return input_preview, recon, lq_prev, status
if real_spad_image is None:
raise ValueError("Please provide a real SPAD frame.")
in_h, in_w = _ensure_rgb_image(real_spad_image).shape[:2]
lq_prev, recon, status = pipeline.reconstruct_from_real_spad(
lq_image_np=real_spad_image,
prompt=prompt,
only_vae_output=bool(only_vae_output),
seed=seed,
)
recon = _resize_uint8_to_hw(recon, in_h, in_w)
input_preview = _to_uint8(_normalize_float01(_ensure_rgb_image(real_spad_image)))
if pipeline_type == PIPELINE_MONO:
input_preview = _to_gray_uint8(input_preview)
lq_prev = _to_gray_uint8(lq_prev)
recon = _to_gray_uint8(recon)
return input_preview, recon, lq_prev, status
except Exception as exc:
msg = f"Single reconstruction failed: {exc}"
tb = traceback.format_exc(limit=1)
return None, None, None, f"{msg}\n{tb}"
def load_cube_for_ui(pipeline_type: str, mode: str, cube_file: Any, cube_path: str):
try:
path = _resolve_uploaded_path(cube_file, cube_path)
if not path:
raise ValueError("Provide a cube file upload or local path.")
descriptor = _build_cube_descriptor(mode, path, out_size=_get_runtime_burst_out_size(pipeline_type))
if descriptor.total_frames < BURST_WINDOW:
raise ValueError(
f"Cube has {descriptor.total_frames} frames, but gQIR burst requires at least {BURST_WINDOW}."
)
max_start = descriptor.total_frames - BURST_WINDOW
slider_update = gr.update(minimum=0, maximum=max_start, value=0, step=1, interactive=True)
preview_window = _load_window_from_descriptor(
descriptor,
start=0,
count=1,
pipeline_type=pipeline_type,
resize_for_model=False,
)
preview = _to_uint8(preview_window[0])
if pipeline_type == PIPELINE_MONO:
preview = _to_gray_uint8(preview)
input_display_preview = preview
model_input_preview = None if mode == BURST_MODE_GT else preview
info = (
f"Loaded cube: {descriptor.path}\n"
f"Input type: {descriptor.kind}\n"
f"Frames: {descriptor.total_frames}\n"
f"Valid start index range: [0, {max_start}]\n"
f"Window size fixed at {BURST_WINDOW}"
)
return descriptor, info, slider_update, input_display_preview, model_input_preview, "Cube loaded successfully."
except Exception as exc:
err = f"Cube load failed: {exc}"
return None, err, gr.update(interactive=False), None, None, err
@spaces.GPU
def run_burst_reconstruction(
pipeline_type: str,
mode: str,
descriptor: Optional[CubeDescriptor],
start_idx: int,
target_ppp: float,
):
try:
if descriptor is None:
raise ValueError("Load a burst cube first.")
start_idx = int(start_idx)
gt_window = None
raw_window = _load_window_from_descriptor(
descriptor,
start=start_idx,
count=BURST_WINDOW,
pipeline_type=pipeline_type,
resize_for_model=False,
)
raw_center_h, raw_center_w = raw_window[BURST_WINDOW // 2].shape[:2]
if mode == BURST_MODE_GT:
gt_window = raw_window
if pipeline_type == PIPELINE_COLOR:
binary_window = np.stack(
[
_simulate_binary_burst_frame_from_gt(gt_window[i], target_ppp=float(target_ppp))
for i in range(BURST_WINDOW)
],
axis=0,
).astype(np.float32)
else:
binary_window = np.stack(
[
_simulate_binary_burst_frame_from_gt_mono(gt_window[i], target_ppp=float(target_ppp))
for i in range(BURST_WINDOW)
],
axis=0,
).astype(np.float32)
else:
binary_window = raw_window
pipeline = _get_burst_pipeline(pipeline_type)
center_input, recon, center_gt = pipeline.reconstruct_from_binary_window(
binary_window_77=binary_window,
gt_window_77=gt_window,
)
center_input = _resize_uint8_to_hw(center_input, raw_center_h, raw_center_w)
recon = _resize_uint8_to_hw(recon, raw_center_h, raw_center_w)
center_gt = _resize_uint8_to_hw(center_gt, raw_center_h, raw_center_w)
display_input = _to_uint8(raw_window[BURST_WINDOW // 2])
display_input = _resize_uint8_to_hw(display_input, raw_center_h, raw_center_w)
if pipeline_type == PIPELINE_MONO:
display_input = _to_gray_uint8(display_input)
center_input = _to_gray_uint8(center_input)
recon = _to_gray_uint8(recon)
center_gt = _to_gray_uint8(center_gt)
ppp_status = f"PPP={float(target_ppp):.2f}" if mode == BURST_MODE_GT else "PPP=ignored (real cube input)"
status = (
f"Burst reconstruction complete. "
f"Pipeline={pipeline_type}, "
f"Input mode={'GT simulation' if mode == BURST_MODE_GT else 'real photon cube'}, "
f"{ppp_status}, "
f"window=[{start_idx}, {start_idx + BURST_WINDOW - 1}]."
)
return display_input, recon, center_input, status
except Exception as exc:
msg = f"Burst reconstruction failed: {exc}"
tb = traceback.format_exc(limit=1)
return None, None, None, f"{msg}\n{tb}"
def build_demo() -> gr.Blocks:
markdown = """
<h1 align="center">gQIR: Generative Quanta Image Reconstruction</h1>
<p align="center">
<a href="https://aryan-garg.github.io/gqir/">Project Page</a> |
<a href="https://arxiv.org/abs/2602.20417">ArXiv</a> |
<a href="https://github.com/Aryan-Garg/gQIR">GitHub</a>
</p>
### What You Can Run
- **Single Frame (Stage-2):** Reconstruct one frame from either a clean GT image (internally simulated to SPAD) or a real SPAD frame.
- **Burst (Stage-3):** Reconstruct from a fixed **77-frame** window using either GT videos/cubes or real photon cubes.
- **Pipelines:** Toggle between **Color** and **Monochrome** reconstruction in both tabs.
### Supported Inputs
- **Single GT:** Standard image uploads.
- **Single Real:** Real SPAD frame image uploads.
- **Burst GT:** Public-friendly videos (`.mp4`, `.mov`, `.wmv`, `.avi`, `.mkv`, `.webm`) plus research cube formats (`.npy`, `.npz`, `.pt`, `.h5`) or image folders.
- **Burst Real:** Photon cubes (`.npy`, `.npz`, `.pt`, `.h5`) or image folders.
### Quick Usage
1. Pick pipeline and input mode.
2. Load input and select burst start index (for Stage-3).
3. Set PPP for GT simulation paths, then run reconstruction and compare input vs output side-by-side.
"""
with gr.Blocks(title="gQIR Demo") as demo:
gr.Markdown(markdown)
with gr.Tab("Single Frame (Stage-2)"):
with gr.Row():
with gr.Column():
single_pipeline_type = gr.Radio(
PIPELINE_OPTIONS,
value=PIPELINE_COLOR,
label="Pipeline",
)
single_mode = gr.Radio(
[SINGLE_MODE_GT, SINGLE_MODE_REAL],
value=SINGLE_MODE_GT,
label="Input Mode",
)
gt_image = gr.Image(label="GT Image", type="numpy")
real_spad_image = gr.Image(label="Real SPAD Frame", type="numpy", visible=False)
prompt = gr.Textbox(label="Prompt (optional)", value="")
target_ppp = gr.Slider(
minimum=0.25,
maximum=5.0,
value=3.5,
step=0.25,
label="Target PPP (GT simulation only)",
)
only_vae_output = gr.Checkbox(label="Stage 1 (qVAE) output only", value=False)
seed = gr.Number(label="Seed (-1 for random)", value=310, precision=0)
run_single_btn = gr.Button("Run Single Reconstruction")
with gr.Column():
with gr.Row():
single_input_preview = gr.Image(label="Input Frame", type="numpy")
single_output_preview = gr.Image(label="Reconstruction (original aspect)", type="numpy")
single_model_input_preview = gr.Image(label="Model Input (resized for inference)", type="numpy")
single_status = gr.Textbox(label="Status", interactive=False)
single_mode.change(
fn=_single_inputs_visibility,
inputs=[single_mode],
outputs=[gt_image, real_spad_image, target_ppp],
)
run_single_btn.click(
fn=run_single_reconstruction,
inputs=[single_pipeline_type, single_mode, gt_image, real_spad_image, prompt, target_ppp, only_vae_output, seed],
outputs=[single_input_preview, single_output_preview, single_model_input_preview, single_status],
)
with gr.Tab("Burst (Stage-3)"):
cube_state = gr.State(value=None)
with gr.Row():
with gr.Column():
burst_pipeline_type = gr.Radio(
PIPELINE_OPTIONS,
value=PIPELINE_COLOR,
label="Pipeline",
)
burst_mode = gr.Radio(
[BURST_MODE_GT, BURST_MODE_REAL],
value=BURST_MODE_GT,
label="Burst Input Mode",
)
cube_file = gr.File(
label=(
"GT video/cube file "
"(.mp4/.mov/.wmv/.avi/.mkv/.webm/.npy/.npz/.pt/.h5) "
"or image directory path below"
),
type="filepath",
)
cube_path = gr.Textbox(label="Or Local Cube Path (file or folder)", value="")
load_cube_btn = gr.Button("Load Cube")
cube_info = gr.Textbox(label="Cube Info", interactive=False)
start_idx = gr.Slider(
minimum=0,
maximum=0,
value=0,
step=1,
interactive=False,
label=f"Start Index (window size fixed to {BURST_WINDOW})",
)
burst_target_ppp = gr.Slider(
minimum=0.25,
maximum=5.0,
value=3.5,
step=0.25,
label="Target PPP (GT simulation only)",
)
run_burst_btn = gr.Button("Run Burst Reconstruction")
with gr.Column():
with gr.Row():
burst_input_display = gr.Image(label="Input Center Frame", type="numpy")
burst_recon = gr.Image(label="Reconstruction (original aspect)", type="numpy")
burst_model_input = gr.Image(label="Model Input Center (post-processing)", type="numpy")
burst_status = gr.Textbox(label="Status", interactive=False)
load_cube_btn.click(
fn=load_cube_for_ui,
inputs=[burst_pipeline_type, burst_mode, cube_file, cube_path],
outputs=[cube_state, cube_info, start_idx, burst_input_display, burst_model_input, burst_status],
)
burst_mode.change(
fn=_burst_ppp_interactivity,
inputs=[burst_mode],
outputs=[burst_target_ppp],
)
run_burst_btn.click(
fn=run_burst_reconstruction,
inputs=[burst_pipeline_type, burst_mode, cube_state, start_idx, burst_target_ppp],
outputs=[burst_input_display, burst_recon, burst_model_input, burst_status],
)
return demo
def main() -> None:
global RUNTIME_SINGLE_CONFIGS, RUNTIME_BURST_CONFIGS, RUNTIME_DEVICE, RUNTIME_BURST_OUT_SIZES
global RUNTIME_HF_REPO_ID, RUNTIME_HF_CACHE_DIR, RUNTIME_HF_TOKEN
args = parse_args()
single_color_cfg = Path(args.single_config if args.single_config else args.single_config_color).resolve()
burst_color_cfg = Path(args.burst_config if args.burst_config else args.burst_config_color).resolve()
single_mono_cfg = Path(args.single_config_mono).resolve()
burst_mono_cfg = Path(args.burst_config_mono).resolve()
hf_token = args.hf_token
if not hf_token:
hf_token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_HUB_TOKEN")
RUNTIME_SINGLE_CONFIGS = {
PIPELINE_COLOR: single_color_cfg,
PIPELINE_MONO: single_mono_cfg,
}
RUNTIME_BURST_CONFIGS = {
PIPELINE_COLOR: burst_color_cfg,
PIPELINE_MONO: burst_mono_cfg,
}
RUNTIME_DEVICE = args.device
RUNTIME_HF_REPO_ID = str(args.hf_repo_id or HF_DEFAULT_REPO_ID).strip()
RUNTIME_HF_CACHE_DIR = str(Path(args.hf_cache_dir).expanduser()) if args.hf_cache_dir else None
RUNTIME_HF_TOKEN = hf_token
RUNTIME_BURST_OUT_SIZES = {}
for key, cfg_path in RUNTIME_BURST_CONFIGS.items():
burst_cfg = OmegaConf.load(str(cfg_path))
RUNTIME_BURST_OUT_SIZES[key] = int(burst_cfg.dataset.val.params.out_size)
demo = build_demo().queue()
demo.launch()
if __name__ == "__main__":
main()