KibbleHF commited on
Commit
d57e93c
·
verified ·
1 Parent(s): 364e56c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -35
app.py CHANGED
@@ -2,56 +2,64 @@
2
  from diffusers import StableDiffusionPipeline
3
  import torch
4
  import gradio as gr
5
- from huggingface_hub import hf_hub_download # Import to download the specific file
6
  import os
7
 
8
- # Define the base model and the specific checkpoint file
9
- BASE_MODEL = "runwayml/stable-diffusion-v1-5"
10
  MODEL_REPO_ID = "uhralk/Indigo_Furry_mix"
11
  CHECKPOINT_FILENAME = "indigo_Furrymix_v120_hybrid_fin_fp16.safetensors"
12
 
13
  # Determine the device (GPU or CPU)
14
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
15
  dtype = torch.float16 if device == "cuda" else torch.float32
16
 
17
- # Path to the downloaded checkpoint file
18
- # This will download the file to the Hugging Face cache directory
19
- checkpoint_path = hf_hub_download(repo_id=MODEL_REPO_ID, filename=CHECKPOINT_FILENAME)
20
 
21
  # Load the model
22
  try:
23
- # 1. Load the base Stable Diffusion 1.5 pipeline
24
- pipe = StableDiffusionPipeline.from_pretrained(
25
- BASE_MODEL,
 
26
  torch_dtype=dtype,
27
- use_safetensors=True
 
 
 
 
 
 
 
 
 
28
  )
29
  pipe.to(device)
30
- print(f"Base model '{BASE_MODEL}' loaded successfully on {device}")
31
-
32
- # 2. Load the custom checkpoint's state dict into the pipeline
33
- # This is often the most compatible way for "merged" checkpoints
34
- pipe.load_lora_weights(checkpoint_path, adapter_name="indigo_mix") # Use a dummy adapter name
35
- print(f"Custom checkpoint '{CHECKPOINT_FILENAME}' loaded successfully.")
36
 
37
  except Exception as e:
38
- print(f"Error loading model: {e}")
39
- # Fallback to CPU if GPU loading fails or if no GPU, and retry loading
40
  try:
41
- print("Attempting to load on CPU with float32 as a fallback...")
42
- pipe = StableDiffusionPipeline.from_pretrained(
43
- BASE_MODEL,
44
- torch_dtype=torch.float32,
45
- use_safetensors=True
 
 
 
 
 
46
  )
47
  pipe.to("cpu")
48
- print(f"Base model '{BASE_MODEL}' loaded successfully on CPU")
49
-
50
- pipe.load_lora_weights(checkpoint_path, adapter_name="indigo_mix") # Load on CPU as well
51
- print(f"Custom checkpoint '{CHECKPOINT_FILENAME}' loaded successfully on CPU.")
52
-
53
  except Exception as e_cpu:
54
- print(f"Critical error: Failed to load model even on CPU: {e_cpu}")
55
  exit() # Exit if model cannot be loaded at all
56
 
57
 
@@ -63,10 +71,6 @@ def generate_image(prompt, negative_prompt, num_inference_steps, guidance_scale,
63
  generator = torch.Generator(device=device).manual_seed(seed) if seed != -1 else None
64
 
65
  try:
66
- # If you loaded the LoRA/merged checkpoint, you might need to enable it for inference
67
- # If it was loaded as a direct state dict, it's already "active"
68
- # For pipe.load_lora_weights, if not using a specific adapter name for later enable/disable:
69
- # you just call it and it applies immediately.
70
  image = pipe(
71
  prompt=prompt,
72
  negative_prompt=negative_prompt,
@@ -99,6 +103,6 @@ iface = gr.Interface(
99
  ["A cute indigo furry wizard casting a spell, magical effects, cartoon style", "disfigured, ugly"],
100
  ]
101
  )
102
- # Launch the Gradio app
103
  if __name__ == "__main__":
104
- iface.launch() # <--- Make sure this line is exactly 'iface.launch()'
 
2
  from diffusers import StableDiffusionPipeline
3
  import torch
4
  import gradio as gr
5
+ from huggingface_hub import hf_hub_download
6
  import os
7
 
8
+ # Define the model ID for the specific checkpoint
9
+ # We will use this to get the direct download URL for from_single_file
10
  MODEL_REPO_ID = "uhralk/Indigo_Furry_mix"
11
  CHECKPOINT_FILENAME = "indigo_Furrymix_v120_hybrid_fin_fp16.safetensors"
12
 
13
  # Determine the device (GPU or CPU)
14
  device = "cuda" if torch.cuda.is_available() else "cpu"
15
+ # Use float16 for GPU to save VRAM and speed up, float32 for CPU
16
+ # Stable Diffusion models are often best run in fp16 on GPU
17
  dtype = torch.float16 if device == "cuda" else torch.float32
18
 
19
+ # Construct the full Hub URL for from_single_file
20
+ # The format is "hf_hub_url/repo_id/filename"
21
+ full_checkpoint_url = f"hf_hub:{MODEL_REPO_ID}/{CHECKPOINT_FILENAME}"
22
 
23
  # Load the model
24
  try:
25
+ print(f"Attempting to load model from single file: {full_checkpoint_url} on {device}")
26
+ # Load the pipeline directly from the single .safetensors file
27
+ pipe = StableDiffusionPipeline.from_single_file(
28
+ full_checkpoint_url,
29
  torch_dtype=dtype,
30
+ use_safetensors=True,
31
+ # Ensure VAE, Text Encoder, and Scheduler are fetched from a compatible base if not included
32
+ # SD 1.5 components are compatible with most fine-tunes
33
+ vae=None, # Automatically inferred or loaded from base
34
+ text_encoder=None, # Automatically inferred or loaded from base
35
+ tokenizer=None, # Automatically inferred or loaded from base
36
+ scheduler=None, # Automatically inferred or loaded from base
37
+ # Specifying a compatible base ensures non-Unet components are loaded
38
+ from_transformers=False, # Important for diffusers checkpoints
39
+ load_safety_checker=True # Good practice
40
  )
41
  pipe.to(device)
42
+ print(f"Model loaded successfully from single file on {device}")
 
 
 
 
 
43
 
44
  except Exception as e:
45
+ print(f"Error loading model directly from single file on {device}: {e}")
46
+ print("Attempting to load on CPU with float32 as a fallback (may be very slow)...")
47
  try:
48
+ pipe = StableDiffusionPipeline.from_single_file(
49
+ full_checkpoint_url,
50
+ torch_dtype=torch.float32, # Force float32 for CPU
51
+ use_safetensors=True,
52
+ vae=None,
53
+ text_encoder=None,
54
+ tokenizer=None,
55
+ scheduler=None,
56
+ from_transformers=False,
57
+ load_safety_checker=True
58
  )
59
  pipe.to("cpu")
60
+ print("Model forced loaded on CPU.")
 
 
 
 
61
  except Exception as e_cpu:
62
+ print(f"Critical error: Failed to load model even from single file on CPU: {e_cpu}")
63
  exit() # Exit if model cannot be loaded at all
64
 
65
 
 
71
  generator = torch.Generator(device=device).manual_seed(seed) if seed != -1 else None
72
 
73
  try:
 
 
 
 
74
  image = pipe(
75
  prompt=prompt,
76
  negative_prompt=negative_prompt,
 
103
  ["A cute indigo furry wizard casting a spell, magical effects, cartoon style", "disfigured, ugly"],
104
  ]
105
  )
106
+
107
  if __name__ == "__main__":
108
+ iface.launch()