Update README.md
Browse files
README.md
CHANGED
|
@@ -67,33 +67,34 @@ instantir_path = f'./models/aggregator.pt'
|
|
| 67 |
# load SDXL
|
| 68 |
sdxl = StableDiffusionXLPipeline.from_pretrained('stabilityai/stable-diffusion-xl-base-1.0', torch_dtype=torch.float16)
|
| 69 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
# load adapter
|
| 71 |
image_proj_model = Resampler(
|
| 72 |
embedding_dim=image_encoder.config.hidden_size,
|
| 73 |
output_dim=sdxl.unet.config.cross_attention_dim,
|
| 74 |
)
|
| 75 |
init_adapter_in_unet(
|
| 76 |
-
|
| 77 |
image_proj_model,
|
| 78 |
dcp_adapter,
|
| 79 |
)
|
| 80 |
|
| 81 |
-
pipe = InstantIRPipeline(
|
| 82 |
-
sdxl.vae, sdxl.text_encoder, sdxl.text_encoder_2, sdxl.tokenizer, sdxl.tokenizer_2,
|
| 83 |
-
sdxl.unet, sdxl.scheduler, feature_extractor=image_processor, image_encoder=image_encoder,
|
| 84 |
-
)
|
| 85 |
-
pipe.cuda()
|
| 86 |
-
|
| 87 |
# load previewer lora
|
| 88 |
pipe.prepare_previewers(previewer_lora_path)
|
| 89 |
-
pipe.unet.to(dtype=torch.float16)
|
| 90 |
pipe.scheduler = DDPMScheduler.from_pretrained('stabilityai/stable-diffusion-xl-base-1.0', subfolder="scheduler")
|
| 91 |
lcm_scheduler = LCMSingleStepScheduler.from_config(pipe.scheduler.config)
|
|
|
|
|
|
|
| 92 |
|
| 93 |
# load aggregator weights
|
| 94 |
pretrained_state_dict = torch.load(instantir_path)
|
| 95 |
pipe.aggregator.load_state_dict(pretrained_state_dict)
|
| 96 |
-
pipe.aggregator.to(dtype=torch.float16)
|
| 97 |
```
|
| 98 |
|
| 99 |
Then, you can restore your broken images with:
|
|
|
|
| 67 |
# load SDXL
|
| 68 |
sdxl = StableDiffusionXLPipeline.from_pretrained('stabilityai/stable-diffusion-xl-base-1.0', torch_dtype=torch.float16)
|
| 69 |
|
| 70 |
+
# InstantIR pipeline
|
| 71 |
+
pipe = InstantIRPipeline(
|
| 72 |
+
sdxl.vae, sdxl.text_encoder, sdxl.text_encoder_2, sdxl.tokenizer, sdxl.tokenizer_2,
|
| 73 |
+
sdxl.unet, sdxl.scheduler, feature_extractor=image_processor, image_encoder=image_encoder,
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
# load adapter
|
| 77 |
image_proj_model = Resampler(
|
| 78 |
embedding_dim=image_encoder.config.hidden_size,
|
| 79 |
output_dim=sdxl.unet.config.cross_attention_dim,
|
| 80 |
)
|
| 81 |
init_adapter_in_unet(
|
| 82 |
+
pipe.unet,
|
| 83 |
image_proj_model,
|
| 84 |
dcp_adapter,
|
| 85 |
)
|
| 86 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
# load previewer lora
|
| 88 |
pipe.prepare_previewers(previewer_lora_path)
|
|
|
|
| 89 |
pipe.scheduler = DDPMScheduler.from_pretrained('stabilityai/stable-diffusion-xl-base-1.0', subfolder="scheduler")
|
| 90 |
lcm_scheduler = LCMSingleStepScheduler.from_config(pipe.scheduler.config)
|
| 91 |
+
pipe.unet.to(dtype=torch.float16)
|
| 92 |
+
pipe.to('cuda')
|
| 93 |
|
| 94 |
# load aggregator weights
|
| 95 |
pretrained_state_dict = torch.load(instantir_path)
|
| 96 |
pipe.aggregator.load_state_dict(pretrained_state_dict)
|
| 97 |
+
pipe.aggregator.to(dtype=torch.float16, device=pipe.unet.device)
|
| 98 |
```
|
| 99 |
|
| 100 |
Then, you can restore your broken images with:
|