Spaces:
Sleeping
Sleeping
Rebuild: fp16 + scipy.ndimage.median_filter(ksize=11) — staircase fix + FLUX 16-px grid removal
55b9c4f verified | import spaces # must be first! | |
| import sys | |
| import os | |
| import torch | |
| from PIL import Image | |
| import gradio as gr | |
| from glob import glob | |
| from contextlib import nullcontext | |
| import numpy as np | |
| import cv2 | |
| import tempfile | |
| from pipeline import Lotus2Pipeline | |
| from diffusers import ( | |
| FlowMatchEulerDiscreteScheduler, | |
| FluxTransformer2DModel, | |
| ) | |
| from infer import ( | |
| load_lora_and_lcm_weights, | |
| process_single_image | |
| ) | |
| from huggingface_hub import login | |
| import os | |
| login(token=os.getenv("HF_TOKEN")) | |
| pipeline = None | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| weight_dtype = torch.float16 | |
| task = None | |
| def load_pipeline(): | |
| global pipeline, device, weight_dtype, task | |
| noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( | |
| 'black-forest-labs/FLUX.1-dev', subfolder="scheduler", num_train_timesteps=10 | |
| ) | |
| transformer = FluxTransformer2DModel.from_pretrained( | |
| 'black-forest-labs/FLUX.1-dev', subfolder="transformer", revision=None, variant=None | |
| ) | |
| transformer.requires_grad_(False) | |
| transformer.to(device=device, dtype=weight_dtype) | |
| transformer, local_continuity_module = load_lora_and_lcm_weights(transformer, None, None, None, task) | |
| pipeline = Lotus2Pipeline.from_pretrained( | |
| 'black-forest-labs/FLUX.1-dev', | |
| scheduler=noise_scheduler, | |
| transformer=transformer, | |
| revision=None, | |
| variant=None, | |
| torch_dtype=weight_dtype, | |
| ) | |
| pipeline.local_continuity_module = local_continuity_module | |
| pipeline = pipeline.to(device) | |
| def _save_raw_outputs(output_npy, task): | |
| """Persist the raw float prediction losslessly (.npy, float32 in [0,1]) | |
| plus a 16-bit PNG, returned as downloadable gr.File outputs so clients can | |
| fetch full-precision geometry instead of an 8-bit visualization. | |
| Depth .npy is (H,W) in [0,1]; normal .npy is (H,W,3) in [0,1] = (n+1)/2.""" | |
| out_dir = tempfile.mkdtemp(prefix="lotus2_") | |
| npy_path = os.path.join(out_dir, f"{task}.npy") | |
| np.save(npy_path, output_npy.astype(np.float32)) | |
| arr16 = (np.clip(output_npy, 0.0, 1.0) * 65535.0 + 0.5).astype(np.uint16) | |
| png16_path = os.path.join(out_dir, f"{task}_16bit.png") | |
| if arr16.ndim == 3: # normal: RGB array, cv2 writes BGR -> swap | |
| cv2.imwrite(png16_path, arr16[:, :, ::-1]) | |
| else: # depth: single-channel 16-bit grayscale | |
| cv2.imwrite(png16_path, arr16) | |
| return npy_path, png16_path | |
| def fn(image_path): | |
| global pipeline, device, task | |
| pipeline.set_progress_bar_config(disable=True) | |
| with nullcontext(): | |
| _, output_vis, output_npy = process_single_image( | |
| image_path, pipeline, | |
| task_name=task, | |
| device=device, | |
| num_inference_steps=10, | |
| process_res=1024 | |
| ) | |
| npy_path, png16_path = _save_raw_outputs(output_npy, task) | |
| return [Image.open(image_path), output_vis], npy_path, png16_path | |
| def build_demo(): | |
| global task | |
| inputs = [ | |
| gr.Image(label="Image", type="filepath") | |
| ] | |
| outputs = [ | |
| gr.ImageSlider( | |
| label=f"{task.title()}", | |
| type="pil", | |
| slider_position=20, | |
| ), | |
| gr.File(label=f"Raw float32 {task}.npy (lossless, [0,1])"), | |
| gr.File(label=f"16-bit {task} PNG"), | |
| ] | |
| examples = glob(f"assets/demo_examples/{task}/*.png") + glob(f"assets/demo_examples/{task}/*.jpg") | |
| demo = gr.Interface( | |
| fn=fn, | |
| title="Lotus-2 Normal (FP16 + median degrid)", | |
| description=f""" | |
| <strong>Please consider starring <span style="color: orange">★</span> our <a href="https://github.com/EnVision-Research/Lotus-2" target="_blank" rel="noopener noreferrer">GitHub Repo</a> if you find this demo useful! 😊</strong> | |
| <br> | |
| <strong>Current Task: </strong><strong style="color: red;">{task.title()}</strong> | |
| <br> | |
| Runs the FLUX transformer in <strong>float16</strong> (finer mantissa than bfloat16) | |
| and applies a per-channel <strong>scipy.ndimage.median_filter (ksize=11)</strong> | |
| to the raw prediction to strip the FLUX-VAE 16-px patch grid | |
| while preserving sharp normal edges. | |
| """, | |
| inputs=inputs, | |
| outputs=outputs, | |
| examples=examples, | |
| examples_per_page=10 | |
| ) | |
| return demo | |
| def main(task_name): | |
| global task | |
| task = task_name | |
| load_pipeline() | |
| demo = build_demo() | |
| demo.launch( | |
| # server_name="0.0.0.0", | |
| # server_port=6382, | |
| ) | |
| if __name__ == "__main__": | |
| task_name = "normal" | |
| if not task_name in ['depth', 'normal']: | |
| raise ValueError("Invalid task. Please choose from 'depth' and 'normal'.") | |
| main(task_name) | |