KibbleHF commited on
Commit
3df0ca4
·
verified ·
1 Parent(s): 73222b2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -28
app.py CHANGED
@@ -1,49 +1,55 @@
1
  # app.py
2
- from diffusers import StableDiffusionPipeline, AutoPipelineForText2Image
3
  import torch
4
  import gradio as gr
 
5
  import os
6
 
7
- # Define the model ID for the specific checkpoint
8
- # This is the full ID of the repository where the safetensors file lives
9
- MODEL_ID = "uhralk/Indigo_Furry_mix"
10
- # This is the exact filename of the model checkpoint you want to load
11
- # Make sure this matches the filename on the Hugging Face Hub EXACTLY.
12
  CHECKPOINT_FILENAME = "indigo_Furrymix_v120_hybrid_fin_fp16.safetensors"
13
 
14
  # Determine the device (GPU or CPU)
15
  device = "cuda" if torch.cuda.is_available() else "cpu"
16
- # Use float16 for GPU to save VRAM and speed up, float32 for CPU or if fp16 causes issues
17
  dtype = torch.float16 if device == "cuda" else torch.float32
18
 
 
 
 
 
19
  # Load the model
20
  try:
21
- # Attempt to load the model directly as a full pipeline from the checkpoint
22
- # This assumes the checkpoint is a full model merge, not just LoRA weights
23
- pipe = AutoPipelineForText2Image.from_pretrained(
24
- MODEL_ID,
25
  torch_dtype=dtype,
26
- use_safetensors=True,
27
- # Specify the exact filename within the repository
28
- # This tells diffusers to look for this specific file as the main model weights
29
- # without needing an external base model or explicit LoRA loading.
30
- model_file=CHECKPOINT_FILENAME
31
  )
32
  pipe.to(device)
33
- print(f"Model loaded successfully on {device} using direct checkpoint loading.")
 
 
 
 
 
34
 
35
  except Exception as e:
36
- print(f"Error loading model directly on {device}: {e}")
37
- print("Attempting to load on CPU with float32 as a fallback (may be very slow)...")
38
  try:
39
- pipe = AutoPipelineForText2Image.from_pretrained(
40
- MODEL_ID,
 
41
  torch_dtype=torch.float32,
42
- use_safetensors=True,
43
- model_file=CHECKPOINT_FILENAME
44
  )
45
  pipe.to("cpu")
46
- print("Model forced loaded on CPU.")
 
 
 
 
47
  except Exception as e_cpu:
48
  print(f"Critical error: Failed to load model even on CPU: {e_cpu}")
49
  exit() # Exit if model cannot be loaded at all
@@ -54,10 +60,13 @@ def generate_image(prompt, negative_prompt, num_inference_steps, guidance_scale,
54
  if not prompt:
55
  return None, "Please enter a prompt."
56
 
57
- # Set up random seed if not -1
58
  generator = torch.Generator(device=device).manual_seed(seed) if seed != -1 else None
59
 
60
  try:
 
 
 
 
61
  image = pipe(
62
  prompt=prompt,
63
  negative_prompt=negative_prompt,
@@ -89,8 +98,7 @@ iface = gr.Interface(
89
  ["An indigo furry warrior, wielding a glowing sword, intricate armor, epic fantasy art", "blurry, low quality"],
90
  ["A cute indigo furry wizard casting a spell, magical effects, cartoon style", "disfigured, ugly"],
91
  ]
92
- ) # <--- THIS CLOSING PARENTHESIS IS CRUCIAL AND OFTEN MISSED!
93
 
94
- # Launch the Gradio app
95
  if __name__ == "__main__":
96
- iface.launch()
 
1
  # app.py
2
+ from diffusers import StableDiffusionPipeline
3
  import torch
4
  import gradio as gr
5
+ from huggingface_hub import hf_hub_download # Import to download the specific file
6
  import os
7
 
8
+ # Define the base model and the specific checkpoint file
9
+ BASE_MODEL = "runwayml/stable-diffusion-v1-5"
10
+ MODEL_REPO_ID = "uhralk/Indigo_Furry_mix"
 
 
11
  CHECKPOINT_FILENAME = "indigo_Furrymix_v120_hybrid_fin_fp16.safetensors"
12
 
13
  # Determine the device (GPU or CPU)
14
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
15
  dtype = torch.float16 if device == "cuda" else torch.float32
16
 
17
+ # Path to the downloaded checkpoint file
18
+ # This will download the file to the Hugging Face cache directory
19
+ checkpoint_path = hf_hub_download(repo_id=MODEL_REPO_ID, filename=CHECKPOINT_FILENAME)
20
+
21
  # Load the model
22
  try:
23
+ # 1. Load the base Stable Diffusion 1.5 pipeline
24
+ pipe = StableDiffusionPipeline.from_pretrained(
25
+ BASE_MODEL,
 
26
  torch_dtype=dtype,
27
+ use_safetensors=True
 
 
 
 
28
  )
29
  pipe.to(device)
30
+ print(f"Base model '{BASE_MODEL}' loaded successfully on {device}")
31
+
32
+ # 2. Load the custom checkpoint's state dict into the pipeline
33
+ # This is often the most compatible way for "merged" checkpoints
34
+ pipe.load_lora_weights(checkpoint_path, adapter_name="indigo_mix") # Use a dummy adapter name
35
+ print(f"Custom checkpoint '{CHECKPOINT_FILENAME}' loaded successfully.")
36
 
37
  except Exception as e:
38
+ print(f"Error loading model: {e}")
39
+ # Fallback to CPU if GPU loading fails or if no GPU, and retry loading
40
  try:
41
+ print("Attempting to load on CPU with float32 as a fallback...")
42
+ pipe = StableDiffusionPipeline.from_pretrained(
43
+ BASE_MODEL,
44
  torch_dtype=torch.float32,
45
+ use_safetensors=True
 
46
  )
47
  pipe.to("cpu")
48
+ print(f"Base model '{BASE_MODEL}' loaded successfully on CPU")
49
+
50
+ pipe.load_lora_weights(checkpoint_path, adapter_name="indigo_mix") # Load on CPU as well
51
+ print(f"Custom checkpoint '{CHECKPOINT_FILENAME}' loaded successfully on CPU.")
52
+
53
  except Exception as e_cpu:
54
  print(f"Critical error: Failed to load model even on CPU: {e_cpu}")
55
  exit() # Exit if model cannot be loaded at all
 
60
  if not prompt:
61
  return None, "Please enter a prompt."
62
 
 
63
  generator = torch.Generator(device=device).manual_seed(seed) if seed != -1 else None
64
 
65
  try:
66
+ # If you loaded the LoRA/merged checkpoint, you might need to enable it for inference
67
+ # If it was loaded as a direct state dict, it's already "active"
68
+ # For pipe.load_lora_weights, if not using a specific adapter name for later enable/disable:
69
+ # you just call it and it applies immediately.
70
  image = pipe(
71
  prompt=prompt,
72
  negative_prompt=negative_prompt,
 
98
  ["An indigo furry warrior, wielding a glowing sword, intricate armor, epic fantasy art", "blurry, low quality"],
99
  ["A cute indigo furry wizard casting a spell, magical effects, cartoon style", "disfigured, ugly"],
100
  ]
101
+ )
102
 
 
103
  if __name__ == "__main__":
104
+ iface.