|
|
|
|
|
import os |
|
|
os.system('pip install --upgrade --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu126 "torch<2.9" spaces') |
|
|
|
|
|
|
|
|
try: |
|
|
import spaces |
|
|
except: |
|
|
class spaces(): |
|
|
def GPU(*args, **kwargs): |
|
|
def decorator(function): |
|
|
return lambda *dummy_args, **dummy_kwargs: function(*dummy_args, **dummy_kwargs) |
|
|
return decorator |
|
|
|
|
|
import gradio as gr |
|
|
import numpy as np |
|
|
import torch |
|
|
import random |
|
|
import os |
|
|
from datetime import datetime |
|
|
|
|
|
from PIL import Image |
|
|
import tempfile |
|
|
import zipfile |
|
|
import shutil |
|
|
from pathlib import Path |
|
|
|
|
|
from diffusers import FluxKontextPipeline |
|
|
from diffusers.utils import load_image |
|
|
|
|
|
from optimization import optimize_pipeline_ |
|
|
|
|
|
MAX_SEED = np.iinfo(np.int32).max |
|
|
|
|
|
pipe = FluxKontextPipeline.from_pretrained("yuvraj108c/FLUX.1-Kontext-dev", torch_dtype=torch.bfloat16).to("cuda") |
|
|
optimize_pipeline_(pipe, image=Image.new("RGB", (512, 512)), prompt='prompt') |
|
|
|
|
|
input_image_debug_value = [None] |
|
|
prompt_debug_value = [None] |
|
|
number_debug_value = [None] |
|
|
def save_on_path(img: Image, filename: str, format_: str = None) -> Path: |
|
|
""" |
|
|
Save `img` in a unique temporary folder under the given `filename` |
|
|
and return its absolute path. |
|
|
""" |
|
|
|
|
|
tmp_dir = Path(tempfile.mkdtemp(prefix="pil_tmp_")) |
|
|
|
|
|
|
|
|
file_path = tmp_dir / filename |
|
|
|
|
|
|
|
|
img.save(file_path, format=format_ or img.format) |
|
|
|
|
|
return file_path |
|
|
|
|
|
@spaces.GPU(duration=40) |
|
|
def infer( |
|
|
input_image, |
|
|
prompt, |
|
|
seed = 42, |
|
|
randomize_seed = False, |
|
|
guidance_scale = 2.5, |
|
|
steps = 28, |
|
|
width = -1, |
|
|
height = -1, |
|
|
progress=gr.Progress(track_tqdm=True) |
|
|
): |
|
|
""" |
|
|
Perform image editing using the FLUX.1 Kontext pipeline. |
|
|
|
|
|
This function takes an input image and a text prompt to generate a modified version |
|
|
of the image based on the provided instructions. It uses the FLUX.1 Kontext model |
|
|
for contextual image editing tasks. |
|
|
|
|
|
Args: |
|
|
input_image (PIL.Image.Image): The input image to be edited. Will be converted |
|
|
to RGB format if not already in that format. |
|
|
prompt (str): Text description of the desired edit to apply to the image. |
|
|
Examples: "Remove glasses", "Add a hat", "Change background to beach". |
|
|
seed (int, optional): Random seed for reproducible generation. Defaults to 42. |
|
|
Must be between 0 and MAX_SEED (2^31 - 1). |
|
|
randomize_seed (bool, optional): If True, generates a random seed instead of |
|
|
using the provided seed value. Defaults to False. |
|
|
guidance_scale (float, optional): Controls how closely the model follows the |
|
|
prompt. Higher values mean stronger adherence to the prompt but may reduce |
|
|
image quality. Range: 1.0-10.0. Defaults to 2.5. |
|
|
steps (int, optional): Controls how many steps to run the diffusion model for. |
|
|
Range: 1-30. Defaults to 28. |
|
|
progress (gr.Progress, optional): Gradio progress tracker for monitoring |
|
|
generation progress. Defaults to gr.Progress(track_tqdm=True). |
|
|
|
|
|
Returns: |
|
|
tuple: A 3-tuple containing: |
|
|
- PIL.Image.Image: The generated/edited image |
|
|
- int: The seed value used for generation (useful when randomize_seed=True) |
|
|
- gr.update: Gradio update object to make the reuse button visible |
|
|
|
|
|
Example: |
|
|
>>> edited_image, used_seed, button_update = infer( |
|
|
... input_image=my_image, |
|
|
... prompt="Add sunglasses", |
|
|
... seed=123, |
|
|
... randomize_seed=False, |
|
|
... guidance_scale=2.5 |
|
|
... ) |
|
|
""" |
|
|
if randomize_seed: |
|
|
seed = random.randint(0, MAX_SEED) |
|
|
|
|
|
if input_image: |
|
|
input_image = input_image.convert("RGB") |
|
|
image = pipe( |
|
|
image=input_image, |
|
|
prompt=prompt, |
|
|
guidance_scale=guidance_scale, |
|
|
width = input_image.size[0] if width == -1 else width, |
|
|
height = input_image.size[1] if height == -1 else height, |
|
|
num_inference_steps=steps, |
|
|
generator=torch.Generator().manual_seed(seed), |
|
|
).images[0] |
|
|
else: |
|
|
image = pipe( |
|
|
prompt=prompt, |
|
|
guidance_scale=guidance_scale, |
|
|
num_inference_steps=steps, |
|
|
generator=torch.Generator().manual_seed(seed), |
|
|
).images[0] |
|
|
|
|
|
image_filename = datetime.now().strftime("%Y-%m-%d_%H-%M-%S.%f") + '.webp' |
|
|
path = save_on_path(image, image_filename, format_="WEBP") |
|
|
return path, gr.update(value=path, visible=True), seed, gr.update(visible=True) |
|
|
|
|
|
def infer_example(input_image, prompt): |
|
|
number=1 |
|
|
if input_image_debug_value[0] is not None or prompt_debug_value[0] is not None or number_debug_value[0] is not None: |
|
|
input_image=input_image_debug_value[0] |
|
|
prompt=prompt_debug_value[0] |
|
|
number=number_debug_value[0] |
|
|
|
|
|
gallery = [] |
|
|
try: |
|
|
for i in range(number): |
|
|
print("Generating #" + str(i + 1) + " image...") |
|
|
seed = random.randint(0, MAX_SEED) |
|
|
image, download_button, seed, _ = infer(input_image, prompt, seed, True) |
|
|
gallery.append(image) |
|
|
except: |
|
|
print("Error") |
|
|
zip_path = export_images_to_zip(gallery) |
|
|
return gallery, seed, zip_path |
|
|
|
|
|
def export_images_to_zip(gallery) -> str: |
|
|
""" |
|
|
Bundle compiled_transformer_1 and compiled_transformer_2 into a zip file and return the file path. |
|
|
""" |
|
|
|
|
|
tmp_zip = tempfile.NamedTemporaryFile(suffix=".zip", delete=False) |
|
|
tmp_zip.close() |
|
|
|
|
|
with zipfile.ZipFile(tmp_zip.name, "w", compression=zipfile.ZIP_DEFLATED) as zf: |
|
|
for i in range(len(gallery)): |
|
|
image_path = gallery[i] |
|
|
zf.write(image_path, arcname=os.path.basename(image_path)) |
|
|
|
|
|
print(str(len(gallery)) + " images zipped") |
|
|
return tmp_zip.name |
|
|
|
|
|
css=""" |
|
|
#col-container { |
|
|
margin: 0 auto; |
|
|
max-width: 960px; |
|
|
} |
|
|
""" |
|
|
|
|
|
with gr.Blocks(css=css) as demo: |
|
|
|
|
|
with gr.Column(elem_id="col-container"): |
|
|
gr.Markdown(f"""# FLUX.1 Kontext [dev] |
|
|
Image editing and manipulation model guidance-distilled from FLUX.1 Kontext [pro], [[blog]](https://bfl.ai/announcements/flux-1-kontext-dev) [[model]](https://huggingface.co/black-forest-labs/FLUX.1-Kontext-dev) |
|
|
""") |
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
input_image = gr.Image(label="Upload the image for editing", type="pil") |
|
|
with gr.Row(): |
|
|
prompt = gr.Text( |
|
|
label="Prompt", |
|
|
show_label=False, |
|
|
max_lines=1, |
|
|
placeholder="Enter your prompt for editing (e.g., 'Remove glasses', 'Add a hat')", |
|
|
container=False, |
|
|
) |
|
|
run_button = gr.Button(value="🚀 Edit", variant = "primary", scale=0) |
|
|
with gr.Accordion("Advanced Settings", open=False): |
|
|
|
|
|
seed = gr.Slider( |
|
|
label="Seed", |
|
|
minimum=0, |
|
|
maximum=MAX_SEED, |
|
|
step=1, |
|
|
value=0, |
|
|
) |
|
|
|
|
|
randomize_seed = gr.Checkbox(label="Randomize seed", value=True) |
|
|
|
|
|
guidance_scale = gr.Slider( |
|
|
label="Guidance Scale", |
|
|
minimum=1, |
|
|
maximum=10, |
|
|
step=0.1, |
|
|
value=2.5, |
|
|
) |
|
|
|
|
|
steps = gr.Slider( |
|
|
label="Steps", |
|
|
minimum=1, |
|
|
maximum=30, |
|
|
value=30, |
|
|
step=1 |
|
|
) |
|
|
|
|
|
width = gr.Slider( |
|
|
label="Output width", |
|
|
info="-1 = original width", |
|
|
minimum=-1, |
|
|
maximum=1024, |
|
|
value=-1, |
|
|
step=1 |
|
|
) |
|
|
|
|
|
height = gr.Slider( |
|
|
label="Output height", |
|
|
info="-1 = original height", |
|
|
minimum=-1, |
|
|
maximum=1024, |
|
|
value=-1, |
|
|
step=1 |
|
|
) |
|
|
|
|
|
with gr.Column(): |
|
|
result = gr.Image(label="Result", show_label=False, interactive=False) |
|
|
download_button = gr.DownloadButton(elem_id="download_btn", visible=False) |
|
|
reuse_button = gr.Button("Reuse this image", visible=False) |
|
|
|
|
|
with gr.Row(visible=False): |
|
|
download_button = gr.DownloadButton(elem_id="download_btn", interactive = True) |
|
|
result_gallery = gr.Gallery(label = 'Downloadable results', show_label = True, interactive = False, elem_id = "gallery1") |
|
|
gr.Examples( |
|
|
examples=[ |
|
|
["monster.png", "Make this monster ride a skateboard on the beach"] |
|
|
], |
|
|
inputs=[input_image, prompt], |
|
|
outputs=[result_gallery, seed, download_button], |
|
|
fn=infer_example, |
|
|
run_on_click=True, |
|
|
cache_examples=True, |
|
|
cache_mode='lazy' |
|
|
) |
|
|
prompt_debug=gr.Textbox(label="Prompt Debug") |
|
|
input_image_debug=gr.Image(type="pil", label="Image Debug") |
|
|
number_debug=gr.Slider(label="Number Debug", minimum=1, maximum=50, step=1, value=50) |
|
|
|
|
|
gr.Examples( |
|
|
label = "Examples from demo", |
|
|
examples=[ |
|
|
["flowers.png", "turn the flowers into sunflowers"], |
|
|
["monster.png", "make this monster ride a skateboard on the beach"], |
|
|
["cat.png", "make this cat happy"] |
|
|
], |
|
|
inputs=[input_image, prompt], |
|
|
outputs=[result, download_button, seed], |
|
|
fn=infer |
|
|
) |
|
|
|
|
|
def handle_field_debug_change(input_image_debug_data, prompt_debug_data, number_debug_data): |
|
|
prompt_debug_value[0] = prompt_debug_data |
|
|
input_image_debug_value[0] = input_image_debug_data |
|
|
number_debug_value[0] = number_debug_data |
|
|
return [] |
|
|
|
|
|
inputs_debug=[input_image_debug, prompt_debug, number_debug] |
|
|
|
|
|
input_image_debug.upload(fn=handle_field_debug_change, inputs=inputs_debug, outputs=[]) |
|
|
prompt_debug.change(fn=handle_field_debug_change, inputs=inputs_debug, outputs=[]) |
|
|
number_debug.change(fn=handle_field_debug_change, inputs=inputs_debug, outputs=[]) |
|
|
|
|
|
gr.on( |
|
|
triggers=[run_button.click, prompt.submit], |
|
|
fn = infer, |
|
|
inputs = [input_image, prompt, seed, randomize_seed, guidance_scale, steps, width, height], |
|
|
outputs = [result, download_button, seed, reuse_button] |
|
|
) |
|
|
reuse_button.click( |
|
|
fn = lambda image: image, |
|
|
inputs = [result], |
|
|
outputs = [input_image] |
|
|
) |
|
|
|
|
|
demo.launch(mcp_server=True) |