Alexander Bagus
22
bf66885
raw
history blame
8.18 kB
import gradio as gr
import numpy as np
import torch, random, json, spaces, time
from ulid import ULID
from diffsynth.pipelines.qwen_image import (
QwenImagePipeline, ModelConfig,
QwenImageUnit_Image2LoRAEncode, QwenImageUnit_Image2LoRADecode
)
from safetensors.torch import save_file
import torch
from PIL import Image
from utils import repo_utils, image_utils, prompt_utils
# repo_utils.clone_repo_if_not_exists("git clone https://huggingface.co/DiffSynth-Studio/General-Image-Encoders", "app/repos")
# repo_utils.clone_repo_if_not_exists("https://huggingface.co/apple/starflow", "app/models")
URL_PUBLIC = "https://huggingface.co/spaces/AiSudo/Qwen-Image-to-LoRA/blob/main"
DTYPE = torch.bfloat16
MAX_SEED = np.iinfo(np.int32).max
vram_config_disk_offload = {
"offload_dtype": "disk",
"offload_device": "disk",
"onload_dtype": "disk",
"onload_device": "disk",
"preparing_dtype": torch.bfloat16,
"preparing_device": "cuda",
"computation_dtype": torch.bfloat16,
"computation_device": "cuda",
}
# Load models
pipe = QwenImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(
download_source="huggingface",
model_id="DiffSynth-Studio/General-Image-Encoders",
origin_file_pattern="SigLIP2-G384/model.safetensors",
**vram_config_disk_offload
),
ModelConfig(
download_source="huggingface",
model_id="DiffSynth-Studio/General-Image-Encoders",
origin_file_pattern="DINOv3-7B/model.safetensors",
**vram_config_disk_offload
),
ModelConfig(
download_source="huggingface",
model_id="DiffSynth-Studio/Qwen-Image-i2L",
origin_file_pattern="Qwen-Image-i2L-Style.safetensors",
**vram_config_disk_offload
),
],
processor_config=ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="processor/"),
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
)
@spaces.GPU
def generate_lora(
input_images,
progress=gr.Progress(track_tqdm=True),
):
ulid = str(ULID()).lower()[:12]
print(f"ulid: {ulid}")
input_images = list(input_images)
if not input_images:
print("images are empty.")
# Load images
# images = [
# Image.open("examples/style/1/0.jpg"),
# Image.open("examples/style/1/1.jpg"),
# Image.open("examples/style/1/2.jpg"),
# Image.open("examples/style/1/3.jpg"),
# Image.open("examples/style/1/4.jpg"),
# ]
# Model inference
with torch.no_grad():
embs = QwenImageUnit_Image2LoRAEncode().process(pipe, image2lora_images=input_images)
lora = QwenImageUnit_Image2LoRADecode().process(pipe, **embs)["lora"]
lora_name = f"{ulid}.safetensors"
lora_path = f"loras/{lora_name}"
save_file(lora, lora_path)
return lora_name, gr.update(visible=True, value=lora_path), gr.update(visible=True)
@spaces.GPU
def generate_image(
lora_name,
prompt,
negative_prompt="blurry ugly bad",
width=1024,
height=1024,
seed=42,
randomize_seed=True,
guidance_scale=3.5,
num_inference_steps=8,
progress=gr.Progress(track_tqdm=True),
):
return True
def read_file(path: str) -> str:
with open(path, 'r', encoding='utf-8') as f:
content = f.read()
return content
css = """
#col-container {
margin: 0 auto;
max-width: 960px;
}
h3{
text-align: center;
display:block;
}
#imagen-container {
padding: 12px;
}
"""
with open('examples/0_examples.json', 'r') as file: examples = json.load(file)
with gr.Blocks() as demo:
with gr.Column(elem_id="col-container"):
with gr.Column():
gr.HTML(read_file("static/header.html"))
with gr.Row():
with gr.Column():
input_images = gr.Gallery(
label="Input images",
file_types=["image"],
show_label=False,
elem_id="gallery",
columns=2,
object_fit="cover",
height=300)
lora_button = gr.Button("Generate LoRA", variant="primary")
with gr.Column():
lora_name = gr.Textbox(label="Generated LoRA path",lines=2, interactive=False)
lora_download = gr.DownloadButton(label=f"Download LoRA", visible=False)
with gr.Column(visible=False, elem_classes='imagen-container') as imagen_container:
gr.Markdown("### After your LoRA is ready, you can try generate image here.")
with gr.Row():
with gr.Column():
prompt = gr.Textbox(
label="Prompt",
show_label=False,
lines=2,
placeholder="Enter your prompt",
value="a man in a fishing boat. high quality, detailed",
container=False,
)
imagen_button = gr.Button("Generate Image", variant="primary")
with gr.Accordion("Advanced Settings", open=False):
negative_prompt = gr.Textbox(
label="Negative prompt",
lines=2,
container=False,
placeholder="Enter your negative prompt",
value="blurry ugly bad"
)
num_inference_steps = gr.Slider(
label="Steps",
minimum=1,
maximum=50,
step=1,
value=20,
)
with gr.Row():
width = gr.Slider(
label="Width",
minimum=512,
maximum=1280,
step=32,
value=768,
)
height = gr.Slider(
label="Height",
minimum=512,
maximum=1280,
step=32,
value=1024,
)
with gr.Row():
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=42,
)
guidance_scale = gr.Slider(
label="Guidance scale",
minimum=0.0,
maximum=10.0,
step=0.1,
value=3.5,
)
randomize_seed = gr.Checkbox(label="Randomize seed", value=False)
with gr.Column():
output_image = gr.Image(label="Generated image", show_label=False)
# gr.Examples(examples=examples, inputs=[input_image])
gr.Markdown(read_file("static/footer.md"))
lora_button.click(
fn=generate_lora,
inputs=[
input_images
],
outputs=[lora_name, lora_download, imagen_container],
)
imagen_button.click(
fn=generate_image,
inputs=[
lora_name,
prompt,
negative_prompt,
width,
height,
seed,
randomize_seed,
guidance_scale,
num_inference_steps,
],
outputs=[lora_name, lora_download, imagen_container],
)
if __name__ == "__main__":
demo.launch(mcp_server=True, css=css)