Update app.py
Browse files
app.py
CHANGED
|
@@ -3,10 +3,10 @@ import torch
|
|
| 3 |
import os
|
| 4 |
from diffusers import StableDiffusion3Pipeline
|
| 5 |
from safetensors.torch import load_file
|
| 6 |
-
from spaces import GPU # Remove
|
| 7 |
|
| 8 |
# 1. Define model ID and HF_TOKEN (at the VERY beginning)
|
| 9 |
-
model_id = "stabilityai/stable-diffusion-3.5-large" #
|
| 10 |
hf_token = os.getenv("HF_TOKEN") # For private models (set in HF Space settings)
|
| 11 |
|
| 12 |
# 2. Initialize pipeline (to None initially)
|
|
@@ -14,12 +14,18 @@ pipeline = None
|
|
| 14 |
|
| 15 |
# 3. Load Stable Diffusion and LoRA (before Gradio)
|
| 16 |
try:
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
|
| 24 |
lora_filename = "lora_trained_model.safetensors" # EXACT filename of your LoRA
|
| 25 |
lora_path = os.path.join("./", lora_filename)
|
|
|
|
| 3 |
import os
|
| 4 |
from diffusers import StableDiffusion3Pipeline
|
| 5 |
from safetensors.torch import load_file
|
| 6 |
+
from spaces import GPU # Remove if not in HF Space
|
| 7 |
|
| 8 |
# 1. Define model ID and HF_TOKEN (at the VERY beginning)
|
| 9 |
+
model_id = "stabilityai/stable-diffusion-3.5-large" # Or your preferred model ID
|
| 10 |
hf_token = os.getenv("HF_TOKEN") # For private models (set in HF Space settings)
|
| 11 |
|
| 12 |
# 2. Initialize pipeline (to None initially)
|
|
|
|
| 14 |
|
| 15 |
# 3. Load Stable Diffusion and LoRA (before Gradio)
|
| 16 |
try:
|
| 17 |
+
if hf_token: # check if the token exists, if not, then do not pass the token
|
| 18 |
+
pipeline = StableDiffusion3Pipeline.from_pretrained(
|
| 19 |
+
model_id,
|
| 20 |
+
torch_dtype=torch.float16,
|
| 21 |
+
cache_dir="./model_cache" # For caching
|
| 22 |
+
)
|
| 23 |
+
else:
|
| 24 |
+
pipeline = StableDiffusion3Pipeline.from_pretrained(
|
| 25 |
+
model_id,
|
| 26 |
+
torch_dtype=torch.float16,
|
| 27 |
+
cache_dir="./model_cache" # For caching
|
| 28 |
+
)
|
| 29 |
|
| 30 |
lora_filename = "lora_trained_model.safetensors" # EXACT filename of your LoRA
|
| 31 |
lora_path = os.path.join("./", lora_filename)
|