AkashKumarave commited on
Commit
a1bd508
·
verified ·
1 Parent(s): 06e8f08

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -9
app.py CHANGED
@@ -8,7 +8,7 @@ from huggingface_hub import hf_hub_download
8
  import os
9
  import logging
10
 
11
- # Set up logging
12
  logging.basicConfig(level=logging.INFO)
13
  logger = logging.getLogger(__name__)
14
 
@@ -24,12 +24,14 @@ cache_dir = "./cache"
24
  os.makedirs(cache_dir, exist_ok=True)
25
 
26
  # Load face encoder
 
27
  try:
28
  face_app = FaceAnalysis(providers=["CPUExecutionProvider"])
29
  face_app.prepare(ctx_id=0, det_size=(480, 480))
30
  logger.info("InsightFace model loaded successfully.")
31
  except Exception as e:
32
- raise RuntimeError(f"Failed to load InsightFace model: {e}")
 
33
 
34
  # Download function with explicit path return
35
  def download_file(repo_id, filename, local_dir):
@@ -59,7 +61,8 @@ ip_adapter_path = "./"
59
  os.makedirs(kolors_unet_path, exist_ok=True)
60
  os.makedirs(ip_adapter_path, exist_ok=True)
61
 
62
- # Download weights and get exact paths
 
63
  kolors_weights = download_file(
64
  "Kwai-Kolors/Kolors",
65
  "unet/diffusion_pytorch_model.fp16.safetensors",
@@ -71,7 +74,7 @@ ip_adapter_weights = download_file(
71
  ip_adapter_path
72
  )
73
 
74
- # Load the pipeline
75
  logger.info("Loading Stable Diffusion XL base model...")
76
  try:
77
  pipe = StableDiffusionXLPipeline.from_pretrained(
@@ -79,13 +82,16 @@ try:
79
  torch_dtype=dtype,
80
  safety_checker=None,
81
  local_files_only=False,
82
- cache_dir=cache_dir
 
 
83
  )
 
84
  except Exception as e:
85
  logger.error(f"Failed to load SDXL base model: {e}")
86
  raise
87
 
88
- # Load Kolors unet weights with weights_only=False
89
  logger.info(f"Loading Kolors unet weights from {kolors_weights}...")
90
  try:
91
  state_dict = torch.load(kolors_weights, map_location=device, weights_only=False)
@@ -99,18 +105,23 @@ except Exception as e:
99
  logger.info(f"Loading IP-Adapter from {ip_adapter_weights}...")
100
  try:
101
  pipe.load_ip_adapter(ip_adapter_path, subfolder=None, weight_name="ipa-faceid-plus.bin")
 
102
  except Exception as e:
103
  logger.error(f"Failed to load IP-Adapter: {e}")
104
  raise
105
 
106
  # Move pipeline to CPU
 
107
  pipe.to(device)
 
108
 
109
  def generate_image(uploaded_image, prompt):
 
110
  try:
111
  img = cv2.cvtColor(np.array(uploaded_image), cv2.COLOR_RGB2BGR)
112
  faces = face_app.get(img)
113
  if not faces:
 
114
  return "No face detected!", None
115
 
116
  face_info = faces[-1]
@@ -120,11 +131,12 @@ def generate_image(uploaded_image, prompt):
120
  image = pipe(
121
  prompt=prompt,
122
  image_embeds=face_emb,
123
- num_inference_steps=15,
124
  guidance_scale=7.5,
125
- height=384,
126
- width=384,
127
  ).images[0]
 
128
  return "Image generated successfully!", image
129
  except Exception as e:
130
  logger.error(f"Generation failed: {e}")
@@ -145,4 +157,5 @@ interface = gr.Interface(
145
  description="Upload an image with a face and generate a new image."
146
  )
147
 
 
148
  interface.launch()
 
8
  import os
9
  import logging
10
 
11
+ # Set up detailed logging
12
  logging.basicConfig(level=logging.INFO)
13
  logger = logging.getLogger(__name__)
14
 
 
24
  os.makedirs(cache_dir, exist_ok=True)
25
 
26
  # Load face encoder
27
+ logger.info("Starting InsightFace initialization...")
28
  try:
29
  face_app = FaceAnalysis(providers=["CPUExecutionProvider"])
30
  face_app.prepare(ctx_id=0, det_size=(480, 480))
31
  logger.info("InsightFace model loaded successfully.")
32
  except Exception as e:
33
+ logger.error(f"Failed to load InsightFace model: {e}")
34
+ raise
35
 
36
  # Download function with explicit path return
37
  def download_file(repo_id, filename, local_dir):
 
61
  os.makedirs(kolors_unet_path, exist_ok=True)
62
  os.makedirs(ip_adapter_path, exist_ok=True)
63
 
64
+ # Download weights
65
+ logger.info("Starting weights download...")
66
  kolors_weights = download_file(
67
  "Kwai-Kolors/Kolors",
68
  "unet/diffusion_pytorch_model.fp16.safetensors",
 
74
  ip_adapter_path
75
  )
76
 
77
+ # Load the pipeline with verbose logging
78
  logger.info("Loading Stable Diffusion XL base model...")
79
  try:
80
  pipe = StableDiffusionXLPipeline.from_pretrained(
 
82
  torch_dtype=dtype,
83
  safety_checker=None,
84
  local_files_only=False,
85
+ cache_dir=cache_dir,
86
+ variant="fp16", # Use FP16 weights to reduce memory usage
87
+ use_safetensors=True # Prefer safetensors format if available
88
  )
89
+ logger.info("SDXL base model loaded successfully.")
90
  except Exception as e:
91
  logger.error(f"Failed to load SDXL base model: {e}")
92
  raise
93
 
94
+ # Load Kolors unet weights
95
  logger.info(f"Loading Kolors unet weights from {kolors_weights}...")
96
  try:
97
  state_dict = torch.load(kolors_weights, map_location=device, weights_only=False)
 
105
  logger.info(f"Loading IP-Adapter from {ip_adapter_weights}...")
106
  try:
107
  pipe.load_ip_adapter(ip_adapter_path, subfolder=None, weight_name="ipa-faceid-plus.bin")
108
+ logger.info("IP-Adapter loaded successfully.")
109
  except Exception as e:
110
  logger.error(f"Failed to load IP-Adapter: {e}")
111
  raise
112
 
113
  # Move pipeline to CPU
114
+ logger.info("Moving pipeline to CPU...")
115
  pipe.to(device)
116
+ logger.info("Pipeline moved to CPU.")
117
 
118
  def generate_image(uploaded_image, prompt):
119
+ logger.info("Starting image generation...")
120
  try:
121
  img = cv2.cvtColor(np.array(uploaded_image), cv2.COLOR_RGB2BGR)
122
  faces = face_app.get(img)
123
  if not faces:
124
+ logger.warning("No face detected in uploaded image.")
125
  return "No face detected!", None
126
 
127
  face_info = faces[-1]
 
131
  image = pipe(
132
  prompt=prompt,
133
  image_embeds=face_emb,
134
+ num_inference_steps=10, # Reduced steps for faster execution
135
  guidance_scale=7.5,
136
+ height=256, # Smaller resolution to fit memory
137
+ width=256
138
  ).images[0]
139
+ logger.info("Image generated successfully.")
140
  return "Image generated successfully!", image
141
  except Exception as e:
142
  logger.error(f"Generation failed: {e}")
 
157
  description="Upload an image with a face and generate a new image."
158
  )
159
 
160
+ logger.info("Launching Gradio interface...")
161
  interface.launch()