Z-Image / ui.py
AndrewKapok's picture
Upload ui.py with huggingface_hub
98f37f6 verified
import os
import time
import logging
import gradio as gr
import torch
from PIL import Image
# 导入自定义模块
from models.model_manager import ModelManager
from inference.inference_service import InferenceService
logger = logging.getLogger(__name__)
class GradioUI:
def __init__(self, inference_service: InferenceService):
self.inference_service = inference_service
# 初始化Gradio界面
self.interface = None
def create_interface(self) -> gr.Blocks:
"""
创建Gradio界面
Returns:
gr.Blocks: Gradio界面实例
"""
with gr.Blocks(
title="Qwen-Image-2512 文本到图像生成",
theme=gr.themes.Soft(),
css="""
.main-container { max-width: 1200px; margin: 0 auto; }
.title { text-align: center; margin-bottom: 2rem; }
.upload-section { margin-bottom: 2rem; }
.params-section { margin-bottom: 2rem; }
.status-section { margin-bottom: 2rem; }
.result-section { margin-bottom: 2rem; }
.param-group { display: flex; flex-wrap: wrap; gap: 1rem; margin-bottom: 1rem; }
.param-item { flex: 1 1 200px; }
"""
) as interface:
# 标题
gr.HTML("""
<h1 class="title">Qwen-Image-2512 文本到图像生成</h1>
<p class="title" style="font-size: 1.2rem; color: #666;">基于阿里通义千问的高性能图像生成模型</p>
""")
# 状态显示
with gr.Row(elem_id="status-section"):
status_text = gr.Textbox(
label="状态",
value="模型加载中...",
interactive=False,
elem_id="status-text"
)
progress_bar = gr.Progress(track_tqdm=True)
# 主要内容区域
with gr.Row(elem_id="main-content"):
# 左侧:输入和参数
with gr.Column(scale=1, min_width=300):
# 文本提示输入
with gr.Group(elem_id="prompt-section"):
prompt = gr.Textbox(
label="生成提示",
placeholder="输入您想要生成的图像描述...",
lines=3,
max_lines=5,
elem_id="prompt-input"
)
negative_prompt = gr.Textbox(
label="负面提示",
placeholder="输入您想要避免的内容...",
lines=2,
max_lines=3,
elem_id="negative-prompt-input"
)
# 参数控制面板
with gr.Group(elem_id="params-section"):
gr.Markdown("### 生成参数")
with gr.Row(elem_id="param-group"):
# 图像尺寸
with gr.Column(elem_id="param-item"):
width = gr.Slider(
label="宽度",
minimum=256,
maximum=2512,
step=64,
value=1024,
elem_id="width-slider"
)
height = gr.Slider(
label="高度",
minimum=256,
maximum=2512,
step=64,
value=1024,
elem_id="height-slider"
)
# 推理参数
with gr.Column(elem_id="param-item"):
num_inference_steps = gr.Slider(
label="推理步数",
minimum=1,
maximum=100,
step=1,
value=50,
elem_id="steps-slider"
)
guidance_scale = gr.Slider(
label="引导尺度",
minimum=0.0,
maximum=20.0,
step=0.1,
value=7.5,
elem_id="guidance-slider"
)
# 其他参数
with gr.Column(elem_id="param-item"):
seed = gr.Number(
label="随机种子",
value=None,
precision=0,
elem_id="seed-input"
)
num_images = gr.Slider(
label="生成数量",
minimum=1,
maximum=4,
step=1,
value=1,
elem_id="num-images-slider"
)
# 生成按钮
with gr.Row(elem_id="button-section"):
generate_btn = gr.Button(
"生成图像",
variant="primary",
size="lg",
elem_id="generate-btn"
)
clear_btn = gr.Button(
"清除",
variant="secondary",
size="lg",
elem_id="clear-btn"
)
# 右侧:结果展示
with gr.Column(scale=2, min_width=500):
with gr.Group(elem_id="result-section"):
gr.Markdown("### 生成结果")
# 图像输出区域
gallery = gr.Gallery(
label="生成的图像",
show_label=False,
elem_id="gallery",
columns=2,
rows=2,
object_fit="contain",
height="auto"
)
# 生成信息
with gr.Row(elem_id="info-section"):
execution_time = gr.Textbox(
label="生成时间",
interactive=False,
elem_id="execution-time"
)
image_count = gr.Textbox(
label="图像数量",
interactive=False,
elem_id="image-count"
)
# 示例提示
with gr.Row(elem_id="examples-section"):
gr.Markdown("### 示例提示")
examples = gr.Examples(
examples=[
["一只可爱的柯基犬在草地上奔跑,阳光明媚,高清细节", "模糊, 低质量, 变形", 1024, 1024, 50, 7.5, None, 1],
["一个未来主义城市的夜景,霓虹灯闪烁,飞行器穿梭", "模糊, 低质量, 变形", 1024, 1024, 50, 7.5, None, 1],
["一朵盛开的玫瑰花,特写镜头,超高清细节,自然光线", "模糊, 低质量, 变形", 1024, 1024, 50, 7.5, None, 1],
],
inputs=[prompt, negative_prompt, width, height, num_inference_steps, guidance_scale, seed, num_images],
outputs=[gallery, execution_time, image_count],
fn=self.generate_images,
cache_examples=False
)
# 事件监听
generate_btn.click(
fn=self.generate_images,
inputs=[prompt, negative_prompt, width, height, num_inference_steps, guidance_scale, seed, num_images],
outputs=[gallery, execution_time, image_count, status_text],
show_progress=True
)
clear_btn.click(
fn=self.clear_all,
inputs=[],
outputs=[prompt, negative_prompt, width, height, num_inference_steps, guidance_scale, seed, num_images, gallery, execution_time, image_count, status_text]
)
# 初始化状态
status_text.value = "就绪,可以生成图像"
return interface
def generate_images(
self,
prompt: str,
negative_prompt: str,
width: int,
height: int,
num_inference_steps: int,
guidance_scale: float,
seed: int,
num_images: int
):
"""
生成图像的处理函数
Args:
prompt: 生成提示
negative_prompt: 负面提示
width: 生成图像宽度
height: 生成图像高度
num_inference_steps: 推理步数
guidance_scale: 引导尺度
seed: 随机种子
num_images: 生成图像数量
Returns:
tuple: (生成的图像列表, 执行时间, 图像数量, 状态)
"""
if not prompt:
return [], "0.00秒", "0", "请输入生成提示"
try:
start_time = time.time()
# 生成图像
images = self.inference_service.generate_image(
prompt=prompt,
negative_prompt=negative_prompt if negative_prompt else None,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
width=width,
height=height,
seed=seed if seed is not None else None,
num_images_per_prompt=num_images
)
end_time = time.time()
execution_time = end_time - start_time
return (
images,
f"{execution_time:.2f}秒",
f"{len(images)}",
"生成完成"
)
except Exception as e:
logger.error(f"图像生成失败: {str(e)}")
return [], "0.00秒", "0", f"生成失败: {str(e)}"
def clear_all(self):
"""
清除所有输入和输出
Returns:
tuple: 清除后的状态
"""
return (
"", # prompt
"", # negative_prompt
1024, # width
1024, # height
50, # num_inference_steps
7.5, # guidance_scale
None, # seed
1, # num_images
[], # gallery
"", # execution_time
"", # image_count
"就绪,可以生成图像" # status_text
)
def launch(self, share: bool = False, server_name: str = "0.0.0.0", server_port: int = 7860):
"""
启动Gradio界面
Args:
share: 是否生成公共链接
server_name: 服务器地址
server_port: 服务器端口
"""
if self.interface is None:
self.interface = self.create_interface()
logger.info(f"启动Gradio界面: http://{server_name}:{server_port}")
self.interface.launch(
share=share,
server_name=server_name,
server_port=server_port,
show_api=False,
quiet=True
)
if __name__ == "__main__":
# 配置日志
logging.basicConfig(level=logging.INFO)
# 初始化推理服务
inference_service = InferenceService(
model_path="./models",
device="cpu",
dtype=torch.float32
)
# 初始化模型
inference_service.initialize()
# 创建并启动Gradio界面
gradio_ui = GradioUI(inference_service)
gradio_ui.launch()