rastof9 commited on
Commit
d26d672
·
verified ·
1 Parent(s): c3ecb86

Update generate.py

Browse files
Files changed (1) hide show
  1. generate.py +5 -10
generate.py CHANGED
@@ -77,7 +77,6 @@ class GenerationService:
77
 
78
  faces = self.face_app.get(face)
79
  if faces:
80
- # Shape of normed_embedding is (512,). .unsqueeze(0) makes it (1, 512)
81
  faceid_embed = torch.from_numpy(faces[0].normed_embedding).unsqueeze(0)
82
  faceid_all_embeds.append(faceid_embed)
83
  except Exception as e:
@@ -87,27 +86,23 @@ class GenerationService:
87
  logger.error("No faces were detected in any of the provided images.")
88
  return None
89
 
90
- # Stack embeds into a single tensor and calculate the average
91
- # Shape of stacked tensor: (num_images, 1, 512)
92
- # Shape of average_embedding: (1, 512)
93
  average_embedding = torch.mean(torch.stack(faceid_all_embeds, dim=0), dim=0)
94
 
95
  logger.info("Calling the generation pipeline...")
96
  try:
97
- # FIX: The pipeline expects a 3D or 4D tensor.
98
- # We add a "sequence length" dimension of 1.
99
- # Shape becomes: (1, 1, 512)
100
  final_embedding = average_embedding.unsqueeze(0)
101
 
102
- image = self.pipe(
 
103
  prompt=full_prompt,
104
  negative_prompt=negative_prompt,
105
- ip_adapter_image_embeds=[final_embedding], # Pass the correctly shaped tensor
106
  num_inference_steps=40,
107
  guidance_scale=7.5,
108
  width=512,
109
  height=768,
110
- ).images[0]
 
111
 
112
  # --- Save image locally first ---
113
  temp_dir = "temp_images"
 
77
 
78
  faces = self.face_app.get(face)
79
  if faces:
 
80
  faceid_embed = torch.from_numpy(faces[0].normed_embedding).unsqueeze(0)
81
  faceid_all_embeds.append(faceid_embed)
82
  except Exception as e:
 
86
  logger.error("No faces were detected in any of the provided images.")
87
  return None
88
 
 
 
 
89
  average_embedding = torch.mean(torch.stack(faceid_all_embeds, dim=0), dim=0)
90
 
91
  logger.info("Calling the generation pipeline...")
92
  try:
 
 
 
93
  final_embedding = average_embedding.unsqueeze(0)
94
 
95
+ # FIX: The pipeline returns an object, not a tuple. Access the .images attribute.
96
+ pipeline_output = self.pipe(
97
  prompt=full_prompt,
98
  negative_prompt=negative_prompt,
99
+ ip_adapter_image_embeds=[final_embedding],
100
  num_inference_steps=40,
101
  guidance_scale=7.5,
102
  width=512,
103
  height=768,
104
+ )
105
+ image = pipeline_output.images[0]
106
 
107
  # --- Save image locally first ---
108
  temp_dir = "temp_images"