Spaces:
Configuration error
Configuration error
fix dir
Browse files
app.py
CHANGED
|
@@ -22,6 +22,12 @@ from module.ip_adapter.resampler import Resampler
|
|
| 22 |
from module.aggregator import Aggregator
|
| 23 |
from pipelines.sdxl_instantir import InstantIRPipeline, LCM_LORA_MODULES, PREVIEWER_LORA_MODULES
|
| 24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
|
| 26 |
transform = transforms.Compose([
|
| 27 |
transforms.Resize(1024, interpolation=transforms.InterpolationMode.BILINEAR),
|
|
@@ -66,7 +72,7 @@ image_proj_model = Resampler(
|
|
| 66 |
init_ip_adapter_in_unet(
|
| 67 |
unet,
|
| 68 |
image_proj_model,
|
| 69 |
-
"
|
| 70 |
adapter_tokens=64,
|
| 71 |
)
|
| 72 |
print("Initializing InstantIR...")
|
|
@@ -77,7 +83,7 @@ pipe = InstantIRPipeline(
|
|
| 77 |
|
| 78 |
# Add Previewer LoRA.
|
| 79 |
lora_state_dict, alpha_dict = StableDiffusionXLPipeline.lora_state_dict(
|
| 80 |
-
"
|
| 81 |
# weight_name="previewer_lora_weights.bin",
|
| 82 |
|
| 83 |
)
|
|
@@ -145,7 +151,7 @@ lcm_scheduler = LCMSingleStepScheduler.from_config(pipe.scheduler.config)
|
|
| 145 |
# Load weights.
|
| 146 |
print("Loading checkpoint...")
|
| 147 |
aggregator_state_dict = torch.load(
|
| 148 |
-
"
|
| 149 |
map_location="cpu"
|
| 150 |
)
|
| 151 |
aggregator.load_state_dict(aggregator_state_dict, strict=True)
|
|
|
|
| 22 |
from module.aggregator import Aggregator
|
| 23 |
from pipelines.sdxl_instantir import InstantIRPipeline, LCM_LORA_MODULES, PREVIEWER_LORA_MODULES
|
| 24 |
|
| 25 |
+
from huggingface_hub import hf_hub_download
|
| 26 |
+
|
| 27 |
+
hf_hub_download(repo_id="InstantX/InstantIR", filename="adapter.pt", local_dir="./checkpoints")
|
| 28 |
+
hf_hub_download(repo_id="InstantX/InstantIR", filename="aggregator.pt", local_dir="./checkpoints")
|
| 29 |
+
hf_hub_download(repo_id="InstantX/InstantIR", filename="previewer_lora_weights.bin", local_dir="./checkpoints")
|
| 30 |
+
|
| 31 |
|
| 32 |
transform = transforms.Compose([
|
| 33 |
transforms.Resize(1024, interpolation=transforms.InterpolationMode.BILINEAR),
|
|
|
|
| 72 |
init_ip_adapter_in_unet(
|
| 73 |
unet,
|
| 74 |
image_proj_model,
|
| 75 |
+
"checkpoints/adapter.pt",
|
| 76 |
adapter_tokens=64,
|
| 77 |
)
|
| 78 |
print("Initializing InstantIR...")
|
|
|
|
| 83 |
|
| 84 |
# Add Previewer LoRA.
|
| 85 |
lora_state_dict, alpha_dict = StableDiffusionXLPipeline.lora_state_dict(
|
| 86 |
+
"checkpoints/previewer_lora_weights.bin",
|
| 87 |
# weight_name="previewer_lora_weights.bin",
|
| 88 |
|
| 89 |
)
|
|
|
|
| 151 |
# Load weights.
|
| 152 |
print("Loading checkpoint...")
|
| 153 |
aggregator_state_dict = torch.load(
|
| 154 |
+
"checkpoints/aggregator.pt",
|
| 155 |
map_location="cpu"
|
| 156 |
)
|
| 157 |
aggregator.load_state_dict(aggregator_state_dict, strict=True)
|