rastof9 commited on
Commit
c3ecb86
·
verified ·
1 Parent(s): 183897a

Update generate.py

Browse files
Files changed (1) hide show
  1. generate.py +11 -17
generate.py CHANGED
@@ -21,8 +21,6 @@ logger = logging.getLogger(__name__)
21
 
22
  # --- IP-Adapter FaceID Model (Placeholder) ---
23
  # The complex IP-Adapter logic is assumed to be part of the diffusers pipeline for this implementation.
24
- # In a real-world scenario, you would use a library that has this pre-integrated or
25
- # manually patch the attention layers of the UNet model.
26
 
27
 
28
  # --- Main Generation Service ---
@@ -55,8 +53,6 @@ class GenerationService:
55
  vae=vae, feature_extractor=None, safety_checker=None
56
  ).to(self.device)
57
 
58
- # This is where the IP-Adapter would be loaded and attached to the pipeline.
59
- # For our purposes, we'll simulate its effect via prompt engineering and embeddings.
60
  logger.info("All models loaded successfully.")
61
 
62
  except Exception as e:
@@ -66,9 +62,6 @@ class GenerationService:
66
  def generate_magic_image(self, face_images: list, gender: str, prompt: str, plan: str = 'free') -> str | None:
67
  """
68
  Generates an image, uploads it to cloud storage, and returns the public URL.
69
-
70
- Returns:
71
- str: Public URL of the generated image, or None if an error occurred.
72
  """
73
  logger.info("Starting image generation process...")
74
 
@@ -76,7 +69,6 @@ class GenerationService:
76
  negative_prompt = "multiple people, group photo, crowd, two faces, three faces, multiple faces, collage, ugly, deformed, blurry, low quality"
77
 
78
  faceid_all_embeds = []
79
- face_image_for_structure = None
80
 
81
  for image_path in face_images:
82
  try:
@@ -85,10 +77,9 @@ class GenerationService:
85
 
86
  faces = self.face_app.get(face)
87
  if faces:
 
88
  faceid_embed = torch.from_numpy(faces[0].normed_embedding).unsqueeze(0)
89
  faceid_all_embeds.append(faceid_embed)
90
- if face_image_for_structure is None:
91
- face_image_for_structure = face_align.norm_crop(face, landmark=faces[0].kps, image_size=224)
92
  except Exception as e:
93
  logger.error(f"Error processing face image {image_path}: {e}")
94
 
@@ -96,16 +87,22 @@ class GenerationService:
96
  logger.error("No faces were detected in any of the provided images.")
97
  return None
98
 
 
 
 
99
  average_embedding = torch.mean(torch.stack(faceid_all_embeds, dim=0), dim=0)
100
 
101
  logger.info("Calling the generation pipeline...")
102
  try:
103
- # This is a conceptual representation of how the IP-Adapter is used.
 
 
 
 
104
  image = self.pipe(
105
  prompt=full_prompt,
106
  negative_prompt=negative_prompt,
107
- # FIX: The pipeline expects a list of embeddings.
108
- ip_adapter_image_embeds=[average_embedding],
109
  num_inference_steps=40,
110
  guidance_scale=7.5,
111
  width=512,
@@ -135,9 +132,6 @@ class GenerationService:
135
  # --- Clean up local file ---
136
  os.remove(local_path)
137
 
138
- # TODO: Add watermarking for 'free' plan
139
- # TODO: Add upscaling for 'paid' plan
140
-
141
  return public_url
142
 
143
  except StorageException as e:
@@ -146,7 +140,7 @@ class GenerationService:
146
  except Exception as e:
147
  logger.error(f"An error occurred during image generation or upload: {e}")
148
  if 'local_path' in locals() and os.path.exists(local_path):
149
- os.remove(local_path) # Clean up even on failure
150
  return None
151
 
152
  # --- Example Usage (for testing) ---
 
21
 
22
  # --- IP-Adapter FaceID Model (Placeholder) ---
23
  # The complex IP-Adapter logic is assumed to be part of the diffusers pipeline for this implementation.
 
 
24
 
25
 
26
  # --- Main Generation Service ---
 
53
  vae=vae, feature_extractor=None, safety_checker=None
54
  ).to(self.device)
55
 
 
 
56
  logger.info("All models loaded successfully.")
57
 
58
  except Exception as e:
 
62
  def generate_magic_image(self, face_images: list, gender: str, prompt: str, plan: str = 'free') -> str | None:
63
  """
64
  Generates an image, uploads it to cloud storage, and returns the public URL.
 
 
 
65
  """
66
  logger.info("Starting image generation process...")
67
 
 
69
  negative_prompt = "multiple people, group photo, crowd, two faces, three faces, multiple faces, collage, ugly, deformed, blurry, low quality"
70
 
71
  faceid_all_embeds = []
 
72
 
73
  for image_path in face_images:
74
  try:
 
77
 
78
  faces = self.face_app.get(face)
79
  if faces:
80
+ # Shape of normed_embedding is (512,). .unsqueeze(0) makes it (1, 512)
81
  faceid_embed = torch.from_numpy(faces[0].normed_embedding).unsqueeze(0)
82
  faceid_all_embeds.append(faceid_embed)
 
 
83
  except Exception as e:
84
  logger.error(f"Error processing face image {image_path}: {e}")
85
 
 
87
  logger.error("No faces were detected in any of the provided images.")
88
  return None
89
 
90
+ # Stack embeds into a single tensor and calculate the average
91
+ # Shape of stacked tensor: (num_images, 1, 512)
92
+ # Shape of average_embedding: (1, 512)
93
  average_embedding = torch.mean(torch.stack(faceid_all_embeds, dim=0), dim=0)
94
 
95
  logger.info("Calling the generation pipeline...")
96
  try:
97
+ # FIX: The pipeline expects a 3D or 4D tensor.
98
+ # We add a "sequence length" dimension of 1.
99
+ # Shape becomes: (1, 1, 512)
100
+ final_embedding = average_embedding.unsqueeze(0)
101
+
102
  image = self.pipe(
103
  prompt=full_prompt,
104
  negative_prompt=negative_prompt,
105
+ ip_adapter_image_embeds=[final_embedding], # Pass the correctly shaped tensor
 
106
  num_inference_steps=40,
107
  guidance_scale=7.5,
108
  width=512,
 
132
  # --- Clean up local file ---
133
  os.remove(local_path)
134
 
 
 
 
135
  return public_url
136
 
137
  except StorageException as e:
 
140
  except Exception as e:
141
  logger.error(f"An error occurred during image generation or upload: {e}")
142
  if 'local_path' in locals() and os.path.exists(local_path):
143
+ os.remove(local_path)
144
  return None
145
 
146
  # --- Example Usage (for testing) ---