Spaces:
Running
on
Zero
Running
on
Zero
| import gradio as gr | |
| import torch | |
| import spaces | |
| from diffusers import FluxInpaintPipeline | |
| from PIL import Image #, ImageFile | |
| import io | |
| import numpy as np | |
| # Enable loading of truncated images | |
| # ImageFile.LOAD_TRUNCATED_IMAGES = True | |
| # Initialize the pipeline | |
| pipe = FluxInpaintPipeline.from_pretrained( | |
| "black-forest-labs/FLUX.1-dev", | |
| torch_dtype=torch.bfloat16 | |
| ) | |
| pipe.to("cuda") | |
| pipe.load_lora_weights( | |
| "ali-vilab/In-Context-LoRA", | |
| weight_name="visual-identity-design.safetensors" | |
| ) | |
| def safe_open_image(image): | |
| """Safely open and validate image""" | |
| try: | |
| if isinstance(image, np.ndarray): | |
| # Convert numpy array to PIL Image | |
| image = Image.fromarray(image) | |
| elif isinstance(image, bytes): | |
| # Handle bytes input | |
| image = Image.open(io.BytesIO(image)) | |
| # Ensure the image is in RGB mode | |
| if image.mode != 'RGB': | |
| image = image.convert('RGB') | |
| return image | |
| except Exception as e: | |
| raise ValueError(f"Error processing input image: {str(e)}") | |
| def square_center_crop(img, target_size=768): | |
| """Improved center crop with additional validation""" | |
| try: | |
| img = safe_open_image(img) | |
| # Ensure minimum size | |
| if img.size[0] < 64 or img.size[1] < 64: | |
| raise ValueError("Image is too small. Minimum size is 64x64 pixels.") | |
| width, height = img.size | |
| crop_size = min(width, height) | |
| # Calculate crop coordinates | |
| left = max(0, (width - crop_size) // 2) | |
| top = max(0, (height - crop_size) // 2) | |
| right = min(width, left + crop_size) | |
| bottom = min(height, top + crop_size) | |
| img_cropped = img.crop((left, top, right, bottom)) | |
| # Use high-quality resizing | |
| return img_cropped.resize( | |
| (target_size, target_size), | |
| Image.Resampling.LANCZOS, | |
| reducing_gap=3.0 | |
| ) | |
| except Exception as e: | |
| raise ValueError(f"Error during image cropping: {str(e)}") | |
| def duplicate_horizontally(img): | |
| """Improved horizontal duplication with validation""" | |
| try: | |
| width, height = img.size | |
| if width != height: | |
| raise ValueError(f"Input image must be square, got {width}x{height}") | |
| # Create new image with RGB mode explicitly | |
| new_image = Image.new('RGB', (width * 2, height)) | |
| # Ensure the source image is in RGB mode | |
| if img.mode != 'RGB': | |
| img = img.convert('RGB') | |
| new_image.paste(img, (0, 0)) | |
| new_image.paste(img, (width, 0)) | |
| return new_image | |
| except Exception as e: | |
| raise ValueError(f"Error during image duplication: {str(e)}") | |
| def safe_crop_output(img): | |
| """Safely crop the output image""" | |
| try: | |
| width, height = img.size | |
| half_width = width // 2 | |
| return img.crop((half_width, 0, width, height)) | |
| except Exception as e: | |
| raise ValueError(f"Error cropping output image: {str(e)}") | |
| # Load the mask image with error handling | |
| try: | |
| mask = Image.open("mask_square.png") | |
| if mask.mode != 'RGB': | |
| mask = mask.convert('RGB') | |
| except Exception as e: | |
| raise RuntimeError(f"Error loading mask image: {str(e)}") | |
| def generate(image, prompt_user, progress=gr.Progress(track_tqdm=True)): | |
| """Improved generation function with proper error handling""" | |
| try: | |
| if image is None: | |
| raise ValueError("No input image provided") | |
| if not prompt_user or prompt_user.strip() == "": | |
| raise ValueError("Please provide a prompt") | |
| prompt_structure = "The two-panel image showcases the logo of a brand, [LEFT] the left panel is showing the logo [RIGHT] the right panel has this logo applied to " | |
| prompt = prompt_structure + prompt_user | |
| # Process input image | |
| try: | |
| cropped_image = square_center_crop(image) | |
| except Exception as e: | |
| error_message = f"Error during cropping: {str(e)}" | |
| print(error_message) # For logging | |
| raise gr.Error(error_message) | |
| yield debug_resize, None, None, None | |
| print("Size after cropping", cropped_image.size) | |
| try: | |
| logo_dupli = duplicate_horizontally(cropped_image) | |
| except Exception as e: | |
| error_message = f"Error during duplication: {str(e)}" | |
| print(error_message) # For logging | |
| raise gr.Error(error_message) | |
| yield debug_resize, debug_duplicate, None, None | |
| print("just before getting into pipe") | |
| # Generate output | |
| out = pipe( | |
| prompt=prompt, | |
| image=logo_dupli, | |
| mask_image=mask, | |
| guidance_scale=6, | |
| height=768, | |
| width=1536, | |
| num_inference_steps=28, | |
| max_sequence_length=256, | |
| strength=1 | |
| ).images[0] | |
| # First yield for progress | |
| yield debug_resize, debug_duplicate, out, None | |
| # Process and return final output | |
| image_2 = safe_crop_output(out) | |
| yield debug_resize, debug_duplicate, out, image_2 | |
| except Exception as e: | |
| error_message = f"Error during generation: {str(e)}" | |
| print(error_message) # For logging | |
| raise gr.Error(error_message) | |
| # Create the Gradio interface | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# Logo in Context") | |
| gr.Markdown("### In-Context LoRA + Image-to-Image, apply your logo to anything") | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_image = gr.Image( | |
| label="Upload Logo Image", | |
| type="pil", | |
| height=384 | |
| ) | |
| prompt_input = gr.Textbox( | |
| label="Where should the logo be applied?", | |
| placeholder="e.g., a coffee cup on a wooden table", | |
| lines=2 | |
| ) | |
| generate_btn = gr.Button("Generate Application", variant="primary") | |
| with gr.Column(): | |
| output_image = gr.Image( | |
| label="Generated Application", | |
| type="pil" | |
| ) | |
| output_side = gr.Image( | |
| label="Side by side", | |
| type="pil" | |
| ) | |
| debug_resize = gr.Image() | |
| debug_duplicate = gr.Image() | |
| with gr.Row(): | |
| gr.Markdown(""" | |
| ### Instructions: | |
| 1. Upload a logo image (preferably square) | |
| 2. Describe where you'd like to see the logo applied | |
| 3. Click 'Generate Application' and wait for the result | |
| Note: The generation process might take a few moments. | |
| """) | |
| # Set up the click event with error handling | |
| generate_btn.click( | |
| fn=generate, | |
| inputs=[input_image, prompt_input], | |
| outputs=[debug_resize, debug_duplicate, output_side, output_image], | |
| api_name="generate" | |
| ) | |
| # Launch the interface | |
| if __name__ == "__main__": | |
| demo.launch() |