Update app.py
Browse files
app.py
CHANGED
|
@@ -17,29 +17,27 @@ model_id = "stabilityai/stable-diffusion-3.5-large"
|
|
| 17 |
pipe = StableDiffusion3Pipeline.from_pretrained(model_id)
|
| 18 |
|
| 19 |
# Check if GPU is available, then move the model to the appropriate device
|
| 20 |
-
|
|
|
|
| 21 |
|
| 22 |
# Define the path to the LoRA model
|
| 23 |
lora_model_path = "./lora_model.pth" # Assuming the file is saved locally
|
| 24 |
|
| 25 |
# Custom method to load and apply LoRA weights to the Stable Diffusion pipeline
|
| 26 |
def load_lora_model(pipe, lora_model_path):
|
| 27 |
-
# Set device to 'cuda' if available, otherwise 'cpu'
|
| 28 |
-
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 29 |
-
|
| 30 |
# When loading the LoRA weights
|
| 31 |
lora_weights = torch.load(lora_model_path, map_location=device, weights_only=True)
|
|
|
|
|
|
|
|
|
|
| 32 |
|
| 33 |
-
#
|
| 34 |
-
print(dir(pipe)) # This will list all attributes and methods of the `pipe` object
|
| 35 |
-
|
| 36 |
-
# Apply weights to the UNet submodule
|
| 37 |
try:
|
| 38 |
-
for name, param in pipe.
|
| 39 |
if name in lora_weights:
|
| 40 |
param.data += lora_weights[name]
|
| 41 |
except AttributeError:
|
| 42 |
-
print("The model doesn't have '
|
| 43 |
# Add alternative handling or exit
|
| 44 |
|
| 45 |
return pipe
|
|
@@ -47,7 +45,7 @@ def load_lora_model(pipe, lora_model_path):
|
|
| 47 |
# Load and apply the LoRA model weights
|
| 48 |
pipe = load_lora_model(pipe, lora_model_path)
|
| 49 |
|
| 50 |
-
# Use the @
|
| 51 |
@spaces.gpu
|
| 52 |
def generate(prompt, seed=None):
|
| 53 |
generator = torch.manual_seed(seed) if seed is not None else None
|
|
|
|
| 17 |
pipe = StableDiffusion3Pipeline.from_pretrained(model_id)
|
| 18 |
|
| 19 |
# Check if GPU is available, then move the model to the appropriate device
|
| 20 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 21 |
+
pipe.to(device)
|
| 22 |
|
| 23 |
# Define the path to the LoRA model
|
| 24 |
lora_model_path = "./lora_model.pth" # Assuming the file is saved locally
|
| 25 |
|
| 26 |
# Custom method to load and apply LoRA weights to the Stable Diffusion pipeline
|
| 27 |
def load_lora_model(pipe, lora_model_path):
|
|
|
|
|
|
|
|
|
|
| 28 |
# When loading the LoRA weights
|
| 29 |
lora_weights = torch.load(lora_model_path, map_location=device, weights_only=True)
|
| 30 |
+
|
| 31 |
+
# Check if the transformer folder has the necessary attributes
|
| 32 |
+
print(dir(pipe.transformer)) # List available attributes of the transformer (formerly 'unet')
|
| 33 |
|
| 34 |
+
# Apply weights to the transformer submodule
|
|
|
|
|
|
|
|
|
|
| 35 |
try:
|
| 36 |
+
for name, param in pipe.transformer.named_parameters(): # Accessing transformer parameters
|
| 37 |
if name in lora_weights:
|
| 38 |
param.data += lora_weights[name]
|
| 39 |
except AttributeError:
|
| 40 |
+
print("The model doesn't have 'transformer' attributes. Please check the model structure.")
|
| 41 |
# Add alternative handling or exit
|
| 42 |
|
| 43 |
return pipe
|
|
|
|
| 45 |
# Load and apply the LoRA model weights
|
| 46 |
pipe = load_lora_model(pipe, lora_model_path)
|
| 47 |
|
| 48 |
+
# Use the @spaces.gpu decorator to ensure compatibility with GPU or CPU as needed
|
| 49 |
@spaces.gpu
|
| 50 |
def generate(prompt, seed=None):
|
| 51 |
generator = torch.manual_seed(seed) if seed is not None else None
|