File size: 3,901 Bytes
abf0f61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import os
import torch
from diffusers import ZImagePipeline
import gradio as gr
import threading
import queue
import psutil

# =================== CPU GOD MODE SETTINGS ===================
torch.set_num_threads(torch.get_num_threads())
torch.inference_mode()

MODEL_ID = "Tongyi-MAI/Z-Image-Turbo"

pipe = ZImagePipeline.from_pretrained(
    MODEL_ID,
    torch_dtype=torch.bfloat16,
    low_cpu_mem_usage=True
)
pipe.to("cpu")
pipe.enable_attention_slicing()
pipe.enable_model_cpu_offload()

try:
    pipe.transformer.compile(fullgraph=True, dynamic=True)
except Exception:
    pass

try:
    pipe.enable_attention_slicing(slice_size="auto")
except Exception:
    pass

MAX_THREADS = min(torch.get_num_threads(), os.cpu_count() or 4)

# =================== QUEUE & WORKERS ===================
job_queue = queue.Queue()
status_dict = {}

def worker(worker_id):
    while True:
        job = job_queue.get()
        if job is None:
            break
        job_id, prompt, width, height, steps, seed, batch, out_folder = job
        status_dict[job_id] = f"Worker {worker_id}: Processing..."
        for i in range(batch):
            img_seed = seed + i
            image = pipe(
                prompt=prompt,
                height=height,
                width=width,
                num_inference_steps=steps,
                guidance_scale=0.0,
                generator=torch.Generator("cpu").manual_seed(img_seed),
            ).images[0]
            out_path = os.path.join(out_folder, f"{job_id}_{i}.png")
            image.save(out_path)
        status_dict[job_id] = f"Worker {worker_id}: Done ({batch} images)"
        job_queue.task_done()

workers = []
for i in range(MAX_THREADS):
    t = threading.Thread(target=worker, args=(i+1,), daemon=True)
    t.start()
    workers.append(t)

# =================== JOB MANAGEMENT ===================
job_counter = 0
OUTPUT_DIR = "outputs"
os.makedirs(OUTPUT_DIR, exist_ok=True)

def enqueue_job(prompt, width, height, steps, seed, batch):
    global job_counter
    job_counter += 1
    job_id = f"job_{job_counter}"
    job_queue.put((job_id, prompt, width, height, steps, seed, batch, OUTPUT_DIR))
    status_dict[job_id] = "Queued"
    return job_id

# =================== GRADIO INTERFACE ===================
with gr.Blocks() as demo:
    gr.Markdown("# ⚑ CPU God Mode Z-Image + Gradio Ultimate")

    with gr.Row():
        prompt_input = gr.Textbox(label="Prompt", placeholder="Type prompt here...")
        seed_input = gr.Number(label="Seed", value=42, precision=0)

    with gr.Row():
        width_input = gr.Dropdown(["256","512","768","1024"], value="512", label="Width")
        height_input = gr.Dropdown(["256","512","768","1024"], value="512", label="Height")
        batch_input = gr.Slider(1, 5, value=1, step=1, label="Batch Size")

    steps_input = gr.Slider(1, 25, value=9, step=1, label="Inference Steps")
    output_gallery = gr.Gallery(label="Generated Images").style(grid=[2], height="auto")
    status_box = gr.Textbox(label="Queue Status", interactive=False)

    generate_btn = gr.Button("Generate")

    def on_generate(prompt, width, height, steps, seed, batch):
        width = int(width)
        height = int(height)
        seed = int(seed)
        batch = int(batch)
        job_id = enqueue_job(prompt, width, height, steps, seed, batch)
        return [], f"Job {job_id} queued ({batch} images)"

    def poll_status():
        if status_dict:
            return "\n".join([f"{k}: {v}" for k,v in status_dict.items()])
        return "No jobs in queue."

    generate_btn.click(
        on_generate,
        inputs=[prompt_input, width_input, height_input, steps_input, seed_input, batch_input],
        outputs=[output_gallery, status_box]
    )

    status_box.change(fn=lambda: poll_status(), inputs=[], outputs=status_box)

demo.launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT", 7860)))