Spaces:
Runtime error
Runtime error
| import logging | |
| import random | |
| import warnings | |
| import os | |
| import gradio as gr | |
| import numpy as np | |
| import spaces | |
| import torch | |
| from diffusers import FluxImg2ImgPipeline | |
| from transformers import AutoProcessor, AutoModelForCausalLM | |
| from gradio_imageslider import ImageSlider | |
| from PIL import Image | |
| from huggingface_hub import snapshot_download | |
| import requests | |
| css = """ | |
| #col-container { | |
| margin: 0 auto; | |
| max-width: 800px; | |
| } | |
| .main-header { | |
| text-align: center; | |
| margin-bottom: 2rem; | |
| } | |
| """ | |
| # Device setup | |
| if torch.cuda.is_available(): | |
| power_device = "GPU" | |
| device = "cuda" | |
| else: | |
| power_device = "CPU" | |
| device = "cpu" | |
| # Get HuggingFace token | |
| huggingface_token = os.getenv("HF_TOKEN") | |
| # Download FLUX model | |
| print("๐ฅ Downloading FLUX model...") | |
| model_path = snapshot_download( | |
| repo_id="black-forest-labs/FLUX.1-dev", | |
| repo_type="model", | |
| ignore_patterns=["*.md", "*.gitattributes"], | |
| local_dir="FLUX.1-dev", | |
| token=huggingface_token, | |
| ) | |
| # Load Florence-2 model for image captioning | |
| print("๐ฅ Loading Florence-2 model...") | |
| florence_model = AutoModelForCausalLM.from_pretrained( | |
| "microsoft/Florence-2-large", | |
| torch_dtype=torch.float16, | |
| trust_remote_code=True, | |
| attn_implementation="eager" # Fix for SDPA compatibility issue | |
| ).to(device) | |
| florence_processor = AutoProcessor.from_pretrained( | |
| "microsoft/Florence-2-large", | |
| trust_remote_code=True | |
| ) | |
| # Load FLUX Img2Img pipeline | |
| print("๐ฅ Loading FLUX Img2Img...") | |
| pipe = FluxImg2ImgPipeline.from_pretrained( | |
| model_path, | |
| torch_dtype=torch.bfloat16 | |
| ) | |
| pipe.to(device) | |
| pipe.enable_vae_tiling() | |
| pipe.enable_vae_slicing() | |
| print("โ All models loaded successfully!") | |
| MAX_SEED = 1000000 | |
| MAX_PIXEL_BUDGET = 4096 * 4096 | |
| def generate_caption(image): | |
| """Generate detailed caption using Florence-2""" | |
| try: | |
| task_prompt = "<MORE_DETAILED_CAPTION>" | |
| prompt = task_prompt | |
| inputs = florence_processor(text=prompt, images=image, return_tensors="pt").to(device) | |
| generated_ids = florence_model.generate( | |
| input_ids=inputs["input_ids"], | |
| pixel_values=inputs["pixel_values"], | |
| max_new_tokens=1024, | |
| num_beams=3, | |
| do_sample=True, | |
| ) | |
| generated_text = florence_processor.batch_decode(generated_ids, skip_special_tokens=False)[0] | |
| parsed_answer = florence_processor.post_process_generation(generated_text, task=task_prompt, image_size=(image.width, image.height)) | |
| caption = parsed_answer[task_prompt] | |
| return caption | |
| except Exception as e: | |
| print(f"Caption generation failed: {e}") | |
| return "a high quality detailed image" | |
| def process_input(input_image, upscale_factor): | |
| """Process input image and handle size constraints""" | |
| w, h = input_image.size | |
| w_original, h_original = w, h | |
| aspect_ratio = w / h | |
| was_resized = False | |
| if w * h * upscale_factor**2 > MAX_PIXEL_BUDGET: | |
| warnings.warn( | |
| f"Requested output image is too large ({w * upscale_factor}x{h * upscale_factor}). Resizing to fit budget." | |
| ) | |
| gr.Info( | |
| f"Requested output image is too large. Resizing input to fit within pixel budget." | |
| ) | |
| input_image = input_image.resize( | |
| ( | |
| int(aspect_ratio * MAX_PIXEL_BUDGET**0.5 // upscale_factor), | |
| int(MAX_PIXEL_BUDGET**0.5 // aspect_ratio // upscale_factor), | |
| ), | |
| resample=Image.LANCZOS | |
| ) | |
| was_resized = True | |
| # Resize to multiple of 8 | |
| w, h = input_image.size | |
| w = w - w % 8 | |
| h = h - h % 8 | |
| return input_image.resize((w, h), resample=Image.LANCZOS), w_original, h_original, was_resized | |
| def load_image_from_url(url): | |
| """Load image from URL""" | |
| try: | |
| response = requests.get(url) | |
| response.raise_for_status() | |
| return Image.open(requests.get(url, stream=True).raw) | |
| except Exception as e: | |
| raise gr.Error(f"Failed to load image from URL: {e}") | |
| def enhance_image( | |
| image_input, | |
| image_url, | |
| seed, | |
| randomize_seed, | |
| num_inference_steps, | |
| upscale_factor, | |
| guidance_scale, | |
| denoising_strength, | |
| use_generated_caption, | |
| custom_prompt, | |
| progress=gr.Progress(track_tqdm=True), | |
| ): | |
| """Main enhancement function""" | |
| # Handle image input | |
| if image_input is not None: | |
| input_image = image_input | |
| elif image_url: | |
| input_image = load_image_from_url(image_url) | |
| else: | |
| raise gr.Error("Please provide an image (upload or URL)") | |
| if randomize_seed: | |
| seed = random.randint(0, MAX_SEED) | |
| true_input_image = input_image | |
| # Process input image | |
| input_image, w_original, h_original, was_resized = process_input( | |
| input_image, upscale_factor | |
| ) | |
| # Generate caption if requested | |
| if use_generated_caption: | |
| gr.Info("๐ Generating image caption...") | |
| generated_caption = generate_caption(input_image) | |
| prompt = generated_caption | |
| else: | |
| prompt = custom_prompt if custom_prompt.strip() else "" | |
| # Rescale with upscale factor using LANCZOS | |
| w, h = input_image.size | |
| control_image = input_image.resize((w * upscale_factor, h * upscale_factor), resample=Image.LANCZOS) | |
| generator = torch.Generator().manual_seed(seed) | |
| gr.Info("๐ Upscaling image...") | |
| # Generate upscaled image | |
| image = pipe( | |
| prompt=prompt, | |
| image=control_image, | |
| strength=denoising_strength, | |
| num_inference_steps=num_inference_steps, | |
| guidance_scale=guidance_scale, | |
| height=control_image.size[1], | |
| width=control_image.size[0], | |
| generator=generator, | |
| ).images[0] | |
| if was_resized: | |
| gr.Info(f"๐ Resizing output to target size: {w_original * upscale_factor}x{h_original * upscale_factor}") | |
| # Resize to target desired size | |
| final_image = image.resize((w_original * upscale_factor, h_original * upscale_factor), resample=Image.LANCZOS) | |
| return [true_input_image, final_image], seed, generated_caption if use_generated_caption else "" | |
| # Create Gradio interface | |
| with gr.Blocks(css=css, title="๐จ AI Image Enhancer - Florence-2 + FLUX") as demo: | |
| gr.HTML(""" | |
| <div class="main-header"> | |
| <h1>๐จ AI Image Enhancer</h1> | |
| <p>Upload an image or provide a URL to enhance it using Florence-2 captioning and FLUX upscaling</p> | |
| <p>Currently running on <strong>{}</strong></p> | |
| </div> | |
| """.format(power_device)) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.HTML("<h3>๐ค Input</h3>") | |
| with gr.Tabs(): | |
| with gr.TabItem("๐ Upload Image"): | |
| input_image = gr.Image( | |
| label="Upload Image", | |
| type="pil", | |
| height=300 | |
| ) | |
| with gr.TabItem("๐ Image URL"): | |
| image_url = gr.Textbox( | |
| label="Image URL", | |
| placeholder="https://example.com/image.jpg", | |
| value="https://upload.wikimedia.org/wikipedia/commons/thumb/a/a7/Example.jpg/800px-Example.jpg" | |
| ) | |
| gr.HTML("<h3>๐๏ธ Caption Settings</h3>") | |
| use_generated_caption = gr.Checkbox( | |
| label="Use AI-generated caption (Florence-2)", | |
| value=True, | |
| info="Generate detailed caption automatically" | |
| ) | |
| custom_prompt = gr.Textbox( | |
| label="Custom Prompt (optional)", | |
| placeholder="Enter custom prompt or leave empty for generated caption", | |
| lines=2 | |
| ) | |
| gr.HTML("<h3>โ๏ธ Enhancement Settings</h3>") | |
| upscale_factor = gr.Slider( | |
| label="Upscale Factor", | |
| minimum=1, | |
| maximum=4, | |
| step=1, | |
| value=2, | |
| info="How much to upscale the image" | |
| ) | |
| num_inference_steps = gr.Slider( | |
| label="Number of Inference Steps", | |
| minimum=8, | |
| maximum=50, | |
| step=1, | |
| value=25, | |
| info="More steps = better quality but slower" | |
| ) | |
| guidance_scale = gr.Slider( | |
| label="Guidance Scale", | |
| minimum=1.0, | |
| maximum=10.0, | |
| step=0.5, | |
| value=3.5, | |
| info="How closely to follow the prompt" | |
| ) | |
| denoising_strength = gr.Slider( | |
| label="Denoising Strength", | |
| minimum=0.0, | |
| maximum=1.0, | |
| step=0.05, | |
| value=0.3, | |
| info="Controls how much the image is transformed" | |
| ) | |
| with gr.Row(): | |
| randomize_seed = gr.Checkbox( | |
| label="Randomize seed", | |
| value=True | |
| ) | |
| seed = gr.Slider( | |
| label="Seed", | |
| minimum=0, | |
| maximum=MAX_SEED, | |
| step=1, | |
| value=42, | |
| interactive=True | |
| ) | |
| enhance_btn = gr.Button( | |
| "๐ Enhance Image", | |
| variant="primary", | |
| size="lg" | |
| ) | |
| with gr.Column(scale=1): | |
| gr.HTML("<h3>๐ Results</h3>") | |
| result_slider = ImageSlider( | |
| label="Input / Enhanced", | |
| type="pil", | |
| interactive=True, | |
| height=400 | |
| ) | |
| with gr.Row(): | |
| output_seed = gr.Number( | |
| label="Used Seed", | |
| precision=0, | |
| interactive=False | |
| ) | |
| generated_caption_output = gr.Textbox( | |
| label="Generated Caption", | |
| placeholder="AI-generated caption will appear here...", | |
| lines=3, | |
| interactive=False | |
| ) | |
| # Examples | |
| gr.Examples( | |
| examples=[ | |
| [None, "https://upload.wikimedia.org/wikipedia/commons/thumb/a/a7/Example.jpg/800px-Example.jpg", 42, False, 25, 2, 3.5, 0.3, True, ""], | |
| [None, "https://picsum.photos/512/512", 123, False, 25, 3, 4.0, 0.3, True, ""], | |
| ], | |
| inputs=[ | |
| input_image, | |
| image_url, | |
| seed, | |
| randomize_seed, | |
| num_inference_steps, | |
| upscale_factor, | |
| guidance_scale, | |
| denoising_strength, | |
| use_generated_caption, | |
| custom_prompt, | |
| ] | |
| ) | |
| # Event handler | |
| enhance_btn.click( | |
| fn=enhance_image, | |
| inputs=[ | |
| input_image, | |
| image_url, | |
| seed, | |
| randomize_seed, | |
| num_inference_steps, | |
| upscale_factor, | |
| guidance_scale, | |
| denoising_strength, | |
| use_generated_caption, | |
| custom_prompt, | |
| ], | |
| outputs=[result_slider, output_seed, generated_caption_output] | |
| ) | |
| gr.HTML(""" | |
| <div style="margin-top: 2rem; padding: 1rem; background: #f0f0f0; border-radius: 8px;"> | |
| <h4>๐ก How it works:</h4> | |
| <ol> | |
| <li><strong>Florence-2</strong> analyzes your image and generates a detailed caption</li> | |
| <li>Initial upscale with LANCZOS interpolation</li> | |
| <li><strong>FLUX Img2Img</strong> enhances the upscaled image with AI diffusion guided by the caption</li> | |
| </ol> | |
| <p><strong>Note:</strong> Output limited to 4096x4096 pixels total budget to prevent memory issues.</p> | |
| </div> | |
| """) | |
| if __name__ == "__main__": | |
| demo.queue().launch(share=True, server_name="0.0.0.0", server_port=7860) |