File size: 9,237 Bytes
e2d7812
 
 
 
ff06db6
e2d7812
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
import gradio as gr
import torch
from diffusers import DiffusionPipeline, QwenImageEditPipeline
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from qwen_vl_utils import QwenVLTokenizer
from PIL import Image

# Global variables for each model pipeline
gen_pipe = None
edit_pipe = None

# --- Model Loading Functions ---

def setup_generation_model():
    """
    Loads and quantizes the Qwen/Qwen-Image model for text-to-image generation using bitsandbytes.
    """
    global gen_pipe
    if gen_pipe is not None:
        return "Generation Model already loaded. ✨"

    model_id = "Qwen/Qwen-Image"
    device = "cuda" if torch.cuda.is_available() else "cpu"

    print(f"Loading Qwen-Image Generation Model on {device} with bitsandbytes quantization...")

    try:
        # Define BitsAndBytesConfig for 4-bit quantization
        bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=torch.bfloat16,
            bnb_4bit_use_double_quant=True,
        )

        gen_model = AutoModelForCausalLM.from_pretrained(
            model_id,
            quantization_config=bnb_config if device == "cuda" else None,
            torch_dtype=torch.bfloat16,
            trust_remote_code=True
        )
        
        gen_tokenizer = QwenVLTokenizer.from_pretrained(model_id, trust_remote_code=True)
        
        gen_pipe = DiffusionPipeline.from_pretrained(
            model_id,
            model=gen_model,
            tokenizer=gen_tokenizer,
            torch_dtype=torch.bfloat16,
            use_safetensors=True,
            trust_remote_code=True
        )
        gen_pipe.to(device)
        print("Qwen-Image Generation Model loaded and quantized successfully.")
        return "Generation Model loaded! 🚀"
    except Exception as e:
        gen_pipe = None
        return f"Generation Model setup failed. Error: {e}"

def setup_editing_model():
    """
    Loads the Qwen/Qwen-Image-Edit pipeline for image-to-image editing using bitsandbytes.
    """
    global edit_pipe
    if edit_pipe is not None:
        return "Editing Model already loaded. ✨"

    model_id = "Qwen/Qwen-Image-Edit"
    device = "cuda" if torch.cuda.is_available() else "cpu"

    print(f"Loading Qwen-Image-Edit Model on {device} with bitsandbytes quantization...")

    try:
        # Note: QwenImageEditPipeline is a custom pipeline, so direct bnb quantization might not work
        # as seamlessly as with the CausalLM model. We'll rely on it internally.
        edit_pipe = QwenImageEditPipeline.from_pretrained(
            model_id,
            torch_dtype=torch.bfloat16,
            use_safetensors=True,
            trust_remote_code=True
        )
        edit_pipe.to(device)
        print("Qwen-Image-Edit model loaded successfully.")
        return "Editing Model loaded! ✂️"
    except Exception as e:
        edit_pipe = None
        return f"Editing Model setup failed. Error: {e}"

# --- Generation and Editing Functions (remain the same as before) ---

def generate_image(prompt, negative_prompt, num_inference_steps, guidance_scale, seed):
    global gen_pipe
    if gen_pipe is None: return None, "Model not loaded.", ""
    generator = torch.Generator(device=gen_pipe.device).manual_seed(seed) if seed != -1 else None
    try:
        image = gen_pipe(prompt=prompt, negative_prompt=negative_prompt, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, generator=generator).images[0]
        return image, "Image generated successfully!", ""
    except Exception as e:
        return None, "An error occurred during image generation.", f"Error: {e}"

def edit_image(input_image_pil, prompt, negative_prompt, num_inference_steps, guidance_scale, true_cfg_scale, denoising_strength, seed):
    global edit_pipe
    if edit_pipe is None: return None, "Model not loaded.", ""
    if input_image_pil is None: return None, "Please upload an image.", ""
    generator = torch.Generator(device=edit_pipe.device).manual_seed(seed) if seed != -1 else None
    try:
        edited_image = edit_pipe(image=input_image_pil.convert("RGB"), prompt=prompt, negative_prompt=negative_prompt, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, true_cfg_scale=true_cfg_scale, denoising_strength=denoising_strength, generator=generator).images[0]
        return edited_image, "Image edited successfully!", ""
    except Exception as e:
        return None, "An error occurred during image editing.", f"Error: {e}"

# --- Gradio UI ---

with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown("# 🎨 Qwen Image Studio: Generation & Editing")
    gr.Markdown("Explore the power of Qwen models for advanced image generation and detailed editing.")

    with gr.Tab("Image Generation (Qwen/Qwen-Image)"):
        gr.Markdown("### Text-to-Image Generation")
        gr.Markdown("Create new images from text prompts. ")
        
        with gr.Row():
            gen_model_status = gr.Textbox(value="Generation Model not loaded. Click 'Load' to begin.", interactive=False, label="Model Status")
            load_gen_button = gr.Button("Load Generation Model", variant="primary")
            load_gen_button.click(fn=setup_generation_model, outputs=gen_model_status)

        with gr.Column():
            gen_prompt = gr.Textbox(label="Prompt", placeholder="A majestic dragon flying over a futuristic city at sunset, highly detailed, photorealistic", lines=2)
            gen_negative_prompt = gr.Textbox(label="Negative Prompt (Optional)", placeholder="blurry, low quality, distorted, bad anatomy", lines=1)
            
            with gr.Accordion("Advanced Settings", open=False):
                with gr.Row():
                    gen_num_steps = gr.Slider(minimum=10, maximum=150, step=1, value=50, label="Inference Steps")
                    gen_guidance_scale = gr.Slider(minimum=1.0, maximum=20.0, step=0.5, value=7.5, label="Guidance Scale")
                with gr.Row():
                    gen_seed = gr.Number(label="Seed (-1 for random)", value=-1, precision=0)
            
            generate_button = gr.Button("Generate Image", variant="secondary")
            
            gen_output_image = gr.Image(label="Generated Image")
            gen_status_text = gr.Textbox(label="Status", interactive=False)
            gen_error_text = gr.Textbox(label="Error Details", interactive=False, visible=False)

            generate_button.click(
                fn=generate_image,
                inputs=[gen_prompt, gen_negative_prompt, gen_num_steps, gen_guidance_scale, gen_seed],
                outputs=[gen_output_image, gen_status_text, gen_error_text]
            )

    with gr.Tab("Image Editing (Qwen/Qwen-Image-Edit)"):
        gr.Markdown("### Image-to-Image Editing")
        gr.Markdown("Upload an image and provide a text prompt to transform it. This model excels at semantic and appearance editing.")

        with gr.Row():
            edit_model_status = gr.Textbox(value="Editing Model not loaded. Click 'Load' to begin.", interactive=False, label="Model Status")
            load_edit_button = gr.Button("Load Editing Model", variant="primary")
            load_edit_button.click(fn=setup_editing_model, outputs=edit_model_status)

        with gr.Column():
            edit_input_image = gr.Image(label="Upload Image to Edit", type="pil")
            edit_prompt = gr.Textbox(label="Edit Prompt", placeholder="Change the dog's fur to a vibrant blue and add a red collar", lines=2)
            edit_negative_prompt = gr.Textbox(label="Negative Prompt (Optional)", placeholder="blurry, low quality, distorted, messy", lines=1)

            with gr.Accordion("Advanced Settings", open=False):
                with gr.Row():
                    edit_num_steps = gr.Slider(minimum=10, maximum=150, step=1, value=50, label="Inference Steps")
                    edit_guidance_scale = gr.Slider(minimum=1.0, maximum=20.0, step=0.5, value=7.5, label="Guidance Scale")
                with gr.Row():
                    edit_true_cfg_scale = gr.Slider(minimum=1.0, maximum=10.0, step=0.1, value=4.0, label="True CFG Scale (for more precise control)")
                    edit_denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.8, label="Denoising Strength (how much to change original)")
                with gr.Row():
                    edit_seed = gr.Number(label="Seed (-1 for random)", value=-1, precision=0)

            edit_button = gr.Button("Edit Image", variant="secondary")
            
            edit_output_image = gr.Image(label="Edited Image")
            edit_status_text = gr.Textbox(label="Status", interactive=False)
            edit_error_text = gr.Textbox(label="Error Details", interactive=False, visible=False)

            edit_button.click(
                fn=edit_image,
                inputs=[edit_input_image, edit_prompt, edit_negative_prompt, edit_num_steps, edit_guidance_scale, edit_true_cfg_scale, edit_denoising_strength, edit_seed],
                outputs=[edit_output_image, edit_status_text, edit_error_text]
            )

# Launch the app
demo.launch(inbrowser=True, share=False)