AkashKumarave commited on
Commit
596bd62
·
verified ·
1 Parent(s): 7412973

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -31
app.py CHANGED
@@ -32,48 +32,56 @@ ip_adapter_path = "./"
32
  print("Files in root directory:", os.listdir("."))
33
  print("Files in ./unet/ directory:", os.listdir("./unet") if os.path.exists("./unet") else "No ./unet/ directory")
34
 
35
- # Check if weights exist or download them with retry logic
36
- kolors_weights = model_path + "diffusion_pytorch_model.safetensors"
37
  if not os.path.exists(kolors_weights):
38
- kolors_weights_unet = "./unet/diffusion_pytorch_model.safetensors"
39
- if not os.path.exists(kolors_weights_unet):
40
- print("Preloading failed. Attempting runtime download with retry...")
41
- os.makedirs("./unet", exist_ok=True)
42
- max_retries = 3
43
- for attempt in range(max_retries):
44
- try:
45
- print(f"Download attempt {attempt + 1} of {max_retries}")
46
- urllib.request.urlretrieve(
47
- "https://huggingface.co/Kwai-Kolors/Kolors-diffusers/resolve/main/unet/diffusion_pytorch_model.safetensors",
48
- kolors_weights_unet
49
- )
50
- print("Kolors weights downloaded to", kolors_weights_unet)
51
- model_path = "./unet/"
52
- kolors_weights = kolors_weights_unet
53
- break
54
- except Exception as e:
55
- print(f"Download attempt {attempt + 1} failed: {e}")
56
- if attempt < max_retries - 1:
57
- time.sleep(5) # Wait 5 seconds before retrying
58
- else:
59
- raise FileNotFoundError(f"Failed to download Kolors weights after {max_retries} attempts: {e}. Check network access or contact support.")
60
- else:
61
- model_path = "./unet/"
62
- kolors_weights = kolors_weights_unet
63
- if not os.path.exists(ip_adapter_path + "ip-adapter.bin"):
 
 
 
 
 
 
 
 
64
  raise FileNotFoundError(f"IP-Adapter weights not found at {ip_adapter_path}")
65
 
66
  # Initialize model with empty weights
67
  with init_empty_weights():
68
  pipe = StableDiffusionXLPipeline.from_pretrained(
69
- "Kwai-Kolors/Kolors-diffusers",
70
  torch_dtype=dtype,
71
  safety_checker=None,
72
  )
73
 
74
  # Load and dispatch model with accelerate
75
  pipe = load_checkpoint_and_dispatch(pipe, model_path, device_map="cpu", offload_folder=None)
76
- pipe.load_ip_adapter("h94/IP-Adapter-FaceID-Plus-SDXL", subfolder=None, weight_name="ip-adapter.bin")
77
 
78
  def generate_image(uploaded_image, prompt):
79
  img = cv2.cvtColor(np.array(uploaded_image), cv2.COLOR_RGB2BGR)
@@ -101,7 +109,7 @@ interface = gr.Interface(
101
  fn=generate_image,
102
  inputs=[gr.Image(type="pil", label="Upload Reference Image"), gr.Textbox(label="Enter Prompt", placeholder="e.g., A photorealistic astronaut in space")],
103
  outputs=[gr.Textbox(label="Status"), gr.Image(label="Generated Image")],
104
- title="Face Reference Image Generator (Kolors with IP-Adapter)",
105
  description="Upload an image with a face, enter a prompt, and generate a new image preserving the reference face."
106
  )
107
 
 
32
  print("Files in root directory:", os.listdir("."))
33
  print("Files in ./unet/ directory:", os.listdir("./unet") if os.path.exists("./unet") else "No ./unet/ directory")
34
 
35
+ # Check if base model weights exist or download them
36
+ kolors_weights = model_path + "diffusers_weights.safetensors"
37
  if not os.path.exists(kolors_weights):
38
+ kolors_weights = model_path + "diffusion_pytorch_model.fp16.safetensors"
39
+ if not os.path.exists(kolors_weights):
40
+ kolors_weights_unet = "./unet/diffusion_pytorch_model.fp16.safetensors"
41
+ if not os.path.exists(kolors_weights_unet):
42
+ print("Preloading failed. Attempting runtime download with retry...")
43
+ os.makedirs("./unet", exist_ok=True)
44
+ max_retries = 3
45
+ correct_url = "https://huggingface.co/Kwai-Kolors/Kolors/raw/main/unet/diffusion_pytorch_model.fp16.safetensors"
46
+ for attempt in range(max_retries):
47
+ try:
48
+ print(f"Download attempt {attempt + 1} of {max_retries}")
49
+ urllib.request.urlretrieve(correct_url, kolors_weights_unet)
50
+ print("Kolors base weights downloaded to", kolors_weights_unet)
51
+ model_path = "./unet/"
52
+ kolors_weights = kolors_weights_unet
53
+ break
54
+ except urllib.error.HTTPError as e:
55
+ print(f"Download attempt {attempt + 1} failed: HTTP Error {e.code} - {e.reason}")
56
+ if attempt < max_retries - 1:
57
+ time.sleep(5)
58
+ else:
59
+ raise FileNotFoundError(f"Failed to download Kolors base weights after {max_retries} attempts: HTTP Error {e.code} - {e.reason}. Verify the URL or contact support.")
60
+ except Exception as e:
61
+ print(f"Download attempt {attempt + 1} failed: {e}")
62
+ if attempt < max_retries - 1:
63
+ time.sleep(5)
64
+ else:
65
+ raise FileNotFoundError(f"Failed to download Kolors base weights after {max_retries} attempts: {e}. Check network access or contact support.")
66
+ else:
67
+ model_path = "./unet/"
68
+ kolors_weights = kolors_weights_unet
69
+
70
+ # Check if IP-Adapter weights exist (preloaded)
71
+ if not os.path.exists(ip_adapter_path + "ipa-faceid-plus.bin"):
72
  raise FileNotFoundError(f"IP-Adapter weights not found at {ip_adapter_path}")
73
 
74
  # Initialize model with empty weights
75
  with init_empty_weights():
76
  pipe = StableDiffusionXLPipeline.from_pretrained(
77
+ "Kwai-Kolors/Kolors-IP-Adapter-FaceID-Plus",
78
  torch_dtype=dtype,
79
  safety_checker=None,
80
  )
81
 
82
  # Load and dispatch model with accelerate
83
  pipe = load_checkpoint_and_dispatch(pipe, model_path, device_map="cpu", offload_folder=None)
84
+ pipe.load_ip_adapter("Kwai-Kolors/Kolors-IP-Adapter-FaceID-Plus", subfolder=None, weight_name="ipa-faceid-plus.bin")
85
 
86
  def generate_image(uploaded_image, prompt):
87
  img = cv2.cvtColor(np.array(uploaded_image), cv2.COLOR_RGB2BGR)
 
109
  fn=generate_image,
110
  inputs=[gr.Image(type="pil", label="Upload Reference Image"), gr.Textbox(label="Enter Prompt", placeholder="e.g., A photorealistic astronaut in space")],
111
  outputs=[gr.Textbox(label="Status"), gr.Image(label="Generated Image")],
112
+ title="Face Reference Image Generator (Kolors-IP-Adapter-FaceID-Plus)",
113
  description="Upload an image with a face, enter a prompt, and generate a new image preserving the reference face."
114
  )
115