import spaces # must be first! import sys import os import torch from PIL import Image import gradio as gr from glob import glob from contextlib import nullcontext import numpy as np import cv2 import tempfile from pipeline import Lotus2Pipeline from diffusers import ( FlowMatchEulerDiscreteScheduler, FluxTransformer2DModel, ) from infer import ( load_lora_and_lcm_weights, process_single_image ) from huggingface_hub import login import os login(token=os.getenv("HF_TOKEN")) pipeline = None device = "cuda" if torch.cuda.is_available() else "cpu" weight_dtype = torch.float16 task = None def load_pipeline(): global pipeline, device, weight_dtype, task noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( 'black-forest-labs/FLUX.1-dev', subfolder="scheduler", num_train_timesteps=10 ) transformer = FluxTransformer2DModel.from_pretrained( 'black-forest-labs/FLUX.1-dev', subfolder="transformer", revision=None, variant=None ) transformer.requires_grad_(False) transformer.to(device=device, dtype=weight_dtype) transformer, local_continuity_module = load_lora_and_lcm_weights(transformer, None, None, None, task) pipeline = Lotus2Pipeline.from_pretrained( 'black-forest-labs/FLUX.1-dev', scheduler=noise_scheduler, transformer=transformer, revision=None, variant=None, torch_dtype=weight_dtype, ) pipeline.local_continuity_module = local_continuity_module pipeline = pipeline.to(device) def _save_raw_outputs(output_npy, task): """Persist the raw float prediction losslessly (.npy, float32 in [0,1]) plus a 16-bit PNG, returned as downloadable gr.File outputs so clients can fetch full-precision geometry instead of an 8-bit visualization. Depth .npy is (H,W) in [0,1]; normal .npy is (H,W,3) in [0,1] = (n+1)/2.""" out_dir = tempfile.mkdtemp(prefix="lotus2_") npy_path = os.path.join(out_dir, f"{task}.npy") np.save(npy_path, output_npy.astype(np.float32)) arr16 = (np.clip(output_npy, 0.0, 1.0) * 65535.0 + 0.5).astype(np.uint16) png16_path = os.path.join(out_dir, f"{task}_16bit.png") if arr16.ndim == 3: # normal: RGB array, cv2 writes BGR -> swap cv2.imwrite(png16_path, arr16[:, :, ::-1]) else: # depth: single-channel 16-bit grayscale cv2.imwrite(png16_path, arr16) return npy_path, png16_path @spaces.GPU def fn(image_path): global pipeline, device, task pipeline.set_progress_bar_config(disable=True) with nullcontext(): _, output_vis, output_npy = process_single_image( image_path, pipeline, task_name=task, device=device, num_inference_steps=10, process_res=1024 ) npy_path, png16_path = _save_raw_outputs(output_npy, task) return [Image.open(image_path), output_vis], npy_path, png16_path def build_demo(): global task inputs = [ gr.Image(label="Image", type="filepath") ] outputs = [ gr.ImageSlider( label=f"{task.title()}", type="pil", slider_position=20, ), gr.File(label=f"Raw float32 {task}.npy (lossless, [0,1])"), gr.File(label=f"16-bit {task} PNG"), ] examples = glob(f"assets/demo_examples/{task}/*.png") + glob(f"assets/demo_examples/{task}/*.jpg") demo = gr.Interface( fn=fn, title="Lotus-2 Normal (FP16 + median degrid)", description=f""" Please consider starring our GitHub Repo if you find this demo useful! 😊
Current Task: {task.title()}
Runs the FLUX transformer in float16 (finer mantissa than bfloat16) and applies a per-channel scipy.ndimage.median_filter (ksize=11) to the raw prediction to strip the FLUX-VAE 16-px patch grid while preserving sharp normal edges. """, inputs=inputs, outputs=outputs, examples=examples, examples_per_page=10 ) return demo def main(task_name): global task task = task_name load_pipeline() demo = build_demo() demo.launch( # server_name="0.0.0.0", # server_port=6382, ) if __name__ == "__main__": task_name = "normal" if not task_name in ['depth', 'normal']: raise ValueError("Invalid task. Please choose from 'depth' and 'normal'.") main(task_name)