Spaces:
Running
Running
| import gradio as gr | |
| import spaces | |
| import torch | |
| from diffusers import FluxKontextPipeline | |
| from diffusers.utils import load_image | |
| from PIL import Image | |
| import os | |
| # Style dictionary | |
| style_type_lora_dict = { | |
| "3D_Chibi": "3D_Chibi_lora_weights.safetensors", | |
| "American_Cartoon": "American_Cartoon_lora_weights.safetensors", | |
| "Chinese_Ink": "Chinese_Ink_lora_weights.safetensors", | |
| "Clay_Toy": "Clay_Toy_lora_weights.safetensors", | |
| "Fabric": "Fabric_lora_weights.safetensors", | |
| "Ghibli": "Ghibli_lora_weights.safetensors", | |
| "Irasutoya": "Irasutoya_lora_weights.safetensors", | |
| "Jojo": "Jojo_lora_weights.safetensors", | |
| "Oil_Painting": "Oil_Painting_lora_weights.safetensors", | |
| "Pixel": "Pixel_lora_weights.safetensors", | |
| "Snoopy": "Snoopy_lora_weights.safetensors", | |
| "Poly": "Poly_lora_weights.safetensors", | |
| "LEGO": "LEGO_lora_weights.safetensors", | |
| "Origami": "Origami_lora_weights.safetensors", | |
| "Pop_Art": "Pop_Art_lora_weights.safetensors", | |
| "Van_Gogh": "Van_Gogh_lora_weights.safetensors", | |
| "Paper_Cutting": "Paper_Cutting_lora_weights.safetensors", | |
| "Line": "Line_lora_weights.safetensors", | |
| "Vector": "Vector_lora_weights.safetensors", | |
| "Picasso": "Picasso_lora_weights.safetensors", | |
| "Macaron": "Macaron_lora_weights.safetensors", | |
| "Rick_Morty": "Rick_Morty_lora_weights.safetensors" | |
| } | |
| # Style descriptions | |
| style_descriptions = { | |
| "3D_Chibi": "Cute, miniature 3D character style with big heads", | |
| "American_Cartoon": "Classic American animation style", | |
| "Chinese_Ink": "Traditional Chinese ink painting aesthetic", | |
| "Clay_Toy": "Playful clay/plasticine toy appearance", | |
| "Fabric": "Soft, textile-like rendering", | |
| "Ghibli": "Studio Ghibli's distinctive anime style", | |
| "Irasutoya": "Simple, flat Japanese illustration style", | |
| "Jojo": "JoJo's Bizarre Adventure manga style", | |
| "Oil_Painting": "Classic oil painting texture and strokes", | |
| "Pixel": "Retro pixel art style", | |
| "Snoopy": "Peanuts comic strip style", | |
| "Poly": "Low-poly 3D geometric style", | |
| "LEGO": "LEGO brick construction style", | |
| "Origami": "Paper folding art style", | |
| "Pop_Art": "Bold, colorful pop art style", | |
| "Van_Gogh": "Van Gogh's expressive brushstroke style", | |
| "Paper_Cutting": "Paper cut-out art style", | |
| "Line": "Clean line art/sketch style", | |
| "Vector": "Clean vector graphics style", | |
| "Picasso": "Cubist art style inspired by Picasso", | |
| "Macaron": "Soft, pastel macaron-like style", | |
| "Rick_Morty": "Rick and Morty cartoon style" | |
| } | |
| # Mapping for thumbnail files | |
| thumbnail_mapping = { | |
| "3D_Chibi": "3D_Chibi.webp", | |
| "American_Cartoon": "american_cartoon.webp", | |
| "Chinese_Ink": "chinese_ink.webp", | |
| "Clay_Toy": "clay_toy.webp", | |
| "Fabric": "fabric.webp", | |
| "Ghibli": "ghibli.webp", | |
| "Irasutoya": "Irasutoya.webp", | |
| "Jojo": "jojo.webp", | |
| "Oil_Painting": "oil_painting.webp", | |
| "Pixel": "pixel.webp", | |
| "Snoopy": "snoopy.webp", | |
| "Poly": "poly.webp", | |
| "LEGO": "LEGO.webp", | |
| "Origami": "origami.webp", | |
| "Pop_Art": "pop-art.webp", | |
| "Van_Gogh": "van_gogh.webp", | |
| "Paper_Cutting": "Paper_Cutting.webp", | |
| "Line": "line.webp", | |
| "Vector": "vector.webp", | |
| "Picasso": "picasso.webp", | |
| "Macaron": "Macaron.webp", | |
| "Rick_Morty": "Rick_Morty.webp" | |
| } | |
| # Initialize pipeline globally | |
| pipeline = None | |
| pipeline_loaded = False | |
| def load_pipeline(): | |
| global pipeline, pipeline_loaded | |
| if pipeline is None: | |
| print("Loading FLUX.1-Kontext-dev model...") | |
| # HF_TOKEN 자동 감지 | |
| token = os.getenv("HF_TOKEN", True) | |
| pipeline = FluxKontextPipeline.from_pretrained( | |
| "black-forest-labs/FLUX.1-Kontext-dev", | |
| torch_dtype=torch.bfloat16, | |
| use_auth_token=token | |
| ) | |
| pipeline_loaded = True | |
| return pipeline | |
| def load_default_image(): | |
| """Load the default man.webp image""" | |
| if os.path.exists("man.webp"): | |
| try: | |
| return Image.open("man.webp") | |
| except Exception as e: | |
| print(f"Error loading default image: {e}") | |
| return None | |
| def style_transfer(input_image, style_name, prompt_suffix, num_inference_steps, guidance_scale, seed): | |
| """ | |
| Apply style transfer to the input image using selected style | |
| """ | |
| if input_image is None: | |
| gr.Warning("Please upload an image first!") | |
| return None | |
| try: | |
| # Load pipeline and move to GPU | |
| pipe = load_pipeline() | |
| pipe = pipe.to('cuda') | |
| # Enable memory efficient settings | |
| pipe.enable_model_cpu_offload() | |
| # Set seed for reproducibility | |
| generator = None | |
| if seed > 0: | |
| generator = torch.Generator(device="cuda").manual_seed(seed) | |
| # Process input image | |
| if isinstance(input_image, str): | |
| image = load_image(input_image) | |
| else: | |
| image = input_image | |
| # Ensure RGB and resize to 1024x1024 | |
| image = image.convert("RGB").resize((1024, 1024), Image.Resampling.LANCZOS) | |
| # Load the selected LoRA | |
| lora_filename = style_type_lora_dict[style_name] | |
| # Clear any previously loaded LoRA | |
| try: | |
| pipe.unload_lora_weights() | |
| except: | |
| pass | |
| # Load LoRA weights | |
| pipe.load_lora_weights( | |
| "Owen777/Kontext-Style-Loras", | |
| weight_name=lora_filename, | |
| adapter_name="style" | |
| ) | |
| pipe.set_adapters(["style"], adapter_weights=[1.0]) | |
| # Create prompt for style transformation | |
| style_name_readable = style_name.replace('_', ' ') | |
| prompt = f"Turn this image into the {style_name_readable} style." | |
| if prompt_suffix and prompt_suffix.strip(): | |
| prompt += f" {prompt_suffix.strip()}" | |
| print(f"Generating with prompt: {prompt}") | |
| # Generate the styled image | |
| result = pipe( | |
| image=image, | |
| prompt=prompt, | |
| guidance_scale=guidance_scale, | |
| num_inference_steps=num_inference_steps, | |
| generator=generator, | |
| height=1024, | |
| width=1024 | |
| ) | |
| # Clear GPU memory | |
| torch.cuda.empty_cache() | |
| return result.images[0] | |
| except Exception as e: | |
| print(f"Error: {str(e)}") | |
| gr.Error(f"Error during style transfer: {str(e)}") | |
| torch.cuda.empty_cache() | |
| return None | |
| def create_thumbnail_grid(): | |
| """Create a gallery of style thumbnails""" | |
| thumbnails = [] | |
| styles = list(style_type_lora_dict.keys()) | |
| for style in styles: | |
| thumbnail_file = thumbnail_mapping.get(style, "") | |
| if thumbnail_file and os.path.exists(thumbnail_file): | |
| try: | |
| img = Image.open(thumbnail_file) | |
| thumbnails.append((img, style.replace('_', ' '))) | |
| except Exception as e: | |
| print(f"Error loading thumbnail {thumbnail_file}: {e}") | |
| # Create placeholder if thumbnail fails to load | |
| placeholder = Image.new('RGB', (256, 256), color='lightgray') | |
| thumbnails.append((placeholder, style.replace('_', ' '))) | |
| else: | |
| # Create placeholder for missing thumbnails | |
| placeholder = Image.new('RGB', (256, 256), color='lightgray') | |
| thumbnails.append((placeholder, style.replace('_', ' '))) | |
| return thumbnails | |
| # Create Gradio interface | |
| with gr.Blocks(title="Flux Kontext Style LoRA", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("# 🎨 Flux Styler : Flux Kontext Style LoRA") | |
| # Thumbnail Grid Section | |
| with gr.Row(): | |
| style_gallery = gr.Gallery( | |
| value=create_thumbnail_grid(), | |
| label="Style Thumbnails", | |
| show_label=False, | |
| elem_id="style_gallery", | |
| columns=6, | |
| rows=4, | |
| object_fit="cover", | |
| height="auto", | |
| interactive=True, | |
| show_download_button=False | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| input_image = gr.Image( | |
| label="Input Image", | |
| type="pil", | |
| height=400, | |
| value=load_default_image() | |
| ) | |
| style_dropdown = gr.Dropdown( | |
| choices=list(style_type_lora_dict.keys()), | |
| value="Ghibli", | |
| label="Selected Style", | |
| elem_id="style_dropdown" | |
| ) | |
| style_info = gr.Textbox( | |
| label="Style Description", | |
| value=style_descriptions["Ghibli"], | |
| interactive=False, | |
| lines=2 | |
| ) | |
| prompt_suffix = gr.Textbox( | |
| label="Additional Instructions (Optional)", | |
| placeholder="Add extra details...", | |
| lines=2 | |
| ) | |
| with gr.Accordion("Advanced Settings", open=False): | |
| num_steps = gr.Slider( | |
| minimum=10, | |
| maximum=50, | |
| value=24, | |
| step=1, | |
| label="Inference Steps" | |
| ) | |
| guidance = gr.Slider( | |
| minimum=1.0, | |
| maximum=5.0, | |
| value=2.5, | |
| step=0.1, | |
| label="Guidance Scale" | |
| ) | |
| seed = gr.Number( | |
| label="Seed", | |
| value=42, | |
| precision=0 | |
| ) | |
| generate_btn = gr.Button("🎨 Transform Image", variant="primary", size="lg") | |
| with gr.Column(scale=1): | |
| output_image = gr.Image( | |
| label="Styled Result", | |
| type="pil", | |
| height=400 | |
| ) | |
| # Handle gallery selection | |
| def on_gallery_select(evt: gr.SelectData): | |
| """Handle thumbnail selection from gallery""" | |
| selected_index = evt.index | |
| styles = list(style_type_lora_dict.keys()) | |
| if 0 <= selected_index < len(styles): | |
| selected_style = styles[selected_index] | |
| return selected_style, style_descriptions.get(selected_style, "") | |
| return None, None | |
| style_gallery.select( | |
| fn=on_gallery_select, | |
| inputs=None, | |
| outputs=[style_dropdown, style_info] | |
| ) | |
| # Update style description when style changes | |
| def update_description(style): | |
| return style_descriptions.get(style, "") | |
| style_dropdown.change( | |
| fn=update_description, | |
| inputs=[style_dropdown], | |
| outputs=[style_info] | |
| ) | |
| # Connect the generate button | |
| generate_btn.click( | |
| fn=style_transfer, | |
| inputs=[input_image, style_dropdown, prompt_suffix, num_steps, guidance, seed], | |
| outputs=output_image | |
| ) | |
| gr.Markdown(""" | |
| --- | |
| Powered by ❤️ https://discord.gg/openfreeai | |
| """) | |
| if __name__ == "__main__": | |
| demo.launch() |