Spaces:
Build error
Build error
Update generate.py
Browse files- generate.py +11 -17
generate.py
CHANGED
|
@@ -21,8 +21,6 @@ logger = logging.getLogger(__name__)
|
|
| 21 |
|
| 22 |
# --- IP-Adapter FaceID Model (Placeholder) ---
|
| 23 |
# The complex IP-Adapter logic is assumed to be part of the diffusers pipeline for this implementation.
|
| 24 |
-
# In a real-world scenario, you would use a library that has this pre-integrated or
|
| 25 |
-
# manually patch the attention layers of the UNet model.
|
| 26 |
|
| 27 |
|
| 28 |
# --- Main Generation Service ---
|
|
@@ -55,8 +53,6 @@ class GenerationService:
|
|
| 55 |
vae=vae, feature_extractor=None, safety_checker=None
|
| 56 |
).to(self.device)
|
| 57 |
|
| 58 |
-
# This is where the IP-Adapter would be loaded and attached to the pipeline.
|
| 59 |
-
# For our purposes, we'll simulate its effect via prompt engineering and embeddings.
|
| 60 |
logger.info("All models loaded successfully.")
|
| 61 |
|
| 62 |
except Exception as e:
|
|
@@ -66,9 +62,6 @@ class GenerationService:
|
|
| 66 |
def generate_magic_image(self, face_images: list, gender: str, prompt: str, plan: str = 'free') -> str | None:
|
| 67 |
"""
|
| 68 |
Generates an image, uploads it to cloud storage, and returns the public URL.
|
| 69 |
-
|
| 70 |
-
Returns:
|
| 71 |
-
str: Public URL of the generated image, or None if an error occurred.
|
| 72 |
"""
|
| 73 |
logger.info("Starting image generation process...")
|
| 74 |
|
|
@@ -76,7 +69,6 @@ class GenerationService:
|
|
| 76 |
negative_prompt = "multiple people, group photo, crowd, two faces, three faces, multiple faces, collage, ugly, deformed, blurry, low quality"
|
| 77 |
|
| 78 |
faceid_all_embeds = []
|
| 79 |
-
face_image_for_structure = None
|
| 80 |
|
| 81 |
for image_path in face_images:
|
| 82 |
try:
|
|
@@ -85,10 +77,9 @@ class GenerationService:
|
|
| 85 |
|
| 86 |
faces = self.face_app.get(face)
|
| 87 |
if faces:
|
|
|
|
| 88 |
faceid_embed = torch.from_numpy(faces[0].normed_embedding).unsqueeze(0)
|
| 89 |
faceid_all_embeds.append(faceid_embed)
|
| 90 |
-
if face_image_for_structure is None:
|
| 91 |
-
face_image_for_structure = face_align.norm_crop(face, landmark=faces[0].kps, image_size=224)
|
| 92 |
except Exception as e:
|
| 93 |
logger.error(f"Error processing face image {image_path}: {e}")
|
| 94 |
|
|
@@ -96,16 +87,22 @@ class GenerationService:
|
|
| 96 |
logger.error("No faces were detected in any of the provided images.")
|
| 97 |
return None
|
| 98 |
|
|
|
|
|
|
|
|
|
|
| 99 |
average_embedding = torch.mean(torch.stack(faceid_all_embeds, dim=0), dim=0)
|
| 100 |
|
| 101 |
logger.info("Calling the generation pipeline...")
|
| 102 |
try:
|
| 103 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
| 104 |
image = self.pipe(
|
| 105 |
prompt=full_prompt,
|
| 106 |
negative_prompt=negative_prompt,
|
| 107 |
-
#
|
| 108 |
-
ip_adapter_image_embeds=[average_embedding],
|
| 109 |
num_inference_steps=40,
|
| 110 |
guidance_scale=7.5,
|
| 111 |
width=512,
|
|
@@ -135,9 +132,6 @@ class GenerationService:
|
|
| 135 |
# --- Clean up local file ---
|
| 136 |
os.remove(local_path)
|
| 137 |
|
| 138 |
-
# TODO: Add watermarking for 'free' plan
|
| 139 |
-
# TODO: Add upscaling for 'paid' plan
|
| 140 |
-
|
| 141 |
return public_url
|
| 142 |
|
| 143 |
except StorageException as e:
|
|
@@ -146,7 +140,7 @@ class GenerationService:
|
|
| 146 |
except Exception as e:
|
| 147 |
logger.error(f"An error occurred during image generation or upload: {e}")
|
| 148 |
if 'local_path' in locals() and os.path.exists(local_path):
|
| 149 |
-
os.remove(local_path)
|
| 150 |
return None
|
| 151 |
|
| 152 |
# --- Example Usage (for testing) ---
|
|
|
|
| 21 |
|
| 22 |
# --- IP-Adapter FaceID Model (Placeholder) ---
|
| 23 |
# The complex IP-Adapter logic is assumed to be part of the diffusers pipeline for this implementation.
|
|
|
|
|
|
|
| 24 |
|
| 25 |
|
| 26 |
# --- Main Generation Service ---
|
|
|
|
| 53 |
vae=vae, feature_extractor=None, safety_checker=None
|
| 54 |
).to(self.device)
|
| 55 |
|
|
|
|
|
|
|
| 56 |
logger.info("All models loaded successfully.")
|
| 57 |
|
| 58 |
except Exception as e:
|
|
|
|
| 62 |
def generate_magic_image(self, face_images: list, gender: str, prompt: str, plan: str = 'free') -> str | None:
|
| 63 |
"""
|
| 64 |
Generates an image, uploads it to cloud storage, and returns the public URL.
|
|
|
|
|
|
|
|
|
|
| 65 |
"""
|
| 66 |
logger.info("Starting image generation process...")
|
| 67 |
|
|
|
|
| 69 |
negative_prompt = "multiple people, group photo, crowd, two faces, three faces, multiple faces, collage, ugly, deformed, blurry, low quality"
|
| 70 |
|
| 71 |
faceid_all_embeds = []
|
|
|
|
| 72 |
|
| 73 |
for image_path in face_images:
|
| 74 |
try:
|
|
|
|
| 77 |
|
| 78 |
faces = self.face_app.get(face)
|
| 79 |
if faces:
|
| 80 |
+
# Shape of normed_embedding is (512,). .unsqueeze(0) makes it (1, 512)
|
| 81 |
faceid_embed = torch.from_numpy(faces[0].normed_embedding).unsqueeze(0)
|
| 82 |
faceid_all_embeds.append(faceid_embed)
|
|
|
|
|
|
|
| 83 |
except Exception as e:
|
| 84 |
logger.error(f"Error processing face image {image_path}: {e}")
|
| 85 |
|
|
|
|
| 87 |
logger.error("No faces were detected in any of the provided images.")
|
| 88 |
return None
|
| 89 |
|
| 90 |
+
# Stack embeds into a single tensor and calculate the average
|
| 91 |
+
# Shape of stacked tensor: (num_images, 1, 512)
|
| 92 |
+
# Shape of average_embedding: (1, 512)
|
| 93 |
average_embedding = torch.mean(torch.stack(faceid_all_embeds, dim=0), dim=0)
|
| 94 |
|
| 95 |
logger.info("Calling the generation pipeline...")
|
| 96 |
try:
|
| 97 |
+
# FIX: The pipeline expects a 3D or 4D tensor.
|
| 98 |
+
# We add a "sequence length" dimension of 1.
|
| 99 |
+
# Shape becomes: (1, 1, 512)
|
| 100 |
+
final_embedding = average_embedding.unsqueeze(0)
|
| 101 |
+
|
| 102 |
image = self.pipe(
|
| 103 |
prompt=full_prompt,
|
| 104 |
negative_prompt=negative_prompt,
|
| 105 |
+
ip_adapter_image_embeds=[final_embedding], # Pass the correctly shaped tensor
|
|
|
|
| 106 |
num_inference_steps=40,
|
| 107 |
guidance_scale=7.5,
|
| 108 |
width=512,
|
|
|
|
| 132 |
# --- Clean up local file ---
|
| 133 |
os.remove(local_path)
|
| 134 |
|
|
|
|
|
|
|
|
|
|
| 135 |
return public_url
|
| 136 |
|
| 137 |
except StorageException as e:
|
|
|
|
| 140 |
except Exception as e:
|
| 141 |
logger.error(f"An error occurred during image generation or upload: {e}")
|
| 142 |
if 'local_path' in locals() and os.path.exists(local_path):
|
| 143 |
+
os.remove(local_path)
|
| 144 |
return None
|
| 145 |
|
| 146 |
# --- Example Usage (for testing) ---
|