import gradio as gr import numpy as np import random import spaces #[uncomment to use ZeroGPU] from diffusers import DiffusionPipeline import torch from huggingface_hub import hf_hub_download from safetensors.torch import load_file device = "cuda" if torch.cuda.is_available() else "cpu" # model_repo_id = "/data/stabilityai/sdxl-turbo" # Replace to the model you would like to use # # if torch.cuda.is_available(): # torch_dtype = torch.float16 # else: # torch_dtype = torch.float32 # # pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype) # pipe = pipe.to(device) # ------------------ set up InterLCM restorer ------------------- # import os import cv2 import argparse import glob import re import torch from torchvision.transforms.functional import normalize from basicsr.utils import imwrite, img2tensor, tensor2img from basicsr.utils.download_util import load_file_from_url from basicsr.utils.misc import gpu_is_available, get_device from facelib.utils.face_restoration_helper import FaceRestoreHelper from facelib.utils.misc import is_gray from basicsr.utils.registry import ARCH_REGISTRY # CILP import clip import torchvision.transforms as transforms from basicsr.utils.clip_util import VisionTransformer clip.model.VisionTransformer = VisionTransformer # LCM from diffusers import DiffusionPipeline, UNet2DConditionModel, ControlNetModel from basicsr.utils.lcm_utils import register_lcm_forward, register_lcmschedule_step from basicsr.archs.rrdbnet_arch import RRDBNet from basicsr.utils.realesrgan_utils import RealESRGANer from scripts.wavelet_color_fix import wavelet_reconstruction, adaptive_instance_normalization REPO_ID = "senmaonk/InterLCM" visual_encoder_path = "weights/InterLCM/visual_encoder_3step.pth" spatial_encoder_path = "weights/InterLCM/spatial_encoder_3step.pth" visual_encoder_path_1step = "weights/InterLCM/visual_encoder_1step.pth" spatial_encoder_path_1step = "weights/InterLCM/spatial_encoder_1step.pth" sd_path = "stable-diffusion-v1-5/stable-diffusion-v1-5" lcm_path = "SimianLuo/LCM_Dreamshaper_v7" detection_model = "retinaface_resnet50" def download_weights(FILENAME): print(f"Downloading {FILENAME} from {REPO_ID}...") local_path = hf_hub_download( repo_id=REPO_ID, filename=FILENAME, ) print(f"downloaded to: {local_path}") return local_path visual_encoder_path = download_weights(visual_encoder_path) spatial_encoder_path = download_weights(spatial_encoder_path) visual_encoder_path_1step = download_weights(visual_encoder_path_1step) spatial_encoder_path_1step = download_weights(spatial_encoder_path_1step) # CLIPImageEncoder clip_model, clip_preprocess = clip.load('ViT-B/16', device=device) preprocess = transforms.Compose([transforms.Normalize(mean=[-1.0, -1.0, -1.0], std=[2.0, 2.0, 2.0])] + # Un-normalize from [-1.0, 1.0] (GAN output) to [0, 1]. clip_preprocess.transforms[:2] + # to match CLIP input scale assumptions clip_preprocess.transforms[4:]) # + skip convert PIL to tensor # Visual Encoder visual_encoder = ARCH_REGISTRY.get('VisualEncoder')(nf=64, emb_dim=197, ch_mult=[2, 4, 8], res_blocks=2, img_size=512).to(device) checkpoint_ve = torch.load(visual_encoder_path)['params_ema'] visual_encoder.load_state_dict(checkpoint_ve) visual_encoder.eval() del checkpoint_ve # Spatial Encoder unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path=sd_path, subfolder="unet") spatial_encoder = ControlNetModel.from_unet(unet).to(device) checkpoint_c = torch.load(spatial_encoder_path)['params_ema'] spatial_encoder.load_state_dict(checkpoint_c) spatial_encoder.eval() del unet # Visual Encoder 1-step visual_encoder_1step = ARCH_REGISTRY.get('VisualEncoder')(nf=64, emb_dim=197, ch_mult=[2, 4, 8], res_blocks=2, img_size=512).to(device) checkpoint_ve = torch.load(visual_encoder_path_1step)['params_ema'] visual_encoder_1step.load_state_dict(checkpoint_ve) visual_encoder_1step.eval() del checkpoint_ve # Spatial Encoder unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path=sd_path, subfolder="unet") spatial_encoder_1step = ControlNetModel.from_unet(unet).to(device) checkpoint_c = torch.load(spatial_encoder_path_1step)['params_ema'] spatial_encoder_1step.load_state_dict(checkpoint_c) spatial_encoder_1step.eval() del unet torch.cuda.empty_cache() # lcm lcm = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path=lcm_path).to(device) # set enhancer with RealESRGAN def set_realesrgan(): half = True if torch.cuda.is_available() else False model = RRDBNet( num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2, ) upsampler = RealESRGANer( scale=2, model_path="weights/realesrgan/RealESRGAN_x2plus.pth", model=model, tile=400, tile_pad=40, pre_pad=0, half=half, device=device ) return upsampler upsampler = set_realesrgan() upscale = 2 face_helper = FaceRestoreHelper( upscale_factor=upscale, face_size=512, crop_ratio=(1, 1), det_model=detection_model, save_ext='png', use_parse=True, device=device) # ------------------ set up InterLCM restorer ------------------- # @spaces.GPU def inference(input_img, interlcm_step, face_align, background_enhance, face_upsample): # try: only_center_face = False draw_box = False interlcm_step = int(interlcm_step) assert interlcm_step in (1, 3) if interlcm_step == 1: register_lcm_forward(lcm, spatial_encoder_1step) elif interlcm_step == 3: register_lcm_forward(lcm, spatial_encoder) register_lcmschedule_step(lcm.scheduler) face_align = face_align if face_align is not None else True has_aligned = not face_align background_enhance = background_enhance if background_enhance is not None else True bg_upsampler = upsampler if background_enhance else None face_upsampler = upsampler if face_upsample else None img = cv2.imread(str(input_img), cv2.IMREAD_COLOR) print('\timage size:', img.shape) face_helper.clean_all() if has_aligned: # the input faces are already cropped and aligned img = cv2.resize(img, (512, 512), interpolation=cv2.INTER_LINEAR) face_helper.is_gray = is_gray(img, threshold=10) if face_helper.is_gray: print('Grayscale input: True') face_helper.cropped_faces = [img] else: face_helper.read_image(img) # get face landmarks for each face num_det_faces = face_helper.get_face_landmarks_5( only_center_face=only_center_face, resize=640, eye_dist_threshold=5, device=device) print(f'\tdetect {num_det_faces} faces') # align and warp each face face_helper.align_warp_face() # face restoration for each cropped face for idx, cropped_face in enumerate(face_helper.cropped_faces): # prepare data cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True) normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True) cropped_face_t = cropped_face_t.unsqueeze(0).to(device) try: with torch.no_grad(): input = preprocess(cropped_face_t) img_emb = clip_model.encode_image(input) img_emb = img_emb.to(torch.float) if interlcm_step == 1: visual_feat = visual_encoder_1step(img_emb) elif interlcm_step == 3: visual_feat = visual_encoder(img_emb) latent_code = lcm.vae.encode(cropped_face_t)['latent_dist'].mean latent_code = latent_code * 0.18215 output = lcm.forward(height=512, width=512, num_inference_steps=interlcm_step + 1, guidance_scale=8.0, latents=latent_code, prompt_embeds=visual_feat, output_type="pil", lcm_origin_steps=50, lq_input=cropped_face_t).images output = wavelet_reconstruction(output, cropped_face_t) restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1)) del output torch.cuda.empty_cache() except Exception as error: print(f'\tFailed inference for CodeFormer: {error}') restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1)) restored_face = restored_face.astype('uint8') face_helper.add_restored_face(restored_face, cropped_face) # paste_back if not has_aligned: # upsample the background if bg_upsampler is not None: # Now only support RealESRGAN for upsampling background bg_img = bg_upsampler.enhance(img, outscale=upscale)[0] else: bg_img = None face_helper.get_inverse_affine(None) # paste each restored face to the input image if face_upsample and face_upsampler is not None: restored_img = face_helper.paste_faces_to_input_image(upsample_img=bg_img, draw_box=draw_box, face_upsampler=face_upsampler) else: restored_img = face_helper.paste_faces_to_input_image(upsample_img=bg_img, draw_box=draw_box) else: restored_img = restored_face # save restored img save_path = f'output/out.png' imwrite(restored_img, save_path) restored_img = cv2.cvtColor(restored_img, cv2.COLOR_BGR2RGB) return restored_img # except Exception as error: # print('Global exception', error) # return None title = "InterLCM: Low-Quality Images as Intermediate States of Latent Consistency Models for Effective Blind Face Restoration" description = r"""
InterLCM logo

Official Gradio demo for Low-Quality Images as Intermediate States of Latent Consistency Models for Effective Blind Face Restoration (ICLR 2025)
πŸ”₯ InterLCM is a robust blind face restoration algorithm.
⭐ If InterLCM is helpful to your images or projects, please help star this repo. Thanks! πŸ€—
""" article = r""" If InterLCM is helpful, please help to ⭐ the Github Repo. Thanks! [![GitHub Stars](https://img.shields.io/github/stars/sen-mao/InterLCM?style=social)](https://github.com/sen-mao/InterLCM) --- πŸ“ **Citation** If our work is useful for your research, please consider citing: ```bibtex @inproceedings{li2025interlcm, title={InterLCM: Low-Quality Images as Intermediate States of Latent Consistency Models for Effective Blind Face Restoration}, author={Li, Senmao and Wang, Kai and van de Weijer, Joost and Khan, Fahad Shahbaz and Guo, Chun-Le and Yang, Shiqi and Wang, Yaxing and Yang, Jian and Cheng, Ming-Ming}, booktitle={ICLR}, year={2025} } ``` πŸ“§ **Contact** If you have any questions, please feel free to reach me out at senmaonk@gmail.com.
visitors
""" demo = gr.Interface( inference, [ gr.Image(type="filepath", label="Input"), gr.Radio(choices=["1", "3"], value="3", label="Select InterLCM step (InterLCM enables 1-step⚑ BFR under non-extreme degradation conditions)"), gr.Checkbox(value=True, label="Pre_Face_Align"), gr.Checkbox(value=True, label="Background_Enhance"), gr.Checkbox(value=True, label="Face_Upsample"), ], [ gr.Image(type="numpy", label="Output") ], title=title, description=description, article=article, examples=[ ['inputs/cropped_faces/0631.png', "3", False, False, False], ['inputs/cropped_faces/Nora_Bendijo_0001_00.png', "3", False, False, False], ['inputs/whole_imgs/03.jpg', "1", True, True, True], ['inputs/whole_imgs/04.jpg', "3", True, True, True], ['inputs/whole_imgs/05.jpg', "3", True, True, True] ], concurrency_limit=2, # allow_flagging="never", ) if __name__ == "__main__": # DEBUG = os.getenv('DEBUG') == '1' # demo.launch(server_name="0.0.0.0", server_port=7861, max_threads=10, share=False) demo.launch()