Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| from config import Config | |
| from utils import resize_image_to_1mp, get_caption, draw_kps | |
| from PIL import Image | |
| class Generator: | |
| def __init__(self, model_handler): | |
| self.mh = model_handler | |
| def prepare_control_images(self, image, width, height): | |
| """ | |
| Generates conditioning maps, ensuring they are resized | |
| to the exact target dimensions (width, height). | |
| """ | |
| print(f"Generating control maps for {width}x{height}...") | |
| # Generate depth map | |
| depth_map_raw = self.mh.leres_detector(image) | |
| # Generate lineart map | |
| lineart_map_raw = self.mh.lineart_anime_detector(image) | |
| # Manually resize maps to match the exact output resolution | |
| depth_map = depth_map_raw.resize((width, height), Image.LANCZOS) | |
| lineart_map = lineart_map_raw.resize((width, height), Image.LANCZOS) | |
| return depth_map, lineart_map | |
| def predict( | |
| self, | |
| input_image, | |
| user_prompt="", | |
| negative_prompt="", | |
| guidance_scale=1.5, | |
| num_inference_steps=6, | |
| img2img_strength=0.3, | |
| depth_strength=0.3, | |
| lineart_strength=0.3, | |
| seed=-1 | |
| ): | |
| # 1. Pre-process Inputs | |
| print("Processing Input...") | |
| processed_image = resize_image_to_1mp(input_image) | |
| target_width, target_height = processed_image.size | |
| # 2. Get Face Info | |
| face_info = self.mh.get_face_info(processed_image) | |
| # 3. Generate Prompt | |
| if not user_prompt.strip(): | |
| try: | |
| generated_caption = get_caption(processed_image) | |
| final_prompt = f"{Config.STYLE_TRIGGER}, {generated_caption}" | |
| except Exception as e: | |
| print(f"Captioning failed: {e}, using default prompt.") | |
| final_prompt = f"{Config.STYLE_TRIGGER}, a beautiful pixel art image" | |
| else: | |
| final_prompt = f"{Config.STYLE_TRIGGER}, {user_prompt}" | |
| print(f"Prompt: {final_prompt}") | |
| # 4. Generate Control Maps | |
| print("Generating Control Maps (Depth, LineArt)...") | |
| depth_map, lineart_map = self.prepare_control_images(processed_image, target_width, target_height) | |
| # 5. Logic for Face vs No-Face | |
| if face_info is not None: | |
| print("Face detected: Applying InstantID with keypoints.") | |
| # Use Raw Embedding | |
| face_emb = torch.tensor( | |
| face_info['embedding'], | |
| dtype=Config.DTYPE, | |
| device=Config.DEVICE | |
| ).unsqueeze(0) | |
| face_kps = draw_kps(processed_image, face_info['kps']) | |
| controlnet_conditioning_scale = [0.8, depth_strength, lineart_strength] | |
| self.mh.pipeline.set_ip_adapter_scale(0.8) | |
| else: | |
| print("No face detected: Disabling InstantID.") | |
| face_emb = torch.zeros((1, 512), dtype=Config.DTYPE, device=Config.DEVICE) | |
| face_kps = Image.new('RGB', (target_width, target_height), (0, 0, 0)) | |
| controlnet_conditioning_scale = [0.0, depth_strength, lineart_strength] | |
| self.mh.pipeline.set_ip_adapter_scale(0.0) | |
| control_guidance_end = [0.3, 0.6, 0.6] | |
| if seed == -1 or seed is None: | |
| seed = torch.Generator().seed() | |
| generator = torch.Generator(device=Config.DEVICE).manual_seed(int(seed)) | |
| print(f"Using seed: {seed}") | |
| # 6. Run Inference | |
| print("Running pipeline...") | |
| result = self.mh.pipeline( | |
| prompt=final_prompt, | |
| negative_prompt=negative_prompt, | |
| image=processed_image, | |
| control_image=[face_kps, depth_map, lineart_map], | |
| image_embeds=face_emb, | |
| generator=generator, | |
| strength=img2img_strength, | |
| guidance_scale=guidance_scale, | |
| num_inference_steps=num_inference_steps, | |
| controlnet_conditioning_scale=controlnet_conditioning_scale, | |
| control_guidance_end=control_guidance_end, | |
| clip_skip=2, | |
| # --- FIX: Set eta to 0.0 to remove stochastic noise --- | |
| eta=0.0, | |
| # ------------------------------------------------------ | |
| ).images[0] | |
| return result |