Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import numpy as np | |
| import torch | |
| import random | |
| import os | |
| import spaces | |
| from PIL import Image, ImageOps, ImageFilter | |
| from diffusers import FluxPipeline, DiffusionPipeline | |
| import requests | |
| from io import BytesIO | |
| # Constants | |
| MAX_SEED = np.iinfo(np.int32).max | |
| HF_TOKEN = os.getenv("HF_TOKEN") | |
| # Model configuration | |
| KONTEXT_MODEL = "black-forest-labs/FLUX.1-Kontext-dev" | |
| FALLBACK_MODEL = "black-forest-labs/FLUX.1-dev" | |
| LORA_MODEL = "thedeoxen/refcontrol-flux-kontext-reference-pose-lora" | |
| TRIGGER_WORD = "refcontrolpose" | |
| # Initialize pipeline | |
| print("Loading models...") | |
| def load_pipeline(): | |
| """Load the appropriate pipeline based on availability""" | |
| global pipe, MODEL_STATUS | |
| try: | |
| # First, try to import necessary libraries | |
| try: | |
| from diffusers import FluxKontextPipeline | |
| import peft | |
| print("PEFT library found") | |
| use_kontext = True | |
| except ImportError: | |
| print("FluxKontextPipeline or PEFT not available, using fallback") | |
| use_kontext = False | |
| if use_kontext and HF_TOKEN: | |
| # Try to load Kontext model | |
| pipe = FluxKontextPipeline.from_pretrained( | |
| KONTEXT_MODEL, | |
| torch_dtype=torch.bfloat16, | |
| token=HF_TOKEN | |
| ) | |
| # Try to load LoRA if PEFT is available | |
| try: | |
| pipe.load_lora_weights( | |
| LORA_MODEL, | |
| adapter_name="refcontrol", | |
| token=HF_TOKEN | |
| ) | |
| MODEL_STATUS = "โ Flux Kontext + RefControl LoRA loaded" | |
| except Exception as e: | |
| print(f"Could not load LoRA: {e}") | |
| MODEL_STATUS = "โ ๏ธ Flux Kontext loaded (without LoRA - PEFT required)" | |
| pipe = pipe.to("cuda") | |
| else: | |
| # Fallback to standard FLUX | |
| pipe = FluxPipeline.from_pretrained( | |
| FALLBACK_MODEL, | |
| torch_dtype=torch.bfloat16, | |
| token=HF_TOKEN if HF_TOKEN else True | |
| ) | |
| pipe = pipe.to("cuda") | |
| MODEL_STATUS = "โ ๏ธ Using FLUX.1-dev (fallback mode)" | |
| except Exception as e: | |
| print(f"Error loading models: {e}") | |
| MODEL_STATUS = f"โ Error: {str(e)}" | |
| pipe = None | |
| return pipe, MODEL_STATUS | |
| # Load the pipeline | |
| pipe, MODEL_STATUS = load_pipeline() | |
| print(MODEL_STATUS) | |
| def prepare_images_for_kontext(reference_image, pose_image, target_size=512): | |
| """ | |
| Prepare reference and pose images for Kontext processing. | |
| Following the RefControl format: reference (left) | pose (right) | |
| """ | |
| if reference_image is None or pose_image is None: | |
| return None | |
| # Convert to RGB | |
| reference_image = reference_image.convert("RGB") | |
| pose_image = pose_image.convert("RGB") | |
| # Calculate dimensions maintaining aspect ratio | |
| ref_ratio = reference_image.width / reference_image.height | |
| pose_ratio = pose_image.width / pose_image.height | |
| # Set heights to target size | |
| height = target_size | |
| ref_width = int(height * ref_ratio) | |
| pose_width = int(height * pose_ratio) | |
| # Ensure dimensions are divisible by 8 (FLUX requirement) | |
| ref_width = (ref_width // 8) * 8 | |
| pose_width = (pose_width // 8) * 8 | |
| height = (height // 8) * 8 | |
| # Resize images | |
| reference_resized = reference_image.resize((ref_width, height), Image.LANCZOS) | |
| pose_resized = pose_image.resize((pose_width, height), Image.LANCZOS) | |
| # Concatenate horizontally: reference | pose | |
| total_width = ref_width + pose_width | |
| concatenated = Image.new('RGB', (total_width, height)) | |
| concatenated.paste(reference_resized, (0, 0)) | |
| concatenated.paste(pose_resized, (ref_width, 0)) | |
| return concatenated | |
| def process_pose_for_control(pose_image): | |
| """ | |
| Process pose image to ensure maximum contrast and clarity for control | |
| """ | |
| if pose_image is None: | |
| return None | |
| # Convert to grayscale first | |
| gray = pose_image.convert("L") | |
| # Apply strong edge detection | |
| edges = gray.filter(ImageFilter.FIND_EDGES) | |
| edges = edges.filter(ImageFilter.EDGE_ENHANCE_MORE) | |
| # Maximize contrast | |
| edges = ImageOps.autocontrast(edges, cutoff=2) | |
| # Convert to pure black and white | |
| threshold = 128 | |
| edges = edges.point(lambda x: 255 if x > threshold else 0, mode='1') | |
| # Convert back to RGB with inverted colors (black lines on white) | |
| edges = edges.convert("RGB") | |
| edges = ImageOps.invert(edges) | |
| return edges | |
| def generate_pose_transfer( | |
| reference_image, | |
| pose_image, | |
| prompt="", | |
| negative_prompt="", | |
| seed=42, | |
| randomize_seed=False, | |
| guidance_scale=7.5, # Increased for better pose adherence | |
| num_inference_steps=28, | |
| lora_scale=1.0, | |
| enhance_pose=False, | |
| progress=gr.Progress(track_tqdm=True) | |
| ): | |
| """ | |
| Main generation function using RefControl approach. | |
| """ | |
| if pipe is None: | |
| return None, 0, "Model not loaded. Please check HF_TOKEN and restart the Space" | |
| if reference_image is None or pose_image is None: | |
| raise gr.Error("Please upload both reference and pose images") | |
| # Randomize seed if requested | |
| if randomize_seed: | |
| seed = random.randint(0, MAX_SEED) | |
| # Enhance pose if requested | |
| if enhance_pose: | |
| pose_image = process_pose_for_control(pose_image) | |
| # Prepare concatenated input with fixed size | |
| concatenated_input = prepare_images_for_kontext(reference_image, pose_image, target_size=512) | |
| if concatenated_input is None: | |
| raise gr.Error("Failed to process images") | |
| # Ensure dimensions are model-compatible | |
| width, height = concatenated_input.size | |
| # Round to nearest 64 pixels for stability | |
| width = (width // 64) * 64 | |
| height = (height // 64) * 64 | |
| # Limit maximum size to prevent memory issues | |
| max_size = 1024 | |
| if width > max_size: | |
| ratio = max_size / width | |
| width = max_size | |
| height = int(height * ratio) | |
| height = (height // 64) * 64 | |
| if height > max_size: | |
| ratio = max_size / height | |
| height = max_size | |
| width = int(width * ratio) | |
| width = (width // 64) * 64 | |
| # Resize if needed | |
| if (width, height) != concatenated_input.size: | |
| concatenated_input = concatenated_input.resize((width, height), Image.LANCZOS) | |
| # Construct prompt with trigger word - CRITICAL FOR POSE CONTROL | |
| # The prompt must explicitly describe the pose transfer task | |
| base_instruction = f"{TRIGGER_WORD}, A photo composed of two images side by side. Left: reference person. Right: target pose skeleton. Task: Generate the person from the left image in the exact pose shown in the right image" | |
| if prompt: | |
| full_prompt = f"{base_instruction}. Additional details: {prompt}" | |
| else: | |
| full_prompt = base_instruction | |
| # Add strong pose control instructions | |
| full_prompt += ". IMPORTANT: Strictly follow the pose/skeleton from the right image while preserving the identity, clothing, and appearance from the left image. The output should show ONLY the transformed person, not the side-by-side layout." | |
| # Set generator for reproducibility | |
| generator = torch.Generator("cuda").manual_seed(seed) | |
| try: | |
| # Check if we have LoRA capabilities | |
| has_lora = hasattr(pipe, 'set_adapters') and "LoRA" in MODEL_STATUS | |
| # Set LoRA with higher strength for better pose control | |
| if has_lora: | |
| try: | |
| # Increase LoRA strength for pose control | |
| actual_lora_scale = lora_scale * 1.5 # Boost LoRA influence | |
| pipe.set_adapters(["refcontrol"], adapter_weights=[actual_lora_scale]) | |
| print(f"LoRA adapter set with boosted strength: {actual_lora_scale}") | |
| except Exception as e: | |
| print(f"LoRA adapter not set: {e}") | |
| print(f"Generating with size: {width}x{height}") | |
| print(f"Prompt: {full_prompt[:200]}...") | |
| # Generate image with stronger pose control | |
| with torch.cuda.amp.autocast(dtype=torch.bfloat16): | |
| if "Kontext" in MODEL_STATUS: | |
| # Use Kontext pipeline - removed unsupported controlnet_conditioning_scale | |
| result = pipe( | |
| image=concatenated_input, | |
| prompt=full_prompt, | |
| negative_prompt=negative_prompt if negative_prompt else "blurry, distorted, deformed, wrong pose, incorrect posture", | |
| guidance_scale=guidance_scale, # Higher for better control | |
| num_inference_steps=num_inference_steps, | |
| generator=generator, | |
| width=width, | |
| height=height, | |
| ).images[0] | |
| else: | |
| # Use standard FLUX pipeline | |
| result = pipe( | |
| prompt=full_prompt, | |
| negative_prompt=negative_prompt if negative_prompt else "", | |
| image=concatenated_input, | |
| guidance_scale=guidance_scale, | |
| num_inference_steps=num_inference_steps, | |
| generator=generator, | |
| strength=0.85, | |
| ).images[0] | |
| print("Generation successful!") | |
| return result, seed, concatenated_input | |
| except RuntimeError as e: | |
| if "out of memory" in str(e).lower(): | |
| raise gr.Error("GPU out of memory. Try reducing image size or inference steps.") | |
| else: | |
| raise gr.Error(f"Generation failed: {str(e)}") | |
| except Exception as e: | |
| print(f"Error details: {e}") | |
| raise gr.Error(f"Generation failed: {str(e)}") | |
| # CSS styling | |
| css = """ | |
| #col-container { | |
| margin: 0 auto; | |
| max-width: 1280px; | |
| } | |
| .header { | |
| text-align: center; | |
| background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); | |
| padding: 20px; | |
| border-radius: 12px; | |
| margin-bottom: 20px; | |
| } | |
| .header h1 { | |
| color: white; | |
| margin: 0; | |
| font-size: 2em; | |
| } | |
| .status-box { | |
| padding: 10px; | |
| border-radius: 8px; | |
| margin: 10px 0; | |
| font-weight: bold; | |
| text-align: center; | |
| } | |
| .input-image { | |
| border: 2px solid #e0e0e0; | |
| border-radius: 8px; | |
| overflow: hidden; | |
| } | |
| .result-image { | |
| border: 3px solid #4CAF50; | |
| border-radius: 8px; | |
| overflow: hidden; | |
| } | |
| .info-box { | |
| background: #f0f0f0; | |
| padding: 10px; | |
| border-radius: 8px; | |
| margin: 10px 0; | |
| } | |
| """ | |
| # Create Gradio interface | |
| with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo: | |
| with gr.Column(elem_id="col-container"): | |
| # Header | |
| gr.HTML(""" | |
| <div class="header"> | |
| <h1>๐ญ FLUX Pose Transfer System</h1> | |
| <p style="color: white;">Transfer poses while preserving identity</p> | |
| </div> | |
| """) | |
| # Model status | |
| status_color = "#d4edda" if "โ " in MODEL_STATUS else "#fff3cd" if "โ ๏ธ" in MODEL_STATUS else "#f8d7da" | |
| gr.HTML(f""" | |
| <div class="status-box" style="background: {status_color};"> | |
| {MODEL_STATUS} | |
| </div> | |
| """) | |
| # Authentication check | |
| if not HF_TOKEN: | |
| gr.Markdown(""" | |
| ### ๐ Authentication Required | |
| To use this Space with full features: | |
| 1. Go to **Settings** โ **Variables and secrets** | |
| 2. Add `HF_TOKEN` with your Hugging Face token | |
| 3. Restart the Space | |
| Or click below to sign in: | |
| """) | |
| gr.LoginButton("Sign in with Hugging Face", size="lg") | |
| # Info box for PEFT requirement | |
| if "PEFT required" in MODEL_STATUS: | |
| gr.HTML(""" | |
| <div class="info-box"> | |
| <b>Note:</b> For full LoRA support, PEFT library is required. | |
| Add <code>peft</code> to your requirements.txt file. | |
| </div> | |
| """) | |
| # Main interface | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("### ๐ฅ Input Images") | |
| # Reference image | |
| reference_image = gr.Image( | |
| label="Reference Image (Subject to transform)", | |
| type="pil", | |
| elem_classes=["input-image"], | |
| height=300 | |
| ) | |
| # Pose image | |
| pose_image = gr.Image( | |
| label="Pose Control (Line art or skeleton)", | |
| type="pil", | |
| elem_classes=["input-image"], | |
| height=300 | |
| ) | |
| # Pose extraction tool | |
| with gr.Accordion("๐ง Extract Pose from Image", open=False): | |
| extract_source = gr.Image( | |
| label="Source image for pose extraction", | |
| type="pil", | |
| height=200 | |
| ) | |
| extract_btn = gr.Button("Extract Pose", size="sm") | |
| # Prompts | |
| prompt = gr.Textbox( | |
| label=f"Prompt ('{TRIGGER_WORD}' added automatically)", | |
| placeholder="e.g., wearing elegant dress, professional photography", | |
| lines=2 | |
| ) | |
| negative_prompt = gr.Textbox( | |
| label="Negative Prompt (optional)", | |
| placeholder="e.g., blurry, low quality, distorted", | |
| lines=1, | |
| value="blurry, low quality, distorted, deformed" | |
| ) | |
| # Generate button | |
| generate_btn = gr.Button( | |
| "๐จ Generate Pose Transfer", | |
| variant="primary", | |
| size="lg" | |
| ) | |
| # Advanced settings | |
| with gr.Accordion("โ๏ธ Advanced Settings", open=False): | |
| with gr.Row(): | |
| seed = gr.Slider( | |
| label="Seed", | |
| minimum=0, | |
| maximum=MAX_SEED, | |
| step=1, | |
| value=42 | |
| ) | |
| randomize_seed = gr.Checkbox( | |
| label="Randomize", | |
| value=True | |
| ) | |
| guidance_scale = gr.Slider( | |
| label="Guidance Scale", | |
| minimum=5.0, | |
| maximum=15.0, | |
| step=0.5, | |
| value=7.5, | |
| info="Higher = stricter pose following (7-10 recommended)" | |
| ) | |
| num_inference_steps = gr.Slider( | |
| label="Inference Steps", | |
| minimum=20, | |
| maximum=50, | |
| step=1, | |
| value=30 | |
| ) | |
| if "LoRA" in MODEL_STATUS: | |
| lora_scale = gr.Slider( | |
| label="LoRA Strength", | |
| minimum=0.5, | |
| maximum=2.0, | |
| step=0.1, | |
| value=1.2, | |
| info="RefControl LoRA influence (1.0-1.5 recommended)" | |
| ) | |
| else: | |
| lora_scale = gr.Slider( | |
| label="LoRA Strength (not available)", | |
| minimum=0.0, | |
| maximum=2.0, | |
| step=0.1, | |
| value=1.0, | |
| interactive=False | |
| ) | |
| enhance_pose = gr.Checkbox( | |
| label="Auto-enhance pose edges", | |
| value=False | |
| ) | |
| with gr.Column(scale=1): | |
| gr.Markdown("### ๐ผ๏ธ Result") | |
| # Result image | |
| result_image = gr.Image( | |
| label="Generated Image", | |
| elem_classes=["result-image"], | |
| interactive=False, | |
| height=500 | |
| ) | |
| # Seed display | |
| seed_used = gr.Number( | |
| label="Seed Used", | |
| interactive=False | |
| ) | |
| # Debug view | |
| with gr.Accordion("๐ Debug View", open=False): | |
| concat_preview = gr.Image( | |
| label="Input Concatenation (Reference | Pose)", | |
| height=200 | |
| ) | |
| # Action buttons | |
| with gr.Row(): | |
| reuse_ref_btn = gr.Button("โป๏ธ Use as Reference", size="sm") | |
| reuse_pose_btn = gr.Button("๐ Extract Pose", size="sm") | |
| clear_btn = gr.Button("๐๏ธ Clear All", size="sm") | |
| # Examples | |
| gr.Markdown("### ๐ก Example Prompts") | |
| gr.Examples( | |
| examples=[ | |
| ["professional portrait, studio lighting"], | |
| ["wearing red dress, outdoor garden"], | |
| ["business attire, office setting"], | |
| ["casual streetwear, urban background"], | |
| ["athletic wear, gym environment"], | |
| ], | |
| inputs=[prompt] | |
| ) | |
| # Instructions | |
| with gr.Accordion("๐ Instructions", open=False): | |
| gr.Markdown(f""" | |
| ## How to Use: | |
| 1. **Upload Reference Image**: The person whose appearance you want to keep | |
| 2. **Upload Pose Image**: Line art or skeleton pose to follow | |
| 3. **Add Prompt** (optional): Describe additional details | |
| 4. **Click Generate**: Create your pose-transferred image | |
| ## Model Information: | |
| - **Current Model**: {MODEL_STATUS} | |
| - **Trigger Word**: `{TRIGGER_WORD}` (added automatically) | |
| ## Tips: | |
| - Use clear, high-contrast pose images | |
| - Black lines on white background work best for poses | |
| - Adjust guidance scale for pose adherence strength | |
| - Higher steps = better quality but slower | |
| ## Requirements: | |
| - **HF_TOKEN**: Required for model access | |
| - **peft**: Required for LoRA support (add to requirements.txt) | |
| """) | |
| # Event handlers | |
| generate_btn.click( | |
| fn=generate_pose_transfer, | |
| inputs=[ | |
| reference_image, | |
| pose_image, | |
| prompt, | |
| negative_prompt, | |
| seed, | |
| randomize_seed, | |
| guidance_scale, | |
| num_inference_steps, | |
| lora_scale, | |
| enhance_pose | |
| ], | |
| outputs=[result_image, seed_used, concat_preview] | |
| ) | |
| extract_btn.click( | |
| fn=process_pose_for_control, | |
| inputs=[extract_source], | |
| outputs=[pose_image] | |
| ) | |
| reuse_ref_btn.click( | |
| fn=lambda x: x, | |
| inputs=[result_image], | |
| outputs=[reference_image] | |
| ) | |
| reuse_pose_btn.click( | |
| fn=process_pose_for_control, | |
| inputs=[result_image], | |
| outputs=[pose_image] | |
| ) | |
| clear_btn.click( | |
| fn=lambda: [None, None, "", "blurry, low quality, distorted, deformed", 42, None, None], | |
| outputs=[ | |
| reference_image, | |
| pose_image, | |
| prompt, | |
| negative_prompt, | |
| seed_used, | |
| result_image, | |
| concat_preview | |
| ] | |
| ) | |
| # Launch the app | |
| if __name__ == "__main__": | |
| demo.queue() | |
| demo.launch() |