Lotus-2_Normal / app.py
sciencellama's picture
Rebuild: fp16 + scipy.ndimage.median_filter(ksize=11) — staircase fix + FLUX 16-px grid removal
55b9c4f verified
Raw
History Blame Contribute Delete
4.76 kB
import spaces # must be first!
import sys
import os
import torch
from PIL import Image
import gradio as gr
from glob import glob
from contextlib import nullcontext
import numpy as np
import cv2
import tempfile
from pipeline import Lotus2Pipeline
from diffusers import (
FlowMatchEulerDiscreteScheduler,
FluxTransformer2DModel,
)
from infer import (
load_lora_and_lcm_weights,
process_single_image
)
from huggingface_hub import login
import os
login(token=os.getenv("HF_TOKEN"))
pipeline = None
device = "cuda" if torch.cuda.is_available() else "cpu"
weight_dtype = torch.float16
task = None
def load_pipeline():
global pipeline, device, weight_dtype, task
noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
'black-forest-labs/FLUX.1-dev', subfolder="scheduler", num_train_timesteps=10
)
transformer = FluxTransformer2DModel.from_pretrained(
'black-forest-labs/FLUX.1-dev', subfolder="transformer", revision=None, variant=None
)
transformer.requires_grad_(False)
transformer.to(device=device, dtype=weight_dtype)
transformer, local_continuity_module = load_lora_and_lcm_weights(transformer, None, None, None, task)
pipeline = Lotus2Pipeline.from_pretrained(
'black-forest-labs/FLUX.1-dev',
scheduler=noise_scheduler,
transformer=transformer,
revision=None,
variant=None,
torch_dtype=weight_dtype,
)
pipeline.local_continuity_module = local_continuity_module
pipeline = pipeline.to(device)
def _save_raw_outputs(output_npy, task):
"""Persist the raw float prediction losslessly (.npy, float32 in [0,1])
plus a 16-bit PNG, returned as downloadable gr.File outputs so clients can
fetch full-precision geometry instead of an 8-bit visualization.
Depth .npy is (H,W) in [0,1]; normal .npy is (H,W,3) in [0,1] = (n+1)/2."""
out_dir = tempfile.mkdtemp(prefix="lotus2_")
npy_path = os.path.join(out_dir, f"{task}.npy")
np.save(npy_path, output_npy.astype(np.float32))
arr16 = (np.clip(output_npy, 0.0, 1.0) * 65535.0 + 0.5).astype(np.uint16)
png16_path = os.path.join(out_dir, f"{task}_16bit.png")
if arr16.ndim == 3: # normal: RGB array, cv2 writes BGR -> swap
cv2.imwrite(png16_path, arr16[:, :, ::-1])
else: # depth: single-channel 16-bit grayscale
cv2.imwrite(png16_path, arr16)
return npy_path, png16_path
@spaces.GPU
def fn(image_path):
global pipeline, device, task
pipeline.set_progress_bar_config(disable=True)
with nullcontext():
_, output_vis, output_npy = process_single_image(
image_path, pipeline,
task_name=task,
device=device,
num_inference_steps=10,
process_res=1024
)
npy_path, png16_path = _save_raw_outputs(output_npy, task)
return [Image.open(image_path), output_vis], npy_path, png16_path
def build_demo():
global task
inputs = [
gr.Image(label="Image", type="filepath")
]
outputs = [
gr.ImageSlider(
label=f"{task.title()}",
type="pil",
slider_position=20,
),
gr.File(label=f"Raw float32 {task}.npy (lossless, [0,1])"),
gr.File(label=f"16-bit {task} PNG"),
]
examples = glob(f"assets/demo_examples/{task}/*.png") + glob(f"assets/demo_examples/{task}/*.jpg")
demo = gr.Interface(
fn=fn,
title="Lotus-2 Normal (FP16 + median degrid)",
description=f"""
<strong>Please consider starring <span style="color: orange">&#9733;</span> our <a href="https://github.com/EnVision-Research/Lotus-2" target="_blank" rel="noopener noreferrer">GitHub Repo</a> if you find this demo useful! 😊</strong>
<br>
<strong>Current Task: </strong><strong style="color: red;">{task.title()}</strong>
<br>
Runs the FLUX transformer in <strong>float16</strong> (finer mantissa than bfloat16)
and applies a per-channel <strong>scipy.ndimage.median_filter (ksize=11)</strong>
to the raw prediction to strip the FLUX-VAE 16-px patch grid
while preserving sharp normal edges.
""",
inputs=inputs,
outputs=outputs,
examples=examples,
examples_per_page=10
)
return demo
def main(task_name):
global task
task = task_name
load_pipeline()
demo = build_demo()
demo.launch(
# server_name="0.0.0.0",
# server_port=6382,
)
if __name__ == "__main__":
task_name = "normal"
if not task_name in ['depth', 'normal']:
raise ValueError("Invalid task. Please choose from 'depth' and 'normal'.")
main(task_name)