| import io |
| import gradio as gr |
| import numpy as np |
| import random |
| import spaces |
| import torch |
| from diffusers import Flux2KleinPipeline |
| from PIL import Image |
| import base64 |
| import gc |
|
|
| dtype = torch.bfloat16 |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
| MAX_SEED = np.iinfo(np.int32).max |
| MAX_IMAGE_SIZE = 1024 |
|
|
|
|
| |
| REPO_ID_REGULAR = "black-forest-labs/FLUX.2-klein-base-9B" |
| REPO_ID_DISTILLED = "black-forest-labs/FLUX.2-klein-9B" |
|
|
| |
| print("Loading 9B Regular model...") |
| pipe_regular = Flux2KleinPipeline.from_pretrained(REPO_ID_REGULAR, torch_dtype=dtype) |
| pipe_regular.to("cuda") |
|
|
| print("Loading 9B Distilled model...") |
| pipe_distilled = Flux2KleinPipeline.from_pretrained(REPO_ID_DISTILLED, torch_dtype=dtype) |
| pipe_distilled.to("cuda") |
|
|
| |
| pipes = { |
| "Distilled (4 steps)": pipe_distilled, |
| "Base (50 steps)": pipe_regular, |
| } |
|
|
| |
| DEFAULT_STEPS = { |
| "Distilled (4 steps)": 4, |
| "Base (50 steps)": 50, |
| } |
|
|
| DEFAULT_CFG = { |
| "Distilled (4 steps)": 1.0, |
| "Base (50 steps)": 4.0, |
| } |
|
|
| |
| try: |
| print("Loading NSFW detector...") |
| from transformers import AutoProcessor, AutoModelForImageClassification |
| nsfw_processor = AutoProcessor.from_pretrained("Falconsai/nsfw_image_detection") |
| nsfw_model = AutoModelForImageClassification.from_pretrained( |
| "Falconsai/nsfw_image_detection" |
| ).to(device) |
| print("NSFW detector loaded successfully.") |
| except Exception as e: |
| print(f"Failed to load NSFW detector: {e}") |
| nsfw_model = None |
| nsfw_processor = None |
| |
|
|
| class GenerationError(Exception): |
| """Custom exception for generation errors""" |
| pass |
|
|
| def detect_nsfw(image: Image.Image, threshold: float = 0.5) -> bool: |
| """Returns True if image is NSFW""" |
| inputs = nsfw_processor(images=image, return_tensors="pt").to(device) |
| with torch.no_grad(): |
| outputs = nsfw_model(**inputs) |
| probs = torch.nn.functional.softmax(outputs.logits, dim=-1) |
| nsfw_score = probs[0][1].item() |
| return nsfw_score > threshold |
|
|
|
|
| def image_to_data_uri(img): |
| """ |
| Convert a PIL Image to a base64 data URI. |
| |
| Args: |
| img: The PIL Image to convert. |
| |
| Returns: |
| str: A data URI string containing the base64-encoded PNG image. |
| """ |
| buffered = io.BytesIO() |
| img.save(buffered, format="PNG") |
| img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") |
| return f"data:image/png;base64,{img_str}" |
|
|
|
|
| def get_dimensions_from_image(image_list,width,height,match_input_ratio): |
| if image_list is None or len(image_list) == 0 or match_input_ratio == False: |
| return False, width, height, width, height |
| |
| |
| img = image_list[0][0] |
| img_width, img_height = img.size |
|
|
| aspect_ratio = img_width / img_height |
| |
| if aspect_ratio >= 1: |
| source_width = 1024 |
| source_height = int(1024 / aspect_ratio) |
| else: |
| source_height = 1024 |
| source_width = int(1024 * aspect_ratio) |
| |
| |
| new_width = round(source_width / 8) * 8 |
| new_height = round(source_height / 8) * 8 |
| |
| |
| new_width = max(256, min(1024, new_width)) |
| new_height = max(256, min(1024, new_height)) |
| |
| return True, new_width, new_height, source_width, source_height |
|
|
|
|
| def update_steps_from_mode(mode_choice): |
| """ |
| Update inference steps and guidance scale based on the selected mode. |
| |
| Args: |
| mode_choice (str): The selected mode, either "Distilled (4 steps)" or "Base (50 steps)". |
| |
| Returns: |
| tuple: A tuple of (num_inference_steps, guidance_scale). |
| """ |
| return DEFAULT_STEPS[mode_choice], DEFAULT_CFG[mode_choice] |
|
|
|
|
| progress=gr.Progress() |
| @spaces.GPU(duration=85) |
| def _infer( |
| prompt: str, |
| input_images=None, |
| mode_choice: str = "Distilled (4 steps)", |
| seed: int = 42, |
| randomize_seed: bool = False, |
| width: int = 1024, |
| height: int = 1024, |
| num_inference_steps: int = 4, |
| guidance_scale: float = 4.0, |
| match_input_ratio: bool = False, |
| ): |
| if progress: |
| progress(0, desc="Starting generation...") |
|
|
| def callback_fn(pipe, step, timestep, callback_kwargs): |
| print(f"[Step {step}] Timestep: {timestep}") |
| progress_value = (step+1.0)/num_inference_steps |
| progress(progress_value, desc=f"Image generating, {step + 1}/{num_inference_steps} steps") |
| return callback_kwargs |
|
|
|
|
| |
| if isinstance(seed, str): |
| seed = int(seed) |
| if isinstance(randomize_seed, str): |
| randomize_seed = randomize_seed.lower() == "true" |
| if isinstance(width, str): |
| width = int(width) |
| if isinstance(height, str): |
| height = int(height) |
| if isinstance(num_inference_steps, str): |
| num_inference_steps = int(num_inference_steps) |
| if isinstance(guidance_scale, str): |
| guidance_scale = float(guidance_scale) |
| if isinstance(match_input_ratio, str): |
| match_input_ratio = match_input_ratio.lower() == "true" |
|
|
| |
| if randomize_seed: |
| seed = random.randint(0, MAX_SEED) |
| |
| try: |
| |
| pipe = pipes[mode_choice] |
| |
| |
| image_list = None |
| if input_images is not None and len(input_images) > 0: |
| image_list = [] |
| for item in input_images: |
| image_list.append(item[0]) |
|
|
| matched, new_width, new_height, source_width, source_height = get_dimensions_from_image(input_images, width, height, match_input_ratio) |
|
|
| if image_list and len(image_list) > 0: |
| for img in image_list: |
| |
| if nsfw_model and nsfw_processor: |
| if detect_nsfw(img): |
| msg = "The input image contains NSFW content and cannot be generated. Please modify the input image or prompt and try again." |
| raise Exception(msg) |
| |
| |
| final_prompt = prompt |
| |
| generator = torch.Generator(device=device).manual_seed(seed) |
| |
| pipe_kwargs = { |
| "prompt": final_prompt, |
| "height": new_height, |
| "width": new_width, |
| "num_inference_steps": num_inference_steps, |
| "guidance_scale": guidance_scale, |
| "generator": generator, |
| "callback_on_step_end":callback_fn, |
| } |
| |
| |
| if image_list is not None: |
| pipe_kwargs["image"] = image_list |
| |
| image = pipe(**pipe_kwargs).images[0] |
|
|
| |
| if matched: |
| |
| image = image.resize((source_width, source_height), resample=Image.LANCZOS) |
| |
| |
| |
| if nsfw_model and nsfw_processor: |
| if detect_nsfw(image): |
| msg = "Generated image contains NSFW content and cannot be displayed. Please modify the input image or prompt and try again." |
| raise Exception(msg) |
| |
|
|
| progress(1, desc="Complete") |
|
|
| info = { |
| "status": "success" |
| } |
|
|
| return image, seed, info |
|
|
| except GenerationError as e: |
| error_info = { |
| "error": str(e), |
| "status": "failed", |
| } |
| return None, None, error_info |
| except Exception as e: |
| error_info = { |
| "error": str(e), |
| "status": "failed", |
| } |
| return None, None, error_info |
| finally: |
| |
| torch.cuda.empty_cache() |
| gc.collect() |
| |
| def infer( |
| prompt: str, |
| input_images=None, |
| mode_choice: str = "Distilled (4 steps)", |
| seed: int = 42, |
| randomize_seed: bool = False, |
| width: int = 1024, |
| height: int = 1024, |
| num_inference_steps: int = 4, |
| guidance_scale: float = 4.0, |
| match_input_ratio: bool = False, |
| ): |
| |
| image, seed, info = _infer(prompt,input_images,mode_choice,seed,randomize_seed,width,height,num_inference_steps,guidance_scale,match_input_ratio) |
| |
| if info["status"] == "failed": |
| raise gr.Error(info["error"]) |
| |
| return image, seed |
| |
|
|
| examples = [ |
| ["Create a vase on a table in living room, the color of the vase is a gradient of color, starting with #02eb3c color and finishing with #edfa3c. The flowers inside the vase have the color #ff0088"], |
| ["Photorealistic infographic showing the complete Berlin TV Tower (Fernsehturm) from ground base to antenna tip, full vertical view with entire structure visible including concrete shaft, metallic sphere, and antenna spire. Slight upward perspective angle looking up toward the iconic sphere, perfectly centered on clean white background. Left side labels with thin horizontal connector lines: the text '368m' in extra large bold dark grey numerals (#2D3748) positioned at exactly the antenna tip with 'TOTAL HEIGHT' in small caps below. The text '207m' in extra large bold with 'TELECAFÉ' in small caps below, with connector line touching the sphere precisely at the window level. Right side label with horizontal connector line touching the sphere's equator: the text '32m' in extra large bold dark grey numerals with 'SPHERE DIAMETER' in small caps below. Bottom section arranged in three balanced columns: Left - Large text '986' in extra bold dark grey with 'STEPS' in caps below. Center - 'BERLIN TV TOWER' in bold caps with 'FERNSEHTURM' in lighter weight below. Right - 'INAUGURATED' in bold caps with 'OCTOBER 3, 1969' below. All typography in modern sans-serif font (such as Inter or Helvetica), color #2D3748, clean minimal technical diagram style. Horizontal connector lines are thin, precise, and clearly visible, touching the tower structure at exact corresponding measurement points. Professional architectural elevation drawing aesthetic with dynamic low angle perspective creating sense of height and grandeur, poster-ready infographic design with perfect visual hierarchy."], |
| ["Soaking wet capybara taking shelter under a banana leaf in the rainy jungle, close up photo"], |
| ["A kawaii die-cut sticker of a chubby orange cat, featuring big sparkly eyes and a happy smile with paws raised in greeting and a heart-shaped pink nose. The design should have smooth rounded lines with black outlines and soft gradient shading with pink cheeks."], |
| ] |
|
|
| examples_images = [ |
| ["The person from image 1 is petting the cat from image 2, the bird from image 3 is next to them", ["woman1.webp", "cat_window.webp", "bird.webp"]] |
| ] |
|
|
| css = """ |
| #col-container { |
| margin: 0 auto; |
| max-width: 1200px; |
| } |
| """ |
|
|
| title = "# AI Image Generator from Image" |
| description = "Transform your photos into stunning art with our AI Image Generator! Simply upload an image to remix, restyle, and reimagine it instantly. This Space offers a limited free demo. If you reach your usage limit or want faster generation speeds, please visit our full platform at [AI Image Generator from Image](https://www.image2image.ai/) to continue creating." |
|
|
|
|
| with gr.Blocks(css=css) as demo: |
| with gr.Column(elem_id="col-container"): |
| gr.Markdown(title) |
| gr.Markdown(description) |
| with gr.Row(): |
| with gr.Column(): |
| gr.Markdown("### Input") |
| with gr.Row(): |
| prompt = gr.Text( |
| label="Prompt", |
| show_label=False, |
| max_lines=2, |
| placeholder="Enter your prompt", |
| container=False, |
| scale=3 |
| ) |
| |
| run_button = gr.Button("Run", scale=1) |
| |
| with gr.Accordion("Input image(s) (optional)", open=True): |
| input_images = gr.Gallery( |
| label="Input Image(s)", |
| type="pil", |
| columns=3, |
| rows=1, |
| ) |
|
|
| mode_choice = gr.Radio( |
| label="Mode", |
| choices=["Distilled (4 steps)", "Base (50 steps)"], |
| value="Distilled (4 steps)", |
| ) |
| |
| with gr.Accordion("Advanced Settings", open=False): |
| |
| seed = gr.Slider( |
| label="Seed", |
| minimum=0, |
| maximum=MAX_SEED, |
| step=1, |
| value=0, |
| ) |
| |
| randomize_seed = gr.Checkbox(label="Randomize seed", value=True) |
| |
| with gr.Group(): |
| with gr.Row(): |
| width = gr.Slider( |
| label="Width", |
| minimum=256, |
| maximum=MAX_IMAGE_SIZE, |
| step=8, |
| value=1024, |
| ) |
| |
| height = gr.Slider( |
| label="Height", |
| minimum=256, |
| maximum=MAX_IMAGE_SIZE, |
| step=8, |
| value=1024, |
| ) |
| match_input = gr.Checkbox(label="Match input ratio", value=False) |
| |
| with gr.Row(): |
| |
| num_inference_steps = gr.Slider( |
| label="Number of inference steps", |
| minimum=1, |
| maximum=100, |
| step=1, |
| value=4, |
| ) |
| |
| guidance_scale = gr.Slider( |
| label="Guidance scale", |
| minimum=0.0, |
| maximum=10.0, |
| step=0.1, |
| value=1.0, |
| ) |
| |
| |
| with gr.Column(): |
| gr.Markdown("### Output") |
| result = gr.Image(label="Result", show_label=False) |
| |
| |
| gr.Examples( |
| examples=examples, |
| fn=infer, |
| inputs=[prompt], |
| outputs=[result, seed], |
| cache_examples=True, |
| cache_mode="lazy" |
| ) |
|
|
| gr.Examples( |
| examples=examples_images, |
| fn=infer, |
| inputs=[prompt, input_images], |
| outputs=[result, seed], |
| cache_examples=True, |
| cache_mode="lazy" |
| ) |
|
|
| |
| mode_choice.change( |
| fn=update_steps_from_mode, |
| inputs=[mode_choice], |
| outputs=[num_inference_steps, guidance_scale] |
| ) |
|
|
| gr.on( |
| triggers=[run_button.click, prompt.submit], |
| fn=infer, |
| inputs=[prompt, input_images, mode_choice, seed, randomize_seed, width, height, num_inference_steps, guidance_scale, match_input], |
| outputs=[result, seed], |
| api_name="generate" |
| ) |
|
|
| |
| demo.launch(mcp_server=True) |