akhaliq HF Staff commited on
Commit
baaacdb
·
verified ·
1 Parent(s): f8de4f0

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +193 -0
app.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import spaces
4
+ from diffusers import DiffusionPipeline
5
+ import os
6
+ import random
7
+
8
+ # --- Model Loading and Setup ---
9
+
10
+ model_name = "OPPOer/Qwen-Image-Pruning"
11
+ COMPILATION_WIDTH = 1328
12
+ COMPILATION_HEIGHT = 1328
13
+
14
+ # Configure device and dtype
15
+ if torch.cuda.is_available():
16
+ # Use bfloat16 for optimal performance on modern NVIDIA GPUs (A100/H200 recommended)
17
+ torch_dtype = torch.bfloat16
18
+ device = "cuda"
19
+ else:
20
+ # Fallback for CPU, note: diffusion on CPU is extremely slow
21
+ torch_dtype = torch.float32
22
+ device = "cpu"
23
+
24
+ try:
25
+ # Load the pipeline
26
+ pipe = DiffusionPipeline.from_pretrained(model_name, torch_dtype=torch_dtype, trust_remote_code=True)
27
+ pipe.to(device)
28
+ except Exception as e:
29
+ # Handle environment where bfloat16 is not fully supported or other loading issues
30
+ print(f"Failed to load model with bfloat16: {e}. Trying float16/32 fallback.")
31
+ try:
32
+ torch_dtype = torch.float16 if device == "cuda" else torch.float32
33
+ pipe = DiffusionPipeline.from_pretrained(model_name, torch_dtype=torch_dtype, trust_remote_code=True)
34
+ pipe.to(device)
35
+ except Exception as e2:
36
+ print(f"Failed to load model even with fallback: {e2}")
37
+ raise e2
38
+
39
+
40
+ # Qwen-specific prompt extension (Chinese magic prompt)
41
+ positive_magic = ", 超清,4K,电影级构图。"
42
+ negative_prompt = "bad anatomy, blurry, disfigured, poorly drawn face, mutation, mutated, extra limb, missing limb, floating limbs, disconnected limbs, malformed hands, ugly, low-resolution, artifacts, text, watermark, signature"
43
+
44
+ # --- ZeroGPU AoT Compilation (Mandatory for Diffusion Models) ---
45
+
46
+ if device == "cuda":
47
+ @spaces.GPU(duration=1500)
48
+ def compile_transformer():
49
+ print("Starting AOT compilation...")
50
+
51
+ # Qwen-Image uses a transformer (DiT-style architecture).
52
+ if not hasattr(pipe, 'transformer'):
53
+ raise AttributeError("Pipeline does not have a 'transformer' attribute for AoT compilation.")
54
+
55
+ # 1. Capture example inputs (run minimal inference)
56
+ prompt_for_capture = "test prompt for compilation"
57
+
58
+ # Ensure CFG is enabled for export (true_cfg_scale=1)
59
+ temp_cfg = pipe.config.true_cfg_scale
60
+ pipe.config.true_cfg_scale = 1.0
61
+
62
+ with spaces.aoti_capture(pipe.transformer) as call:
63
+ pipe(
64
+ prompt=prompt_for_capture,
65
+ negative_prompt=negative_prompt,
66
+ width=COMPILATION_WIDTH,
67
+ height=COMPILATION_HEIGHT,
68
+ num_inference_steps=1,
69
+ true_cfg_scale=1.0,
70
+ generator=torch.Generator(device=device).manual_seed(42),
71
+ )
72
+
73
+ # Restore original config
74
+ pipe.config.true_cfg_scale = temp_cfg
75
+
76
+ # 2. Export the model (static shapes based on COMPILATION_WIDTH/HEIGHT)
77
+ exported = torch.export.export(
78
+ pipe.transformer,
79
+ args=call.args,
80
+ kwargs=call.kwargs,
81
+ )
82
+
83
+ # 3. Compile the exported model
84
+ print(f"Export successful. Compiling for {COMPILATION_WIDTH}x{COMPILATION_HEIGHT}...")
85
+ return spaces.aoti_compile(exported)
86
+
87
+ # 4. Apply compiled model to pipeline during startup
88
+ try:
89
+ compiled_transformer = compile_transformer()
90
+ spaces.aoti_apply(compiled_transformer, pipe.transformer)
91
+ print("✅ AOT Compilation successful and applied.")
92
+ except Exception as e:
93
+ print(f"⚠️ AOT Compilation failed (falling back to standard GPU mode). Performance may be lower. Error: {e}")
94
+
95
+ # --- Inference Function ---
96
+
97
+ @spaces.GPU(duration=120)
98
+ def generate_image(prompt: str, steps: int, width: int, height: int, seed: int):
99
+ # Apply the Chinese positive magic
100
+ full_prompt = prompt + positive_magic
101
+
102
+ generator = torch.Generator(device=device).manual_seed(seed)
103
+
104
+ if width % 8 != 0 or height % 8 != 0:
105
+ gr.Warning("Width and Height should be divisible by 8 for optimal performance.")
106
+
107
+ # Set true_cfg_scale=1 as specified in the original request
108
+ image = pipe(
109
+ prompt=full_prompt,
110
+ negative_prompt=negative_prompt,
111
+ width=width,
112
+ height=height,
113
+ num_inference_steps=steps,
114
+ true_cfg_scale=1,
115
+ generator=generator
116
+ ).images[0]
117
+
118
+ return image
119
+
120
+ # --- Gradio Interface ---
121
+
122
+ with gr.Blocks(theme=gr.themes.Soft(), title="Qwen-Image Text-to-Image Generation (AoT Optimized)") as demo:
123
+ gr.HTML(f"""
124
+ <div style="text-align: center; max-width: 800px; margin: 0 auto;">
125
+ <h1>Qwen-Image Pruning Text-to-Image</h1>
126
+ <p>Optimized for speed using Gradio ZeroGPU AoT Compilation.</p>
127
+ <p>🚨 Prompts should ideally be in Chinese for best results due to the model training and included magic prompts.</p>
128
+ <p style="margin-top: 10px;">Built with <a href="https://huggingface.co/spaces/akhaliq/anycoder" target="_blank">anycoder</a></p>
129
+ </div>
130
+ """)
131
+
132
+ with gr.Row():
133
+ with gr.Column(scale=1):
134
+ prompt_input = gr.Textbox(
135
+ label="Prompt (Chinese Recommended)",
136
+ value='一个穿着"QWEN"标志的T恤的中国美女正拿着黑色的马克笔面相镜头微笑。',
137
+ lines=3
138
+ )
139
+
140
+ with gr.Accordion("Generation Settings", open=True):
141
+ steps_slider = gr.Slider(
142
+ minimum=4, maximum=50, value=8, step=1, label="Inference Steps"
143
+ )
144
+
145
+ with gr.Row():
146
+ width_input = gr.Slider(
147
+ minimum=512, maximum=1536, value=COMPILATION_WIDTH, step=8, label="Width", interactive=(device != "cuda") # Restrict changing size if AoT is active on a fixed resolution
148
+ )
149
+ height_input = gr.Slider(
150
+ minimum=512, maximum=1536, value=COMPILATION_HEIGHT, step=8, label="Height", interactive=(device != "cuda")
151
+ )
152
+ if device == "cuda":
153
+ gr.Markdown(f"Note: For maximum performance (AoT), recommended resolution is {COMPILATION_WIDTH}x{COMPILATION_HEIGHT}")
154
+
155
+ seed_input = gr.Number(value=42, label="Seed", precision=0)
156
+ random_seed_btn = gr.Button("🎲 Random Seed", scale=0)
157
+
158
+ generate_btn = gr.Button("Generate Image", variant="primary")
159
+
160
+ with gr.Column(scale=2):
161
+ output_image = gr.Image(label="Generated Image", show_share_button=True)
162
+
163
+ # Example prompts
164
+ gr.Examples(
165
+ examples=[
166
+ ['一个穿着"QWEN"标志的T恤的中国美女正拿着黑色的马克笔面相镜头微笑。'],
167
+ ['海报,温馨家庭场景,柔和阳光洒在野餐布上,色彩温暖明亮。文字内容:“共享阳光,共享爱。”'],
168
+ ['一个穿着校服的年轻女孩站在教室里,在黑板上写字。黑板中央用整洁的白粉笔写着“Introducing Qwen-Image”。'],
169
+ ],
170
+ inputs=prompt_input,
171
+ outputs=output_image,
172
+ fn=generate_image,
173
+ cache_examples=False,
174
+ run_on_click=True
175
+ )
176
+
177
+ # Event handlers
178
+ generate_btn.click(
179
+ fn=generate_image,
180
+ inputs=[prompt_input, steps_slider, width_input, height_input, seed_input],
181
+ outputs=output_image,
182
+ show_progress="minimal"
183
+ )
184
+
185
+ random_seed_btn.click(
186
+ fn=lambda: int(random.randint(0, 1000000)),
187
+ inputs=[],
188
+ outputs=seed_input,
189
+ queue=False,
190
+ show_progress="hidden"
191
+ )
192
+
193
+ demo.queue().launch()