AkashKumarave commited on
Commit
f52c5b7
·
verified ·
1 Parent(s): bd44cb3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +62 -83
app.py CHANGED
@@ -4,102 +4,73 @@ import numpy as np
4
  import gradio as gr
5
  from diffusers import StableDiffusionXLPipeline
6
  from insightface.app import FaceAnalysis
7
- from accelerate import init_empty_weights, load_checkpoint_and_dispatch
8
- import os
9
  from huggingface_hub import hf_hub_download
10
- import time
11
 
12
- # Allow network access for runtime download
13
  os.environ["HF_HUB_OFFLINE"] = "0"
14
 
15
- # Set device to CPU
16
  device = "cpu"
17
- dtype = torch.float32
18
 
19
- # Load face encoder (InsightFace handles its own download)
20
  try:
21
  face_app = FaceAnalysis(providers=["CPUExecutionProvider"])
22
  face_app.prepare(ctx_id=0, det_size=(480, 480))
23
  print("InsightFace model loaded successfully.")
24
  except Exception as e:
25
- raise RuntimeError(f"Failed to load InsightFace model: {e}. Ensure network access for initial download.")
26
 
27
- # Define paths for preloaded or downloaded weights
28
- model_path = "./" # Start with root
29
  ip_adapter_path = "./"
30
 
31
- # Debug: List files to confirm preloading or download
32
- print("Files in root directory:", os.listdir("."))
33
- print("Files in ./unet/ directory:", os.listdir("./unet") if os.path.exists("./unet") else "No ./unet/ directory")
34
-
35
- # Check if base model weights exist or download them
36
- kolors_weights = model_path + "diffusers_weights.safetensors"
37
  if not os.path.exists(kolors_weights):
38
- kolors_weights = model_path + "diffusion_pytorch_model.fp16.safetensors"
39
- if not os.path.exists(kolors_weights):
40
- kolors_weights_unet = "./unet/diffusion_pytorch_model.fp16.safetensors"
41
- if not os.path.exists(kolors_weights_unet):
42
- print("Preloading failed. Attempting runtime download with retry...")
43
- os.makedirs("./unet", exist_ok=True)
44
- max_retries = 3
45
- for attempt in range(max_retries):
46
- try:
47
- print(f"Download attempt {attempt + 1} of {max_retries}")
48
- hf_hub_download(
49
- repo_id="Kwai-Kolors/Kolors",
50
- filename="unet/diffusion_pytorch_model.fp16.safetensors",
51
- local_dir="./unet",
52
- local_files_only=False
53
- )
54
- print("Kolors base weights downloaded to", kolors_weights_unet)
55
- model_path = "./unet/"
56
- kolors_weights = kolors_weights_unet
57
- break
58
- except Exception as e:
59
- print(f"Download attempt {attempt + 1} failed: {e}")
60
- if attempt < max_retries - 1:
61
- time.sleep(5)
62
- else:
63
- raise FileNotFoundError(f"Failed to download Kolors base weights after {max_retries} attempts: {e}. Verify the repo or contact support.")
64
- else:
65
- model_path = "./unet/"
66
- kolors_weights = kolors_weights_unet
67
 
68
- # Check if IP-Adapter weights exist or download them
69
- ip_adapter_weights = ip_adapter_path + "ipa-faceid-plus.bin"
70
  if not os.path.exists(ip_adapter_weights):
71
- print("IP-Adapter preloading failed. Attempting runtime download with retry...")
72
- max_retries = 3
73
- for attempt in range(max_retries):
74
- try:
75
- print(f"IP-Adapter download attempt {attempt + 1} of {max_retries}")
76
- hf_hub_download(
77
- repo_id="Kwai-Kolors/Kolors-IP-Adapter-FaceID-Plus",
78
- filename="ipa-faceid-plus.bin",
79
- local_dir="./",
80
- local_files_only=False
81
- )
82
- print("IP-Adapter weights downloaded to", ip_adapter_weights)
83
- break
84
- except Exception as e:
85
- print(f"IP-Adapter download attempt {attempt + 1} failed: {e}")
86
- if attempt < max_retries - 1:
87
- time.sleep(5)
88
- else:
89
- raise FileNotFoundError(f"Failed to download IP-Adapter weights after {max_retries} attempts: {e}. Verify the repo or contact support.")
90
-
91
- # Initialize model with empty weights
92
- with init_empty_weights():
93
- pipe = StableDiffusionXLPipeline.from_pretrained(
94
- "./", # Use local model path
95
- torch_dtype=dtype,
96
- safety_checker=None,
97
- local_files_only=True # Force local file usage
98
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
 
100
- # Load and dispatch model with accelerate
101
- pipe = load_checkpoint_and_dispatch(pipe, model_path, device_map="cpu", offload_folder=None)
102
- pipe.load_ip_adapter("./", subfolder=None, weight_name="ipa-faceid-plus.bin")
103
 
104
  def generate_image(uploaded_image, prompt):
105
  img = cv2.cvtColor(np.array(uploaded_image), cv2.COLOR_RGB2BGR)
@@ -111,24 +82,32 @@ def generate_image(uploaded_image, prompt):
111
  face_emb = face_info["embedding"]
112
 
113
  try:
 
114
  image = pipe(
115
  prompt=prompt,
116
  image_embeds=face_emb,
117
- num_inference_steps=20,
118
  guidance_scale=7.5,
119
- height=512,
120
- width=512,
121
  ).images[0]
122
  return "Image generated successfully!", image
123
  except Exception as e:
124
  return f"Generation failed: {e}", None
125
 
 
126
  interface = gr.Interface(
127
  fn=generate_image,
128
- inputs=[gr.Image(type="pil", label="Upload Reference Image"), gr.Textbox(label="Enter Prompt", placeholder="e.g., A photorealistic astronaut in space")],
129
- outputs=[gr.Textbox(label="Status"), gr.Image(label="Generated Image")],
 
 
 
 
 
 
130
  title="Face Reference Image Generator (Kolors-IP-Adapter-FaceID-Plus)",
131
  description="Upload an image with a face, enter a prompt, and generate a new image preserving the reference face."
132
  )
133
 
134
- interface.launch()
 
4
  import gradio as gr
5
  from diffusers import StableDiffusionXLPipeline
6
  from insightface.app import FaceAnalysis
 
 
7
  from huggingface_hub import hf_hub_download
8
+ import os
9
 
10
+ # Allow network access for runtime downloads
11
  os.environ["HF_HUB_OFFLINE"] = "0"
12
 
13
+ # Set device to CPU (Hugging Face free tier is CPU-only)
14
  device = "cpu"
15
+ dtype = torch.float32 # Use float32 to avoid GPU-specific optimizations
16
 
17
+ # Load face encoder (InsightFace will download its weights on first run)
18
  try:
19
  face_app = FaceAnalysis(providers=["CPUExecutionProvider"])
20
  face_app.prepare(ctx_id=0, det_size=(480, 480))
21
  print("InsightFace model loaded successfully.")
22
  except Exception as e:
23
+ raise RuntimeError(f"Failed to load InsightFace model: {e}. Ensure network access.")
24
 
25
+ # Define paths for temporary storage (ephemeral in Spaces)
26
+ kolors_unet_path = "./unet"
27
  ip_adapter_path = "./"
28
 
29
+ # Download Kolors unet weights at runtime
30
+ kolors_weights = os.path.join(kolors_unet_path, "diffusion_pytorch_model.fp16.safetensors")
 
 
 
 
31
  if not os.path.exists(kolors_weights):
32
+ print("Downloading Kolors unet weights...")
33
+ os.makedirs(kolors_unet_path, exist_ok=True)
34
+ hf_hub_download(
35
+ repo_id="Kwai-Kolors/Kolors",
36
+ filename="unet/diffusion_pytorch_model.fp16.safetensors",
37
+ local_dir=kolors_unet_path,
38
+ local_files_only=False
39
+ )
40
+ print("Kolors unet weights downloaded to", kolors_weights)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
+ # Download IP-Adapter weights at runtime
43
+ ip_adapter_weights = os.path.join(ip_adapter_path, "ipa-faceid-plus.bin")
44
  if not os.path.exists(ip_adapter_weights):
45
+ print("Downloading IP-Adapter weights...")
46
+ hf_hub_download(
47
+ repo_id="Kwai-Kolors/Kolors-IP-Adapter-FaceID-Plus",
48
+ filename="ipa-faceid-plus.bin",
49
+ local_dir=ip_adapter_path,
50
+ local_files_only=False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  )
52
+ print("IP-Adapter weights downloaded to", ip_adapter_weights)
53
+
54
+ # Load the base SDXL pipeline directly from Hugging Face Hub
55
+ print("Loading Stable Diffusion XL base model...")
56
+ pipe = StableDiffusionXLPipeline.from_pretrained(
57
+ "stabilityai/stable-diffusion-xl-base-1.0",
58
+ torch_dtype=dtype,
59
+ safety_checker=None,
60
+ local_files_only=False, # Download from Hub at runtime
61
+ cache_dir="./cache" # Use temporary cache directory
62
+ )
63
+
64
+ # Replace unet with Kolors weights
65
+ print("Loading Kolors unet weights into pipeline...")
66
+ pipe.unet.load_state_dict(torch.load(kolors_weights, map_location=device))
67
+
68
+ # Load IP-Adapter
69
+ print("Loading IP-Adapter...")
70
+ pipe.load_ip_adapter(ip_adapter_path, subfolder=None, weight_name="ipa-faceid-plus.bin")
71
 
72
+ # Move pipeline to CPU
73
+ pipe.to(device)
 
74
 
75
  def generate_image(uploaded_image, prompt):
76
  img = cv2.cvtColor(np.array(uploaded_image), cv2.COLOR_RGB2BGR)
 
82
  face_emb = face_info["embedding"]
83
 
84
  try:
85
+ # Reduce inference steps and resolution to fit free tier limits
86
  image = pipe(
87
  prompt=prompt,
88
  image_embeds=face_emb,
89
+ num_inference_steps=15, # Lower steps for faster execution
90
  guidance_scale=7.5,
91
+ height=384, # Smaller resolution to reduce memory usage
92
+ width=384,
93
  ).images[0]
94
  return "Image generated successfully!", image
95
  except Exception as e:
96
  return f"Generation failed: {e}", None
97
 
98
+ # Gradio interface
99
  interface = gr.Interface(
100
  fn=generate_image,
101
+ inputs=[
102
+ gr.Image(type="pil", label="Upload Reference Image"),
103
+ gr.Textbox(label="Enter Prompt", placeholder="e.g., A photorealistic astronaut in space")
104
+ ],
105
+ outputs=[
106
+ gr.Textbox(label="Status"),
107
+ gr.Image(label="Generated Image")
108
+ ],
109
  title="Face Reference Image Generator (Kolors-IP-Adapter-FaceID-Plus)",
110
  description="Upload an image with a face, enter a prompt, and generate a new image preserving the reference face."
111
  )
112
 
113
+ interface.launch()