FilipeR's picture
Update app.py
e34ce73 verified
import gradio as gr
import numpy as np
import random
import torch
import spaces
from PIL import Image
from diffusers import QwenImageEditPlusPipeline
import os
import base64
import json
from huggingface_hub import login
# from prompt_augment import PromptAugment
login(token=os.environ.get('hf'))
# --- Model Loading ---
dtype = torch.bfloat16
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load the model pipeline
pipe = QwenImageEditPlusPipeline.from_pretrained("FireRedTeam/FireRed-Image-Edit-1.0", torch_dtype=dtype).to(device)
# prompt_handler = PromptAugment()
# --- UI Constants and Helpers ---
MAX_SEED = np.iinfo(np.int32).max
# --- Main Inference Function (with hardcoded negative prompt) ---
@spaces.GPU()
def infer(
images,
prompt,
seed=5555,
randomize_seed=True,
true_guidance_scale=1.0,
num_inference_steps=50,
height=None,
width=None,
rewrite_prompt=False,
num_images_per_prompt=1,
progress=gr.Progress(track_tqdm=True),
):
"""
Generates an image using the local Qwen-Image diffusers pipeline.
"""
# Hardcode the negative prompt as requested
negative_prompt = " "
if randomize_seed:
seed = random.randint(0, MAX_SEED)
# Set up the generator for reproducibility
generator = torch.Generator(device=device).manual_seed(seed)
# Load input images into PIL Images
pil_images = []
if images is not None:
for item in images:
try:
if isinstance(item[0], Image.Image):
pil_images.append(item[0].convert("RGB"))
elif isinstance(item[0], str):
pil_images.append(Image.open(item[0]).convert("RGB"))
elif hasattr(item, "name"):
pil_images.append(Image.open(item.name).convert("RGB"))
except Exception:
continue
if height==256 and width==256:
height, width = None, None
print(f"Calling pipeline with prompt: '{prompt}'")
print(f"Negative Prompt: '{negative_prompt}'")
print(f"Seed: {seed}, Steps: {num_inference_steps}, Guidance: {true_guidance_scale}, Size: {width}x{height}")
if False and rewrite_prompt and len(pil_images) > 0:
# prompt = polish_prompt(prompt, pil_images[0])
# prompt = prompt_handler.predict(prompt, [pil_images[0]])
print(f"Rewritten Prompt: {prompt}")
# Generate the image
image = pipe(
image=pil_images if len(pil_images) > 0 else None,
prompt=prompt,
height=height,
width=width,
negative_prompt=negative_prompt,
num_inference_steps=num_inference_steps,
generator=generator,
true_cfg_scale=true_guidance_scale,
num_images_per_prompt=num_images_per_prompt,
).images
return image, seed
css = """
#NOcol-container {
margin: 0 auto;
max-width: 1024px;
}
"""
def get_image_base64(image_path):
with open(image_path, "rb") as img_file:
return base64.b64encode(img_file.read()).decode('utf-8')
logo_base64 = get_image_base64("logo.png")
with gr.Blocks() as demo:
with gr.Column():
# gr.HTML(f'<img src="data:image/png;base64,{logo_base64}" alt="FireRedTeam Logo" width="400" />')
# gr.Markdown("[Learn more](https://github.com/FireRedTeam/FireRed-Image-Edit) about the FireRed-Image-Edit series.")
with gr.Row():
with gr.Column():
input_images = gr.Gallery(label="Input Images", show_label=False, type="pil", interactive=True)
# result = gr.Image(label="Result", show_label=False, type="pil")
result = gr.Gallery(label="Result", show_label=False, type="pil")
with gr.Row():
prompt = gr.Text(
label="Prompt",
show_label=False,
placeholder="describe the edit instruction",
container=False,
)
run_button = gr.Button("Edit", variant="primary")
with gr.Accordion("Advanced Settings", open=True):
# Negative prompt UI element is removed here
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=0,
)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
with gr.Row():
true_guidance_scale = gr.Slider(
label="True guidance scale",
minimum=1.0,
maximum=10.0,
step=0.1,
value=4.0
)
num_inference_steps = gr.Slider(
label="Number of inference steps",
minimum=1,
maximum=50,
step=1,
value=40,
)
height = gr.Slider(
label="Height",
minimum=256,
maximum=2048,
step=8,
value=1024,
)
width = gr.Slider(
label="Width",
minimum=256,
maximum=2048,
step=8,
value=1024,
)
rewrite_prompt = gr.Checkbox(label="Rewrite prompt", value=True)
# gr.Examples(examples=examples, inputs=[prompt], outputs=[result, seed], fn=infer, cache_examples=False)
gr.on(
triggers=[run_button.click, prompt.submit],
fn=infer,
inputs=[
input_images,
prompt,
seed,
randomize_seed,
true_guidance_scale,
num_inference_steps,
height,
width,
rewrite_prompt,
],
outputs=[result, seed],
)
if __name__ == "__main__":
# demo.launch()
demo.launch()