Spaces:
Running on Zero
Running on Zero
UI improvements: move status bar to right side, simplify layout, update defaults to Wang Xizhi
Browse files- app.py +80 -86
- inference.py +2 -2
- src/flux/modules/layers.py +8 -6
- src/flux/xflux_pipeline.py +4 -1
app.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
| 1 |
# -*- coding: utf-8 -*-
|
| 2 |
"""
|
| 3 |
Gradio Demo for Chinese Calligraphy Generation - HuggingFace Space Version
|
|
|
|
| 4 |
"""
|
| 5 |
|
| 6 |
# IMPORTANT: import spaces first before any CUDA-related packages
|
|
@@ -9,6 +10,7 @@ import spaces
|
|
| 9 |
import gradio as gr
|
| 10 |
import json
|
| 11 |
import csv
|
|
|
|
| 12 |
|
| 13 |
# Load author and font mappings from CSV
|
| 14 |
def load_author_fonts_from_csv(csv_path):
|
|
@@ -83,84 +85,81 @@ def init_generator():
|
|
| 83 |
def update_font_choices(author: str):
|
| 84 |
"""
|
| 85 |
Update available font choices based on selected author
|
| 86 |
-
|
| 87 |
-
Args:
|
| 88 |
-
author: Selected author name
|
| 89 |
-
|
| 90 |
-
Returns:
|
| 91 |
-
Updated dropdown with available fonts for the author
|
| 92 |
"""
|
| 93 |
if author == "None (Synthetic / 合成风格)" or author not in AUTHOR_FONTS:
|
| 94 |
-
# If no author or synthetic, show all font types
|
| 95 |
choices = list(FONT_STYLE_NAMES.values())
|
| 96 |
else:
|
| 97 |
-
# Show only fonts available for this author
|
| 98 |
available_fonts = AUTHOR_FONTS[author]
|
| 99 |
choices = [FONT_STYLE_NAMES[font] for font in available_fonts if font in FONT_STYLE_NAMES]
|
| 100 |
|
| 101 |
-
# Return updated dropdown with first choice as default
|
| 102 |
return gr.Dropdown(choices=choices, value=choices[0] if choices else None)
|
| 103 |
|
| 104 |
|
| 105 |
-
|
| 106 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 107 |
text: str,
|
| 108 |
author_dropdown: str,
|
| 109 |
font_style: str,
|
| 110 |
num_steps: int,
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
):
|
| 115 |
"""
|
| 116 |
-
|
| 117 |
|
| 118 |
Args:
|
| 119 |
text: Input text (1-7 characters)
|
| 120 |
-
author_dropdown: Selected author
|
| 121 |
-
font_style:
|
| 122 |
-
num_steps:
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
batch_size: Number of images to generate
|
| 126 |
|
| 127 |
-
|
| 128 |
-
|
| 129 |
"""
|
| 130 |
import torch
|
| 131 |
|
| 132 |
-
# Validate text
|
| 133 |
if len(text) < 1:
|
| 134 |
raise gr.Error("文本不能为空 / Text cannot be empty")
|
| 135 |
if len(text) > 7:
|
| 136 |
raise gr.Error(f"文本最多7个字符 / Text must be at most 7 characters. Current: {len(text)}")
|
| 137 |
|
| 138 |
-
#
|
| 139 |
-
font =
|
| 140 |
-
for font_key, font_display in FONT_STYLE_NAMES.items():
|
| 141 |
-
if font_display == font_style:
|
| 142 |
-
font = font_key
|
| 143 |
-
break
|
| 144 |
-
|
| 145 |
if font is None:
|
| 146 |
raise gr.Error(f"无法识别的字体风格 / Unknown font style: {font_style}")
|
| 147 |
|
| 148 |
# Determine author
|
| 149 |
author = author_dropdown if author_dropdown != "None (Synthetic / 合成风格)" else None
|
| 150 |
|
| 151 |
-
#
|
| 152 |
-
|
| 153 |
-
seed = torch.randint(0, 2**32, (1,)).item()
|
| 154 |
|
| 155 |
-
# Initialize generator if needed
|
| 156 |
gen = init_generator()
|
| 157 |
|
| 158 |
-
|
|
|
|
|
|
|
| 159 |
results = []
|
| 160 |
seeds_used = []
|
| 161 |
|
| 162 |
-
for i in range(
|
| 163 |
-
current_seed =
|
|
|
|
|
|
|
|
|
|
|
|
|
| 164 |
result_img, cond_img = gen.generate(
|
| 165 |
text=text,
|
| 166 |
font_style=font,
|
|
@@ -168,16 +167,19 @@ def generate_calligraphy(
|
|
| 168 |
num_steps=num_steps,
|
| 169 |
seed=current_seed,
|
| 170 |
)
|
| 171 |
-
|
|
|
|
| 172 |
seeds_used.append(current_seed)
|
|
|
|
|
|
|
|
|
|
| 173 |
|
| 174 |
-
#
|
| 175 |
-
if
|
| 176 |
-
|
| 177 |
else:
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
return results, seed_info
|
| 181 |
|
| 182 |
|
| 183 |
# Create Gradio interface
|
|
@@ -199,8 +201,8 @@ with gr.Blocks(title="UniCalli - Chinese Calligraphy Generator / 中国书法生
|
|
| 199 |
|
| 200 |
text_input = gr.Textbox(
|
| 201 |
label="输入文本 / Input Text (1-7个字符 / 1-7 characters)",
|
| 202 |
-
placeholder="请输入1-7个汉字 / Enter 1-7 Chinese characters, e.g.:
|
| 203 |
-
value="
|
| 204 |
max_lines=1
|
| 205 |
)
|
| 206 |
|
|
@@ -209,19 +211,19 @@ with gr.Blocks(title="UniCalli - Chinese Calligraphy Generator / 中国书法生
|
|
| 209 |
author_dropdown = gr.Dropdown(
|
| 210 |
label="1. 选择书法家 / Select Calligrapher",
|
| 211 |
choices=["None (Synthetic / 合成风格)"] + AUTHOR_LIST,
|
| 212 |
-
value="
|
| 213 |
info="先选择历史书法家 / Choose a historical calligrapher first"
|
| 214 |
)
|
| 215 |
|
| 216 |
-
# Get initial fonts for default author (
|
| 217 |
-
initial_author = "
|
| 218 |
initial_fonts = AUTHOR_FONTS.get(initial_author, ["楷", "草", "行"])
|
| 219 |
initial_font_choices = [FONT_STYLE_NAMES[f] for f in initial_fonts if f in FONT_STYLE_NAMES]
|
| 220 |
|
| 221 |
font_style = gr.Dropdown(
|
| 222 |
label="2. 选择字体风格 / Select Font Style",
|
| 223 |
choices=initial_font_choices,
|
| 224 |
-
value="
|
| 225 |
info="根据所选书法家显示可用字体 / Shows available fonts for selected calligrapher"
|
| 226 |
)
|
| 227 |
|
|
@@ -236,45 +238,40 @@ with gr.Blocks(title="UniCalli - Chinese Calligraphy Generator / 中国书法生
|
|
| 236 |
info="更多步数 = 更高质量,但更慢 / More steps = higher quality, but slower"
|
| 237 |
)
|
| 238 |
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
)
|
| 245 |
-
random_seed = gr.Checkbox(
|
| 246 |
-
label="随机种子 / Random Seed",
|
| 247 |
-
value=False
|
| 248 |
-
)
|
| 249 |
|
| 250 |
-
|
| 251 |
-
label="
|
| 252 |
minimum=1,
|
| 253 |
-
maximum=
|
| 254 |
value=1,
|
| 255 |
-
step=1
|
| 256 |
-
info="生成多张图片以选择最佳效果 / Generate multiple images to pick the best"
|
| 257 |
)
|
| 258 |
|
| 259 |
-
generate_btn = gr.Button("🎨 生成
|
| 260 |
|
| 261 |
with gr.Column(scale=1):
|
| 262 |
# Output section
|
| 263 |
-
gr.Markdown("### 🖼️ 生成结果 / Generated
|
| 264 |
-
gr.Markdown("
|
| 265 |
|
| 266 |
output_gallery = gr.Gallery(
|
| 267 |
label="生成结果 / Generated Results",
|
| 268 |
show_label=False,
|
| 269 |
columns=2,
|
| 270 |
rows=2,
|
| 271 |
-
height=
|
| 272 |
object_fit="contain",
|
| 273 |
allow_preview=True
|
| 274 |
)
|
| 275 |
|
| 276 |
-
|
| 277 |
-
label="
|
|
|
|
| 278 |
interactive=False
|
| 279 |
)
|
| 280 |
|
|
@@ -291,45 +288,42 @@ with gr.Blocks(title="UniCalli - Chinese Calligraphy Generator / 中国书法生
|
|
| 291 |
gr.Markdown(author_info_md)
|
| 292 |
|
| 293 |
# Event handlers
|
| 294 |
-
# Update font choices when author changes
|
| 295 |
author_dropdown.change(
|
| 296 |
fn=update_font_choices,
|
| 297 |
inputs=[author_dropdown],
|
| 298 |
outputs=[font_style]
|
| 299 |
)
|
| 300 |
|
| 301 |
-
# Generate button
|
| 302 |
generate_btn.click(
|
| 303 |
-
fn=
|
| 304 |
inputs=[
|
| 305 |
text_input,
|
| 306 |
author_dropdown,
|
| 307 |
font_style,
|
| 308 |
num_steps,
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
batch_size,
|
| 312 |
],
|
| 313 |
-
outputs=[
|
| 314 |
)
|
| 315 |
|
| 316 |
# Examples
|
| 317 |
gr.Markdown("### 📋 示例 / Examples")
|
| 318 |
gr.Examples(
|
| 319 |
examples=[
|
| 320 |
-
["
|
| 321 |
-
["
|
| 322 |
-
["
|
| 323 |
-
["宁静致远", "None (Synthetic / 合成风格)", "楷 (Regular Script)", 25, 42,
|
| 324 |
],
|
| 325 |
inputs=[
|
| 326 |
text_input,
|
| 327 |
author_dropdown,
|
| 328 |
font_style,
|
| 329 |
num_steps,
|
| 330 |
-
|
| 331 |
-
|
| 332 |
-
batch_size,
|
| 333 |
],
|
| 334 |
)
|
| 335 |
|
|
|
|
| 1 |
# -*- coding: utf-8 -*-
|
| 2 |
"""
|
| 3 |
Gradio Demo for Chinese Calligraphy Generation - HuggingFace Space Version
|
| 4 |
+
With interactive session mode to avoid model reloading
|
| 5 |
"""
|
| 6 |
|
| 7 |
# IMPORTANT: import spaces first before any CUDA-related packages
|
|
|
|
| 10 |
import gradio as gr
|
| 11 |
import json
|
| 12 |
import csv
|
| 13 |
+
import time
|
| 14 |
|
| 15 |
# Load author and font mappings from CSV
|
| 16 |
def load_author_fonts_from_csv(csv_path):
|
|
|
|
| 85 |
def update_font_choices(author: str):
|
| 86 |
"""
|
| 87 |
Update available font choices based on selected author
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 88 |
"""
|
| 89 |
if author == "None (Synthetic / 合成风格)" or author not in AUTHOR_FONTS:
|
|
|
|
| 90 |
choices = list(FONT_STYLE_NAMES.values())
|
| 91 |
else:
|
|
|
|
| 92 |
available_fonts = AUTHOR_FONTS[author]
|
| 93 |
choices = [FONT_STYLE_NAMES[font] for font in available_fonts if font in FONT_STYLE_NAMES]
|
| 94 |
|
|
|
|
| 95 |
return gr.Dropdown(choices=choices, value=choices[0] if choices else None)
|
| 96 |
|
| 97 |
|
| 98 |
+
def parse_font_style(font_style: str) -> str:
|
| 99 |
+
"""Extract font key from display name"""
|
| 100 |
+
for font_key, font_display in FONT_STYLE_NAMES.items():
|
| 101 |
+
if font_display == font_style:
|
| 102 |
+
return font_key
|
| 103 |
+
return None
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
@spaces.GPU(duration=600) # 10 minutes session for multiple generations
|
| 107 |
+
def interactive_session(
|
| 108 |
text: str,
|
| 109 |
author_dropdown: str,
|
| 110 |
font_style: str,
|
| 111 |
num_steps: int,
|
| 112 |
+
start_seed: int,
|
| 113 |
+
num_images: int,
|
| 114 |
+
progress=gr.Progress()
|
| 115 |
):
|
| 116 |
"""
|
| 117 |
+
Interactive session: load model once, generate multiple images
|
| 118 |
|
| 119 |
Args:
|
| 120 |
text: Input text (1-7 characters)
|
| 121 |
+
author_dropdown: Selected author
|
| 122 |
+
font_style: Font style
|
| 123 |
+
num_steps: Inference steps
|
| 124 |
+
start_seed: Starting seed
|
| 125 |
+
num_images: Number of images to generate (each with different seed)
|
|
|
|
| 126 |
|
| 127 |
+
Yields:
|
| 128 |
+
Progress status, gallery of results
|
| 129 |
"""
|
| 130 |
import torch
|
| 131 |
|
| 132 |
+
# Validate text
|
| 133 |
if len(text) < 1:
|
| 134 |
raise gr.Error("文本不能为空 / Text cannot be empty")
|
| 135 |
if len(text) > 7:
|
| 136 |
raise gr.Error(f"文本最多7个字符 / Text must be at most 7 characters. Current: {len(text)}")
|
| 137 |
|
| 138 |
+
# Parse font style
|
| 139 |
+
font = parse_font_style(font_style)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 140 |
if font is None:
|
| 141 |
raise gr.Error(f"无法识别的字体风格 / Unknown font style: {font_style}")
|
| 142 |
|
| 143 |
# Determine author
|
| 144 |
author = author_dropdown if author_dropdown != "None (Synthetic / 合成风格)" else None
|
| 145 |
|
| 146 |
+
# Step 1: Load model (only once per session)
|
| 147 |
+
yield "⏳ 正在加载模型... / Loading model...", []
|
|
|
|
| 148 |
|
|
|
|
| 149 |
gen = init_generator()
|
| 150 |
|
| 151 |
+
yield "✅ 模型加载完成!开始生成... / Model loaded! Starting generation...", []
|
| 152 |
+
|
| 153 |
+
# Step 2: Generate multiple images
|
| 154 |
results = []
|
| 155 |
seeds_used = []
|
| 156 |
|
| 157 |
+
for i in range(num_images):
|
| 158 |
+
current_seed = start_seed + i
|
| 159 |
+
progress((i + 1) / num_images, desc=f"生成第 {i+1}/{num_images} 张...")
|
| 160 |
+
|
| 161 |
+
yield f"🎨 正在生成第 {i+1}/{num_images} 张 (Seed: {current_seed})...", results
|
| 162 |
+
|
| 163 |
result_img, cond_img = gen.generate(
|
| 164 |
text=text,
|
| 165 |
font_style=font,
|
|
|
|
| 167 |
num_steps=num_steps,
|
| 168 |
seed=current_seed,
|
| 169 |
)
|
| 170 |
+
|
| 171 |
+
results.append((result_img, f"Seed: {current_seed}"))
|
| 172 |
seeds_used.append(current_seed)
|
| 173 |
+
|
| 174 |
+
# Yield intermediate results so user can see progress
|
| 175 |
+
yield f"✅ 已完成 {i+1}/{num_images} 张 (Seed: {current_seed})", results
|
| 176 |
|
| 177 |
+
# Final yield with all seeds info
|
| 178 |
+
if num_images > 1:
|
| 179 |
+
final_status = f"✅ 全部完成!共 {num_images} 张 (Seeds: {seeds_used[0]}-{seeds_used[-1]})"
|
| 180 |
else:
|
| 181 |
+
final_status = f"✅ 完成!Seed: {seeds_used[0]}"
|
| 182 |
+
yield final_status, results
|
|
|
|
| 183 |
|
| 184 |
|
| 185 |
# Create Gradio interface
|
|
|
|
| 201 |
|
| 202 |
text_input = gr.Textbox(
|
| 203 |
label="输入文本 / Input Text (1-7个字符 / 1-7 characters)",
|
| 204 |
+
placeholder="请输入1-7个汉字 / Enter 1-7 Chinese characters, e.g.: 天道酬勤",
|
| 205 |
+
value="天道酬勤",
|
| 206 |
max_lines=1
|
| 207 |
)
|
| 208 |
|
|
|
|
| 211 |
author_dropdown = gr.Dropdown(
|
| 212 |
label="1. 选择书法家 / Select Calligrapher",
|
| 213 |
choices=["None (Synthetic / 合成风格)"] + AUTHOR_LIST,
|
| 214 |
+
value="王羲之",
|
| 215 |
info="先选择历史书法家 / Choose a historical calligrapher first"
|
| 216 |
)
|
| 217 |
|
| 218 |
+
# Get initial fonts for default author (王羲之)
|
| 219 |
+
initial_author = "王羲之"
|
| 220 |
initial_fonts = AUTHOR_FONTS.get(initial_author, ["楷", "草", "行"])
|
| 221 |
initial_font_choices = [FONT_STYLE_NAMES[f] for f in initial_fonts if f in FONT_STYLE_NAMES]
|
| 222 |
|
| 223 |
font_style = gr.Dropdown(
|
| 224 |
label="2. 选择字体风格 / Select Font Style",
|
| 225 |
choices=initial_font_choices,
|
| 226 |
+
value="草 (Cursive Script)",
|
| 227 |
info="根据所选书法家显示可用字体 / Shows available fonts for selected calligrapher"
|
| 228 |
)
|
| 229 |
|
|
|
|
| 238 |
info="更多步数 = 更高质量,但更慢 / More steps = higher quality, but slower"
|
| 239 |
)
|
| 240 |
|
| 241 |
+
start_seed = gr.Number(
|
| 242 |
+
label="起始种子 / Start Seed",
|
| 243 |
+
value=42,
|
| 244 |
+
precision=0
|
| 245 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 246 |
|
| 247 |
+
num_images = gr.Slider(
|
| 248 |
+
label="生成数量 / Number of Images",
|
| 249 |
minimum=1,
|
| 250 |
+
maximum=8,
|
| 251 |
value=1,
|
| 252 |
+
step=1
|
|
|
|
| 253 |
)
|
| 254 |
|
| 255 |
+
generate_btn = gr.Button("🎨 开始生成 / Start Generation", variant="primary", size="lg")
|
| 256 |
|
| 257 |
with gr.Column(scale=1):
|
| 258 |
# Output section
|
| 259 |
+
gr.Markdown("### 🖼️ 生成结果 / Generated Results")
|
| 260 |
+
gr.Markdown("*点击图片可放大查看 / Click image to enlarge*")
|
| 261 |
|
| 262 |
output_gallery = gr.Gallery(
|
| 263 |
label="生成结果 / Generated Results",
|
| 264 |
show_label=False,
|
| 265 |
columns=2,
|
| 266 |
rows=2,
|
| 267 |
+
height=550,
|
| 268 |
object_fit="contain",
|
| 269 |
allow_preview=True
|
| 270 |
)
|
| 271 |
|
| 272 |
+
status_text = gr.Textbox(
|
| 273 |
+
label="状态 / Status",
|
| 274 |
+
value="准备就绪 / Ready",
|
| 275 |
interactive=False
|
| 276 |
)
|
| 277 |
|
|
|
|
| 288 |
gr.Markdown(author_info_md)
|
| 289 |
|
| 290 |
# Event handlers
|
|
|
|
| 291 |
author_dropdown.change(
|
| 292 |
fn=update_font_choices,
|
| 293 |
inputs=[author_dropdown],
|
| 294 |
outputs=[font_style]
|
| 295 |
)
|
| 296 |
|
| 297 |
+
# Generate button - uses streaming for live updates
|
| 298 |
generate_btn.click(
|
| 299 |
+
fn=interactive_session,
|
| 300 |
inputs=[
|
| 301 |
text_input,
|
| 302 |
author_dropdown,
|
| 303 |
font_style,
|
| 304 |
num_steps,
|
| 305 |
+
start_seed,
|
| 306 |
+
num_images,
|
|
|
|
| 307 |
],
|
| 308 |
+
outputs=[status_text, output_gallery]
|
| 309 |
)
|
| 310 |
|
| 311 |
# Examples
|
| 312 |
gr.Markdown("### 📋 示例 / Examples")
|
| 313 |
gr.Examples(
|
| 314 |
examples=[
|
| 315 |
+
["天道酬勤", "王羲之", "草 (Cursive Script)", 25, 42, 1],
|
| 316 |
+
["春风得意马蹄疾", "赵佶\\宋徽宗", "楷 (Regular Script)", 25, 42, 1],
|
| 317 |
+
["海内存知己", "黄庭坚", "行 (Running Script)", 25, 42, 1],
|
| 318 |
+
["宁静致远", "None (Synthetic / 合成风格)", "楷 (Regular Script)", 25, 42, 1],
|
| 319 |
],
|
| 320 |
inputs=[
|
| 321 |
text_input,
|
| 322 |
author_dropdown,
|
| 323 |
font_style,
|
| 324 |
num_steps,
|
| 325 |
+
start_seed,
|
| 326 |
+
num_images,
|
|
|
|
| 327 |
],
|
| 328 |
)
|
| 329 |
|
inference.py
CHANGED
|
@@ -341,8 +341,8 @@ class CalligraphyGenerator:
|
|
| 341 |
|
| 342 |
# Move to GPU only if NOT using DeepSpeed (DeepSpeed will handle device placement)
|
| 343 |
if not use_deepspeed:
|
| 344 |
-
print(f"Moving model to {self.device}...")
|
| 345 |
-
model = model.to(self.device)
|
| 346 |
|
| 347 |
# Enable optimized attention backends
|
| 348 |
try:
|
|
|
|
| 341 |
|
| 342 |
# Move to GPU only if NOT using DeepSpeed (DeepSpeed will handle device placement)
|
| 343 |
if not use_deepspeed:
|
| 344 |
+
print(f"Moving model to {self.device} and converting to float32...")
|
| 345 |
+
model = model.to(device=self.device, dtype=torch.float32)
|
| 346 |
|
| 347 |
# Enable optimized attention backends
|
| 348 |
try:
|
src/flux/modules/layers.py
CHANGED
|
@@ -34,19 +34,21 @@ def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 10
|
|
| 34 |
:param max_period: controls the minimum frequency of the embeddings.
|
| 35 |
:return: an (N, D) Tensor of positional embeddings.
|
| 36 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
t = time_factor * t
|
| 38 |
half = dim // 2
|
| 39 |
-
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half)
|
| 40 |
-
t.device
|
| 41 |
-
)
|
| 42 |
|
| 43 |
args = t[:, None].float() * freqs[None]
|
| 44 |
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
| 45 |
if dim % 2:
|
| 46 |
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
return embedding
|
| 50 |
|
| 51 |
|
| 52 |
class MLPEmbedder(nn.Module):
|
|
|
|
| 34 |
:param max_period: controls the minimum frequency of the embeddings.
|
| 35 |
:return: an (N, D) Tensor of positional embeddings.
|
| 36 |
"""
|
| 37 |
+
# Store original dtype and device
|
| 38 |
+
orig_dtype = t.dtype
|
| 39 |
+
orig_device = t.device
|
| 40 |
+
|
| 41 |
t = time_factor * t
|
| 42 |
half = dim // 2
|
| 43 |
+
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=orig_device) / half)
|
|
|
|
|
|
|
| 44 |
|
| 45 |
args = t[:, None].float() * freqs[None]
|
| 46 |
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
| 47 |
if dim % 2:
|
| 48 |
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
| 49 |
+
|
| 50 |
+
# Always convert to original dtype
|
| 51 |
+
return embedding.to(dtype=orig_dtype, device=orig_device)
|
| 52 |
|
| 53 |
|
| 54 |
class MLPEmbedder(nn.Module):
|
src/flux/xflux_pipeline.py
CHANGED
|
@@ -225,6 +225,7 @@ class XFluxPipeline:
|
|
| 225 |
if self.controlnet_loaded:
|
| 226 |
controlnet_image = self.annotator(controlnet_image, width, height)
|
| 227 |
controlnet_image = torch.from_numpy((np.array(controlnet_image) / 127.5) - 1)
|
|
|
|
| 228 |
controlnet_image = controlnet_image.permute(
|
| 229 |
2, 0, 1).unsqueeze(0).to(torch.float32).to(self.device)
|
| 230 |
|
|
@@ -311,6 +312,7 @@ class XFluxPipeline:
|
|
| 311 |
neg_ip_scale=1.0,
|
| 312 |
is_generation=True,
|
| 313 |
):
|
|
|
|
| 314 |
x = get_noise(
|
| 315 |
1, height, width, device=self.device,
|
| 316 |
dtype=torch.float32, seed=seed
|
|
@@ -328,7 +330,8 @@ class XFluxPipeline:
|
|
| 328 |
|
| 329 |
if not self.controlnet_loaded and controlnet_image is not None: # tianshuo
|
| 330 |
# width //= 2
|
| 331 |
-
|
|
|
|
| 332 |
|
| 333 |
inp_cond = prepare(t5=self.t5, clip=self.clip, img=x, prompt=prompt)
|
| 334 |
neg_inp_cond = prepare(t5=self.t5, clip=self.clip, img=x, prompt=neg_prompt)
|
|
|
|
| 225 |
if self.controlnet_loaded:
|
| 226 |
controlnet_image = self.annotator(controlnet_image, width, height)
|
| 227 |
controlnet_image = torch.from_numpy((np.array(controlnet_image) / 127.5) - 1)
|
| 228 |
+
# Keep as float32 for VAE encoding, will be converted to model dtype after
|
| 229 |
controlnet_image = controlnet_image.permute(
|
| 230 |
2, 0, 1).unsqueeze(0).to(torch.float32).to(self.device)
|
| 231 |
|
|
|
|
| 312 |
neg_ip_scale=1.0,
|
| 313 |
is_generation=True,
|
| 314 |
):
|
| 315 |
+
# Use float32 for stable inference
|
| 316 |
x = get_noise(
|
| 317 |
1, height, width, device=self.device,
|
| 318 |
dtype=torch.float32, seed=seed
|
|
|
|
| 330 |
|
| 331 |
if not self.controlnet_loaded and controlnet_image is not None: # tianshuo
|
| 332 |
# width //= 2
|
| 333 |
+
# VAE expects float32 (controlnet_image is already float32)
|
| 334 |
+
cond_latent = self.ae.encode(controlnet_image)
|
| 335 |
|
| 336 |
inp_cond = prepare(t5=self.t5, clip=self.clip, img=x, prompt=prompt)
|
| 337 |
neg_inp_cond = prepare(t5=self.t5, clip=self.clip, img=x, prompt=neg_prompt)
|