Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| from config import Config | |
| from utils import resize_image_to_1mp, get_caption | |
| from PIL import Image | |
| class Generator: | |
| def __init__(self, model_handler): | |
| self.mh = model_handler | |
| def prepare_control_images(self, image, width, height): | |
| """ | |
| Generates conditioning maps, ensuring they are resized | |
| to the exact target dimensions (width, height). | |
| """ | |
| print(f"Generating control maps for {width}x{height}...") | |
| # --- MODIFIED: Call new detectors --- | |
| # Generate depth map | |
| depth_map_raw = self.mh.leres_detector(image) | |
| # Generate lineart map | |
| lineart_map_raw = self.mh.lineart_anime_detector(image) | |
| # --- END MODIFIED --- | |
| # Manually resize maps to match the exact output resolution | |
| depth_map = depth_map_raw.resize((width, height), Image.LANCZOS) | |
| lineart_map = lineart_map_raw.resize((width, height), Image.LANCZOS) | |
| return depth_map, lineart_map | |
| def predict( | |
| self, | |
| input_image, | |
| user_prompt="", | |
| guidance_scale=1.5, | |
| num_inference_steps=6, | |
| img2img_strength=0.3, | |
| depth_strength=0.3, | |
| lineart_strength=0.3 | |
| ): | |
| # 1. Pre-process Inputs | |
| print("Processing Input...") | |
| processed_image = resize_image_to_1mp(input_image) | |
| target_width, target_height = processed_image.size | |
| # 2. Get Face Embedding (Robust Mode) | |
| face_emb = self.mh.get_face_embedding(processed_image) | |
| # 3. Generate Prompt | |
| if not user_prompt.strip(): | |
| try: | |
| generated_caption = get_caption(processed_image) | |
| final_prompt = f"{Config.STYLE_TRIGGER}, {generated_caption}" | |
| except Exception as e: | |
| print(f"Captioning failed: {e}, using default prompt.") | |
| final_prompt = f"{Config.STYLE_TRIGGER}, a beautiful pixel art image" | |
| else: | |
| final_prompt = f"{Config.STYLE_TRIGGER}, {user_prompt}" | |
| print(f"Prompt: {final_prompt}") | |
| # 4. Generate Control Maps (Structure) | |
| print("Generating Control Maps (Depth, LineArt)...") | |
| depth_map, lineart_map = self.prepare_control_images(processed_image, target_width, target_height) | |
| # 5. Logic for Face vs No-Face | |
| # ControlNet order: [InstantID, Zoe, LineArt] | |
| if face_emb is not None: | |
| print("Face detected: Applying InstantID.") | |
| # Use strengths from UI | |
| controlnet_conditioning_scale = [0.6, depth_strength, lineart_strength] | |
| control_guidance_end = [0.3, 0.6, 0.6] # Stop InstantID early | |
| self.mh.pipeline.set_ip_adapter_scale(0.6) # Set IP-Adapter (likeness) strength | |
| else: | |
| print("No face detected: Disabling InstantID.") | |
| # Use strengths from UI, but keep InstantID at 0.0 | |
| controlnet_conditioning_scale = [0.0, depth_strength, lineart_strength] | |
| control_guidance_end = [0.3, 0.6, 0.6] | |
| self.mh.pipeline.set_ip_adapter_scale(0.0) | |
| # --- START FIX for NoneType Error --- | |
| # Create a dummy tensor instead of passing None | |
| # Shape is (batch_size, embedding_dim) | |
| face_emb = torch.zeros((1, 512), dtype=Config.DTYPE, device=Config.DEVICE) | |
| # --- END FIX --- | |
| # 6. Run Inference | |
| print("Running pipeline...") | |
| result = self.mh.pipeline( | |
| prompt=final_prompt, | |
| image=processed_image, # Base image for Img2Img | |
| control_image=[processed_image, depth_map, lineart_map], # ControlNet inputs | |
| image_embeds=face_emb, # Face embedding (or dummy) | |
| # --- Parameters from UI --- | |
| strength=img2img_strength, | |
| num_inference_steps=num_inference_steps, | |
| guidance_scale=guidance_scale, | |
| # --- End Parameters from UI --- | |
| controlnet_conditioning_scale=controlnet_conditioning_scale, | |
| control_guidance_end=control_guidance_end, | |
| clip_skip=2, | |
| # --- LoRA Strength REMOVED --- | |
| # No longer needed, as LoRA is fused into the model weights | |
| # cross_attention_kwargs={"scale": 1.25} | |
| ).images[0] | |
| return result |