import spaces # must be first!
import sys
import os
import torch
from PIL import Image
import gradio as gr
from glob import glob
from contextlib import nullcontext
import numpy as np
import cv2
import tempfile
from pipeline import Lotus2Pipeline
from diffusers import (
FlowMatchEulerDiscreteScheduler,
FluxTransformer2DModel,
)
from infer import (
load_all_task_weights,
process_single_image,
)
from huggingface_hub import login
login(token=os.getenv("HF_TOKEN"))
pipeline = None
device = "cuda" if torch.cuda.is_available() else "cpu"
weight_dtype = torch.bfloat16
TASKS = ("depth", "normal")
def load_pipeline():
global pipeline, device, weight_dtype
noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
'black-forest-labs/FLUX.1-dev', subfolder="scheduler", num_train_timesteps=10
)
transformer = FluxTransformer2DModel.from_pretrained(
'black-forest-labs/FLUX.1-dev', subfolder="transformer", revision=None, variant=None
)
transformer.requires_grad_(False)
transformer.to(device=device, dtype=weight_dtype)
# Load BOTH tasks' adapters (depth + normal) + both LCMs onto the one
# shared FLUX transformer. set_adapter() switches between them per pass.
transformer, lcms = load_all_task_weights(transformer, TASKS)
pipeline = Lotus2Pipeline.from_pretrained(
'black-forest-labs/FLUX.1-dev',
scheduler=noise_scheduler,
transformer=transformer,
revision=None,
variant=None,
torch_dtype=weight_dtype,
)
pipeline._lcms = lcms
pipeline = pipeline.to(device)
def _save_raw_outputs(output_npy, task):
"""Lossless raw float32 .npy + 16-bit PNG. Depth .npy is (H,W) in [0,1];
normal .npy is (H,W,3) in [0,1] = (n+1)/2 (directly load_normals-compatible)."""
out_dir = tempfile.mkdtemp(prefix="lotus2_")
npy_path = os.path.join(out_dir, f"{task}.npy")
np.save(npy_path, output_npy.astype(np.float32))
arr16 = (np.clip(output_npy, 0.0, 1.0) * 65535.0 + 0.5).astype(np.uint16)
png16_path = os.path.join(out_dir, f"{task}_16bit.png")
if arr16.ndim == 3: # normal: RGB array, cv2 writes BGR -> swap
cv2.imwrite(png16_path, arr16[:, :, ::-1])
else: # depth: single-channel 16-bit grayscale
cv2.imwrite(png16_path, arr16)
return npy_path, png16_path
def _run_task(image_path, task, process_res):
# Activate this task's adapters + LCM, then run the standard pipeline.
pipeline._core_adapter_name = f"{task}_core_predictor"
pipeline._sharpener_adapter_name = f"{task}_detail_sharpener"
pipeline.local_continuity_module = pipeline._lcms[task]
_, output_vis, output_npy = process_single_image(
image_path, pipeline,
task_name=task,
device=device,
num_inference_steps=10,
process_res=int(process_res),
core_only=True, # this Space = core-only experiment (skip detail sharpener)
)
npy_path, png16_path = _save_raw_outputs(output_npy, task)
return output_vis, npy_path, png16_path
@spaces.GPU(duration=300)
def fn(image_path, process_res=1024):
global pipeline
pipeline.set_progress_bar_config(disable=True)
# process_res is a CEILING: the model predicts at min(input long side,
# process_res). Output array = input size. So a larger process_res only
# helps if the input is at least that large (upscale it first).
process_res = int(process_res) if process_res else 1024
with nullcontext():
depth_vis, depth_npy, depth_png16 = _run_task(image_path, "depth", process_res)
normal_vis, normal_npy, normal_png16 = _run_task(image_path, "normal", process_res)
inp = Image.open(image_path)
return (
[inp, depth_vis],
[inp, normal_vis],
depth_npy, depth_png16,
normal_npy, normal_png16,
)
def build_demo():
inputs = [
gr.Image(label="Image", type="filepath"),
gr.Slider(512, 2048, value=1024, step=128,
label="Process resolution (detail cap). Output detail = min(input long side, this). "
"Higher = finer but much more VRAM/time; >1536 may OOM on this GPU. "
"Feed an input at least this large (upscale first) to benefit."),
]
outputs = [
gr.ImageSlider(label="Depth", type="pil", slider_position=20),
gr.ImageSlider(label="Normal", type="pil", slider_position=20),
gr.File(label="Raw float32 depth.npy (lossless, [0,1])"),
gr.File(label="16-bit depth PNG"),
gr.File(label="Raw float32 normal.npy (lossless, [0,1])"),
gr.File(label="16-bit normal PNG"),
]
_ex = (glob("assets/demo_examples/depth/*.png")
+ glob("assets/demo_examples/depth/*.jpg"))
examples = [[p, 1024] for p in _ex] # [image, process_res] per the 2 inputs
demo = gr.Interface(
fn=fn,
title="Lotus-2 Geometry CORE-ONLY (experiment): Depth + Normal, single-stage",
description="""
Core-only variant of Lotus-2 Geometry. Skips the detail-sharpener
second stage entirely — the LCM output goes straight to the VAE decoder.
Mirrors Lotus-1's single-stage smoothness, on Lotus-2's improved base predictor.
No post-processing (no median de-grid) so the raw model behaviour is visible.
Returns previews + raw float32 .npy + 16-bit PNG for both depth and normal.
""",
inputs=inputs,
outputs=outputs,
examples=examples,
examples_per_page=10,
cache_examples=False,
)
return demo
def main():
load_pipeline()
demo = build_demo()
demo.launch(
# server_name="0.0.0.0",
# server_port=6383,
)
if __name__ == "__main__":
main()