Update app.py
Browse files
app.py
CHANGED
|
@@ -13,23 +13,23 @@ else:
|
|
| 13 |
|
| 14 |
# Load the Stable Diffusion 3.5 model
|
| 15 |
model_id = "stabilityai/stable-diffusion-3.5-medium"
|
| 16 |
-
pipe = StableDiffusion3Pipeline.from_pretrained(model_id
|
| 17 |
-
pipe.to("cpu")
|
|
|
|
|
|
|
|
|
|
| 18 |
|
| 19 |
# Custom method to load and apply LoRA weights to the Stable Diffusion pipeline
|
| 20 |
def load_lora_model(pipe, lora_model_path):
|
| 21 |
-
# Load the LoRA weights
|
| 22 |
lora_weights = torch.load(lora_model_path, map_location="cpu")
|
| 23 |
-
|
| 24 |
-
# Apply
|
| 25 |
-
for name, param in pipe.named_parameters():
|
| 26 |
if name in lora_weights:
|
| 27 |
-
param.data += lora_weights[name]
|
| 28 |
-
|
| 29 |
-
return pipe # Return the updated model
|
| 30 |
|
| 31 |
-
|
| 32 |
-
lora_model_path = "./lora_model.pth" # Local path to LoRA model
|
| 33 |
|
| 34 |
# Load and apply the LoRA model weights
|
| 35 |
pipe = load_lora_model(pipe, lora_model_path)
|
|
|
|
| 13 |
|
| 14 |
# Load the Stable Diffusion 3.5 model
|
| 15 |
model_id = "stabilityai/stable-diffusion-3.5-medium"
|
| 16 |
+
pipe = StableDiffusion3Pipeline.from_pretrained(model_id) # Removed torch_dtype argument
|
| 17 |
+
pipe.to("cpu") # Ensuring it runs on CPU
|
| 18 |
+
|
| 19 |
+
# Define the path to the LoRA model
|
| 20 |
+
lora_model_path = "./lora_model.pth" # Assuming the file is saved locally
|
| 21 |
|
| 22 |
# Custom method to load and apply LoRA weights to the Stable Diffusion pipeline
|
| 23 |
def load_lora_model(pipe, lora_model_path):
|
| 24 |
+
# Load the LoRA weights
|
| 25 |
lora_weights = torch.load(lora_model_path, map_location="cpu")
|
| 26 |
+
|
| 27 |
+
# Apply weights to the UNet submodule
|
| 28 |
+
for name, param in pipe.unet.named_parameters(): # Accessing unet parameters
|
| 29 |
if name in lora_weights:
|
| 30 |
+
param.data += lora_weights[name]
|
|
|
|
|
|
|
| 31 |
|
| 32 |
+
return pipe
|
|
|
|
| 33 |
|
| 34 |
# Load and apply the LoRA model weights
|
| 35 |
pipe = load_lora_model(pipe, lora_model_path)
|