AkashKumarave commited on
Commit
4932f9e
·
verified ·
1 Parent(s): a453b4c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -27
app.py CHANGED
@@ -25,7 +25,7 @@ except Exception as e:
25
  raise RuntimeError(f"Failed to load InsightFace model: {e}. Ensure network access for initial download.")
26
 
27
  # Define paths for preloaded or downloaded weights
28
- model_path = "./unet/" # Updated to your unet directory
29
  ip_adapter_path = "./"
30
 
31
  # Debug: List files to confirm preloading or download
@@ -33,39 +33,48 @@ print("Files in root directory:", os.listdir("."))
33
  print("Files in ./unet/ directory:", os.listdir("./unet") if os.path.exists("./unet") else "No ./unet/ directory")
34
 
35
  # Check if base model weights exist or download them
36
- kolors_weights = model_path + "diffusion_pytorch_model.fp16.safetensors"
37
  if not os.path.exists(kolors_weights):
38
- print("Kolors base weights not found. Attempting runtime download...")
39
- os.makedirs("./unet", exist_ok=True)
40
- max_retries = 3
41
- for attempt in range(max_retries):
42
- try:
43
- print(f"Download attempt {attempt + 1} of {max_retries}")
44
- hf_hub_download(
45
- repo_id="AkashKumarave/my3",
46
- filename="unet/diffusion_pytorch_model.fp16.safetensors",
47
- local_dir="./unet",
48
- local_files_only=False
49
- )
50
- print("Kolors base weights downloaded to", kolors_weights)
51
- break
52
- except Exception as e:
53
- print(f"Download attempt {attempt + 1} failed: {e}")
54
- if attempt < max_retries - 1:
55
- time.sleep(5)
56
- else:
57
- raise FileNotFoundError(f"Failed to download Kolors base weights after {max_retries} attempts: {e}. Verify the repo or contact support.")
 
 
 
 
 
 
 
 
 
58
 
59
  # Check if IP-Adapter weights exist or download them
60
  ip_adapter_weights = ip_adapter_path + "ipa-faceid-plus.bin"
61
  if not os.path.exists(ip_adapter_weights):
62
- print("IP-Adapter preloading failed. Attempting runtime download...")
63
  max_retries = 3
64
  for attempt in range(max_retries):
65
  try:
66
  print(f"IP-Adapter download attempt {attempt + 1} of {max_retries}")
67
  hf_hub_download(
68
- repo_id="AkashKumarave/my3",
69
  filename="ipa-faceid-plus.bin",
70
  local_dir="./",
71
  local_files_only=False
@@ -82,14 +91,14 @@ if not os.path.exists(ip_adapter_weights):
82
  # Initialize model with empty weights
83
  with init_empty_weights():
84
  pipe = StableDiffusionXLPipeline.from_pretrained(
85
- "AkashKumarave/my3",
86
  torch_dtype=dtype,
87
  safety_checker=None,
88
  )
89
 
90
  # Load and dispatch model with accelerate
91
  pipe = load_checkpoint_and_dispatch(pipe, model_path, device_map="cpu", offload_folder=None)
92
- pipe.load_ip_adapter("AkashKumarave/my3", subfolder=None, weight_name="ipa-faceid-plus.bin")
93
 
94
  def generate_image(uploaded_image, prompt):
95
  img = cv2.cvtColor(np.array(uploaded_image), cv2.COLOR_RGB2BGR)
@@ -121,4 +130,4 @@ interface = gr.Interface(
121
  description="Upload an image with a face, enter a prompt, and generate a new image preserving the reference face."
122
  )
123
 
124
- interface.launch()
 
25
  raise RuntimeError(f"Failed to load InsightFace model: {e}. Ensure network access for initial download.")
26
 
27
  # Define paths for preloaded or downloaded weights
28
+ model_path = "./" # Start with root
29
  ip_adapter_path = "./"
30
 
31
  # Debug: List files to confirm preloading or download
 
33
  print("Files in ./unet/ directory:", os.listdir("./unet") if os.path.exists("./unet") else "No ./unet/ directory")
34
 
35
  # Check if base model weights exist or download them
36
+ kolors_weights = model_path + "diffusers_weights.safetensors"
37
  if not os.path.exists(kolors_weights):
38
+ kolors_weights = model_path + "diffusion_pytorch_model.fp16.safetensors"
39
+ if not os.path.exists(kolors_weights):
40
+ kolors_weights_unet = "./unet/diffusion_pytorch_model.fp16.safetensors"
41
+ if not os.path.exists(kolors_weights_unet):
42
+ print("Preloading failed. Attempting runtime download with retry...")
43
+ os.makedirs("./unet", exist_ok=True)
44
+ max_retries = 3
45
+ for attempt in range(max_retries):
46
+ try:
47
+ print(f"Download attempt {attempt + 1} of {max_retries}")
48
+ hf_hub_download(
49
+ repo_id="Kwai-Kolors/Kolors",
50
+ filename="unet/diffusion_pytorch_model.fp16.safetensors",
51
+ local_dir="./unet",
52
+ local_files_only=False
53
+ )
54
+ print("Kolors base weights downloaded to", kolors_weights_unet)
55
+ model_path = "./unet/"
56
+ kolors_weights = kolors_weights_unet
57
+ break
58
+ except Exception as e:
59
+ print(f"Download attempt {attempt + 1} failed: {e}")
60
+ if attempt < max_retries - 1:
61
+ time.sleep(5)
62
+ else:
63
+ raise FileNotFoundError(f"Failed to download Kolors base weights after {max_retries} attempts: {e}. Verify the repo or contact support.")
64
+ else:
65
+ model_path = "./unet/"
66
+ kolors_weights = kolors_weights_unet
67
 
68
  # Check if IP-Adapter weights exist or download them
69
  ip_adapter_weights = ip_adapter_path + "ipa-faceid-plus.bin"
70
  if not os.path.exists(ip_adapter_weights):
71
+ print("IP-Adapter preloading failed. Attempting runtime download with retry...")
72
  max_retries = 3
73
  for attempt in range(max_retries):
74
  try:
75
  print(f"IP-Adapter download attempt {attempt + 1} of {max_retries}")
76
  hf_hub_download(
77
+ repo_id="Kwai-Kolors/Kolors-IP-Adapter-FaceID-Plus",
78
  filename="ipa-faceid-plus.bin",
79
  local_dir="./",
80
  local_files_only=False
 
91
  # Initialize model with empty weights
92
  with init_empty_weights():
93
  pipe = StableDiffusionXLPipeline.from_pretrained(
94
+ "Kwai-Kolors/Kolors-IP-Adapter-FaceID-Plus",
95
  torch_dtype=dtype,
96
  safety_checker=None,
97
  )
98
 
99
  # Load and dispatch model with accelerate
100
  pipe = load_checkpoint_and_dispatch(pipe, model_path, device_map="cpu", offload_folder=None)
101
+ pipe.load_ip_adapter("Kwai-Kolors/Kolors-IP-Adapter-FaceID-Plus", subfolder=None, weight_name="ipa-faceid-plus.bin")
102
 
103
  def generate_image(uploaded_image, prompt):
104
  img = cv2.cvtColor(np.array(uploaded_image), cv2.COLOR_RGB2BGR)
 
130
  description="Upload an image with a face, enter a prompt, and generate a new image preserving the reference face."
131
  )
132
 
133
+ interface.launch()