AkashKumarave commited on
Commit
9b003e1
·
verified ·
1 Parent(s): f52c5b7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +92 -57
app.py CHANGED
@@ -6,93 +6,128 @@ from diffusers import StableDiffusionXLPipeline
6
  from insightface.app import FaceAnalysis
7
  from huggingface_hub import hf_hub_download
8
  import os
 
9
 
10
- # Allow network access for runtime downloads
 
 
 
 
11
  os.environ["HF_HUB_OFFLINE"] = "0"
12
 
13
- # Set device to CPU (Hugging Face free tier is CPU-only)
14
  device = "cpu"
15
- dtype = torch.float32 # Use float32 to avoid GPU-specific optimizations
 
 
 
 
16
 
17
- # Load face encoder (InsightFace will download its weights on first run)
18
  try:
19
  face_app = FaceAnalysis(providers=["CPUExecutionProvider"])
20
  face_app.prepare(ctx_id=0, det_size=(480, 480))
21
- print("InsightFace model loaded successfully.")
22
  except Exception as e:
23
- raise RuntimeError(f"Failed to load InsightFace model: {e}. Ensure network access.")
24
 
25
- # Define paths for temporary storage (ephemeral in Spaces)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  kolors_unet_path = "./unet"
27
  ip_adapter_path = "./"
 
 
28
 
29
- # Download Kolors unet weights at runtime
30
- kolors_weights = os.path.join(kolors_unet_path, "diffusion_pytorch_model.fp16.safetensors")
31
- if not os.path.exists(kolors_weights):
32
- print("Downloading Kolors unet weights...")
33
- os.makedirs(kolors_unet_path, exist_ok=True)
34
- hf_hub_download(
35
- repo_id="Kwai-Kolors/Kolors",
36
- filename="unet/diffusion_pytorch_model.fp16.safetensors",
37
- local_dir=kolors_unet_path,
38
- local_files_only=False
39
- )
40
- print("Kolors unet weights downloaded to", kolors_weights)
41
-
42
- # Download IP-Adapter weights at runtime
43
- ip_adapter_weights = os.path.join(ip_adapter_path, "ipa-faceid-plus.bin")
44
- if not os.path.exists(ip_adapter_weights):
45
- print("Downloading IP-Adapter weights...")
46
- hf_hub_download(
47
- repo_id="Kwai-Kolors/Kolors-IP-Adapter-FaceID-Plus",
48
- filename="ipa-faceid-plus.bin",
49
- local_dir=ip_adapter_path,
50
- local_files_only=False
51
- )
52
- print("IP-Adapter weights downloaded to", ip_adapter_weights)
53
-
54
- # Load the base SDXL pipeline directly from Hugging Face Hub
55
- print("Loading Stable Diffusion XL base model...")
56
- pipe = StableDiffusionXLPipeline.from_pretrained(
57
- "stabilityai/stable-diffusion-xl-base-1.0",
58
- torch_dtype=dtype,
59
- safety_checker=None,
60
- local_files_only=False, # Download from Hub at runtime
61
- cache_dir="./cache" # Use temporary cache directory
62
  )
63
 
64
- # Replace unet with Kolors weights
65
- print("Loading Kolors unet weights into pipeline...")
66
- pipe.unet.load_state_dict(torch.load(kolors_weights, map_location=device))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
68
  # Load IP-Adapter
69
- print("Loading IP-Adapter...")
70
- pipe.load_ip_adapter(ip_adapter_path, subfolder=None, weight_name="ipa-faceid-plus.bin")
 
 
 
 
71
 
72
  # Move pipeline to CPU
73
  pipe.to(device)
74
 
75
  def generate_image(uploaded_image, prompt):
76
- img = cv2.cvtColor(np.array(uploaded_image), cv2.COLOR_RGB2BGR)
77
- faces = face_app.get(img)
78
- if not faces:
79
- return "No face detected!", None
 
80
 
81
- face_info = faces[-1]
82
- face_emb = face_info["embedding"]
83
 
84
- try:
85
- # Reduce inference steps and resolution to fit free tier limits
86
  image = pipe(
87
  prompt=prompt,
88
  image_embeds=face_emb,
89
- num_inference_steps=15, # Lower steps for faster execution
90
  guidance_scale=7.5,
91
- height=384, # Smaller resolution to reduce memory usage
92
  width=384,
93
  ).images[0]
94
  return "Image generated successfully!", image
95
  except Exception as e:
 
96
  return f"Generation failed: {e}", None
97
 
98
  # Gradio interface
@@ -106,8 +141,8 @@ interface = gr.Interface(
106
  gr.Textbox(label="Status"),
107
  gr.Image(label="Generated Image")
108
  ],
109
- title="Face Reference Image Generator (Kolors-IP-Adapter-FaceID-Plus)",
110
- description="Upload an image with a face, enter a prompt, and generate a new image preserving the reference face."
111
  )
112
 
113
  interface.launch()
 
6
  from insightface.app import FaceAnalysis
7
  from huggingface_hub import hf_hub_download
8
  import os
9
+ import logging
10
 
11
+ # Set up logging
12
+ logging.basicConfig(level=logging.INFO)
13
+ logger = logging.getLogger(__name__)
14
+
15
+ # Allow network access
16
  os.environ["HF_HUB_OFFLINE"] = "0"
17
 
18
+ # Set device to CPU
19
  device = "cpu"
20
+ dtype = torch.float32
21
+
22
+ # Define cache directory
23
+ cache_dir = "./cache"
24
+ os.makedirs(cache_dir, exist_ok=True)
25
 
26
+ # Load face encoder
27
  try:
28
  face_app = FaceAnalysis(providers=["CPUExecutionProvider"])
29
  face_app.prepare(ctx_id=0, det_size=(480, 480))
30
+ logger.info("InsightFace model loaded successfully.")
31
  except Exception as e:
32
+ raise RuntimeError(f"Failed to load InsightFace model: {e}")
33
 
34
+ # Download function with explicit path return
35
+ def download_file(repo_id, filename, local_dir):
36
+ file_path = os.path.join(local_dir, filename)
37
+ if not os.path.exists(file_path):
38
+ logger.info(f"Downloading {filename} from {repo_id} to {local_dir}...")
39
+ try:
40
+ downloaded_path = hf_hub_download(
41
+ repo_id=repo_id,
42
+ filename=filename,
43
+ local_dir=local_dir,
44
+ cache_dir=cache_dir,
45
+ local_files_only=False
46
+ )
47
+ logger.info(f"Downloaded to {downloaded_path}")
48
+ return downloaded_path # Return the actual path from hf_hub_download
49
+ except Exception as e:
50
+ logger.error(f"Download failed: {e}")
51
+ raise
52
+ else:
53
+ logger.info(f"Using cached file at {file_path}")
54
+ return file_path
55
+
56
+ # Define paths
57
  kolors_unet_path = "./unet"
58
  ip_adapter_path = "./"
59
+ os.makedirs(kolors_unet_path, exist_ok=True)
60
+ os.makedirs(ip_adapter_path, exist_ok=True)
61
 
62
+ # Download weights and get exact paths
63
+ kolors_weights = download_file(
64
+ "Kwai-Kolors/Kolors",
65
+ "unet/diffusion_pytorch_model.fp16.safetensors",
66
+ kolors_unet_path
67
+ )
68
+ ip_adapter_weights = download_file(
69
+ "Kwai-Kolors/Kolors-IP-Adapter-FaceID-Plus",
70
+ "ipa-faceid-plus.bin",
71
+ ip_adapter_path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  )
73
 
74
+ # Load the pipeline
75
+ logger.info("Loading Stable Diffusion XL base model...")
76
+ try:
77
+ pipe = StableDiffusionXLPipeline.from_pretrained(
78
+ "stabilityai/stable-diffusion-xl-base-1.0",
79
+ torch_dtype=dtype,
80
+ safety_checker=None,
81
+ local_files_only=False,
82
+ cache_dir=cache_dir
83
+ )
84
+ except Exception as e:
85
+ logger.error(f"Failed to load SDXL base model: {e}")
86
+ raise
87
+
88
+ # Load Kolors unet weights
89
+ logger.info(f"Loading Kolors unet weights from {kolors_weights}...")
90
+ try:
91
+ state_dict = torch.load(kolors_weights, map_location=device)
92
+ pipe.unet.load_state_dict(state_dict)
93
+ logger.info("Kolors unet weights loaded successfully.")
94
+ except Exception as e:
95
+ logger.error(f"Failed to load Kolors unet weights: {e}")
96
+ raise
97
 
98
  # Load IP-Adapter
99
+ logger.info(f"Loading IP-Adapter from {ip_adapter_weights}...")
100
+ try:
101
+ pipe.load_ip_adapter(ip_adapter_path, subfolder=None, weight_name="ipa-faceid-plus.bin")
102
+ except Exception as e:
103
+ logger.error(f"Failed to load IP-Adapter: {e}")
104
+ raise
105
 
106
  # Move pipeline to CPU
107
  pipe.to(device)
108
 
109
  def generate_image(uploaded_image, prompt):
110
+ try:
111
+ img = cv2.cvtColor(np.array(uploaded_image), cv2.COLOR_RGB2BGR)
112
+ faces = face_app.get(img)
113
+ if not faces:
114
+ return "No face detected!", None
115
 
116
+ face_info = faces[-1]
117
+ face_emb = face_info["embedding"]
118
 
119
+ logger.info(f"Generating image with prompt: {prompt}")
 
120
  image = pipe(
121
  prompt=prompt,
122
  image_embeds=face_emb,
123
+ num_inference_steps=15,
124
  guidance_scale=7.5,
125
+ height=384,
126
  width=384,
127
  ).images[0]
128
  return "Image generated successfully!", image
129
  except Exception as e:
130
+ logger.error(f"Generation failed: {e}")
131
  return f"Generation failed: {e}", None
132
 
133
  # Gradio interface
 
141
  gr.Textbox(label="Status"),
142
  gr.Image(label="Generated Image")
143
  ],
144
+ title="Face Reference Image Generator",
145
+ description="Upload an image with a face and generate a new image."
146
  )
147
 
148
  interface.launch()