| | from huggingface_hub import hf_hub_download |
| |
|
| | hf_hub_download(repo_id="InstantX/InstantIR", filename="models/adapter.pt", local_dir=".") |
| | hf_hub_download(repo_id="InstantX/InstantIR", filename="models/aggregator.pt", local_dir=".") |
| | hf_hub_download(repo_id="InstantX/InstantIR", filename="models/previewer_lora_weights.bin", local_dir=".") |
| |
|
| | import torch |
| | from PIL import Image |
| |
|
| | from diffusers import DDPMScheduler |
| | from schedulers.lcm_single_step_scheduler import LCMSingleStepScheduler |
| |
|
| | from module.ip_adapter.utils import load_adapter_to_pipe |
| | from pipelines.sdxl_instantir import InstantIRPipeline |
| |
|
| | def resize_img(input_image, max_side=1280, min_side=1024, size=None, |
| | pad_to_max_side=False, mode=Image.BILINEAR, base_pixel_number=64): |
| |
|
| | w, h = input_image.size |
| | if size is not None: |
| | w_resize_new, h_resize_new = size |
| | else: |
| | |
| | |
| | ratio = max_side / max(h, w) |
| | input_image = input_image.resize([round(ratio*w), round(ratio*h)], mode) |
| | w_resize_new = (round(ratio * w) // base_pixel_number) * base_pixel_number |
| | h_resize_new = (round(ratio * h) // base_pixel_number) * base_pixel_number |
| | input_image = input_image.resize([w_resize_new, h_resize_new], mode) |
| |
|
| | if pad_to_max_side: |
| | res = np.ones([max_side, max_side, 3], dtype=np.uint8) * 255 |
| | offset_x = (max_side - w_resize_new) // 2 |
| | offset_y = (max_side - h_resize_new) // 2 |
| | res[offset_y:offset_y+h_resize_new, offset_x:offset_x+w_resize_new] = np.array(input_image) |
| | input_image = Image.fromarray(res) |
| | return input_image |
| |
|
| | |
| | instantir_path = f'./models' |
| |
|
| | |
| | pipe = InstantIRPipeline.from_pretrained('stabilityai/stable-diffusion-xl-base-1.0', torch_dtype=torch.float16) |
| |
|
| | |
| | load_adapter_to_pipe( |
| | pipe, |
| | f"{instantir_path}/adapter.pt", |
| | image_encoder_or_path = 'facebook/dinov2-large', |
| | ) |
| |
|
| | |
| | pipe.prepare_previewers(instantir_path) |
| | pipe.scheduler = DDPMScheduler.from_pretrained('stabilityai/stable-diffusion-xl-base-1.0', subfolder="scheduler") |
| | lcm_scheduler = LCMSingleStepScheduler.from_config(pipe.scheduler.config) |
| |
|
| | |
| | pretrained_state_dict = torch.load(f"{instantir_path}/aggregator.pt") |
| | pipe.aggregator.load_state_dict(pretrained_state_dict) |
| |
|
| | |
| | pipe.to(device='cuda', dtype=torch.float16) |
| | pipe.aggregator.to(device='cuda', dtype=torch.float16) |
| |
|
| | PROMPT = "Photorealistic, highly detailed, hyper detailed photo - realistic maximum detail, 32k, \ |
| | ultra HD, extreme meticulous detailing, skin pore detailing, \ |
| | hyper sharpness, perfect without deformations, \ |
| | taken using a Canon EOS R camera, Cinematic, High Contrast, Color Grading. " |
| |
|
| | NEG_PROMPT = "blurry, out of focus, unclear, depth of field, over-smooth, \ |
| | sketch, oil painting, cartoon, CG Style, 3D render, unreal engine, \ |
| | dirty, messy, worst quality, low quality, frames, painting, illustration, drawing, art, \ |
| | watermark, signature, jpeg artifacts, deformed, lowres" |
| |
|
| | def infer(prompt, input_image, steps=30, cfg_scale=7.0, guidance_end=1.0, |
| | creative_restoration=False, seed=3407, height=1024, width=1024): |
| |
|
| | |
| | |
| | low_quality_image = Image.open(input_image).convert("RGB") |
| |
|
| | lq = [resize_img(low_quality_image, size=(width, height))] |
| | generator = torch.Generator(device='cuda').manual_seed(seed) |
| | timesteps = [ |
| | i * (1000//steps) + pipe.scheduler.config.steps_offset for i in range(0, steps) |
| | ] |
| | timesteps = timesteps[::-1] |
| |
|
| | prompt = PROMPT if len(prompt)==0 else prompt |
| | neg_prompt = NEG_PROMPT |
| | |
| | |
| | image = pipe( |
| | prompt=[prompt]*len(lq), |
| | image=lq, |
| | num_inference_steps=steps, |
| | generator=generator, |
| | timesteps=timesteps, |
| | negative_prompt=[neg_prompt]*len(lq), |
| | guidance_scale=cfg_scale, |
| | previewer_scheduler=lcm_scheduler, |
| | ).images[0] |
| |
|
| | return image |
| |
|
| | import gradio as gr |
| |
|
| |
|
| |
|
| | with gr.Blocks() as demo: |
| | with gr.Column(): |
| | with gr.Row(): |
| | with gr.Column(): |
| | lq_img = gr.Image(label="Low-quality image", type="filepath") |
| | with gr.Group(): |
| | prompt = gr.Textbox(label="Prompt", value="") |
| | |
| | submit_btn = gr.Button("InstantIR magic!") |
| | output_img = gr.Image(label="InstantIR restored") |
| | submit_btn.click( |
| | fn=infer, |
| | inputs=[prompt, lq_img], |
| | outputs=[output_img] |
| | ) |
| | demo.launch(show_error=True) |