File size: 7,854 Bytes
08c0a40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
"""
Hugging Face Space demo for UniBioTransfer.
Gradio interface for face/hair/motion/head transfer.

ZeroGPU Compatible:
- Model initialized on CPU (no GPU memory during startup)
- Inference wrapped with @spaces.GPU decorator
- Thread-safe global variable access with Lock
"""

import threading
import torch
from PIL import Image
import numpy as np

# ==========================================
# 兼容层:处理本地测试 vs HF ZeroGPU 环境
# ==========================================
try:
    import spaces
    print("Detected spaces library (Hugging Face environment).")
except ImportError:
    print("Local environment detected. Mocking spaces.GPU...")
    class spaces:
        @staticmethod
        def GPU(func):
            return func  # 本地测试时,装饰器变为空壳,直接执行原函数

from infer_hf import UniBioTransferPipeline

# 锁和全局单例 Pipeline
inference_lock = threading.Lock()
global_pipeline :UniBioTransferPipeline = None


def get_pipeline(task):
    """
    单例模式:全局只初始化一次模型(放在 CPU),后续只切换任务。
    强制写死 CPU,保证 ZeroGPU 全局初始化时不碰显卡。
    """
    global global_pipeline
    if global_pipeline is None:
        print("Initializing pipeline once on CPU...")
        # 强制写死 CPU,保证 ZeroGPU 全局初始化时不碰显卡
        global_pipeline = UniBioTransferPipeline.from_pretrained(
            repo_id="scy639/UniBioTransfer",
            task=task,
            device="cpu",
        )
    else:
        # 如果模型已经在内存中,只需切换 task ID 即可
        print(f"Switching existing pipeline to task: {task}")
        global_pipeline.set_task(task)
    return global_pipeline


# 核心:将所有会用到 GPU 的前向推理逻辑包裹在这里
@spaces.GPU
def run_gpu_inference(pipeline:UniBioTransferPipeline, tgt_pil, ref_pil, ddim_steps, scale, seed, num_images):
    """
    这里是 ZeroGPU 分配算力的地方。进入此函数时可以安全地 to("cuda")。
    如果是在本地服务器,这个装饰器没用,但内部的 .to("cuda") 同样生效。
    """
    return pipeline(
        tgt_pil,
        ref_pil,
        ddim_steps=ddim_steps,
        scale=scale,
        seed=seed,
        num_images=num_images,
    )


def inference(task, tgt_img, ref_img, ddim_steps, seed, num_images):
    """
    Run inference for the demo.
    """
    if tgt_img is None or ref_img is None:
        return None, "Please upload both target and reference images."

    try:
        # 1. 拿模型 (此时模型在 CPU)
        pipeline = get_pipeline(task)

        tgt_pil = Image.fromarray(tgt_img).convert("RGB")
        ref_pil = Image.fromarray(ref_img).convert("RGB")

        # 2. 加锁,防止并发污染 global_.task,进入 GPU 推理
        with inference_lock:
            results = run_gpu_inference(
                pipeline,
                tgt_pil,
                ref_pil,
                int(ddim_steps),
                float(3),
                int(seed),
                int(num_images)
            )

        return results, f"Success! Task: {task} transfer completed."

    except Exception as e:
        import traceback
        error_msg = f"Error: {str(e)}\n{traceback.format_exc()}"
        print(f"{error_msg}")
        return None, error_msg


def create_demo():
    """Create Gradio demo interface."""
    import gradio as gr

    with gr.Blocks(title="UniBioTransfer") as demo:
        gr.Markdown(
            """
            # UniBioTransfer

            Perform face transfer, hair transfer, motion transfer (face reenactment), and head transfer.

            - **Face Transfer**: Transfer face identity from reference to target
            - **Hair Transfer**: Transfer hairstyle from reference to target
            - **Motion Transfer**: Transfer motion(expression+head pose) from reference to target
            - **Head Transfer**: Transfer entire head from reference to target

            [Code](https://github.com/scy639/UniBioTransfer)
            [Project Page](https://scy639.github.io/UniBioTransfer.github.io/)
            [Paper](https://arxiv.org/abs/2603.19637)
            """
        )

        with gr.Row():
            with gr.Column():
                task_dropdown = gr.Dropdown(
                    choices=["face", "hair", "motion", "head"],
                    value="face",
                    label="Task",
                    info="Select the transfer type",
                )

                with gr.Row():
                    tgt_image = gr.Image(
                        label="Target Image",
                        type="numpy",
                        height=300,
                    )
                    ref_image = gr.Image(
                        label="Reference Image",
                        type="numpy",
                        height=300,
                    )

                with gr.Row():
                    ddim_steps = gr.Slider(
                        minimum=4,
                        maximum=50,
                        value=50,
                        step=1,
                        label="DDIM Steps",
                        info="More steps = better quality but slower",
                    )
                    # scale = gr.Slider(
                    #     minimum=1.0,
                    #     maximum=10.0,
                    #     value=3.0,
                    #     step=0.5,
                    #     label="CFG Scale",
                    #     info="Guidance scale for conditioning",
                    # )

                seed = gr.Number(
                    value=42,
                    label="Random Seed",
                    info="For reproducibility",
                )

                num_images = gr.Slider(
                    minimum=1,
                    maximum=32,
                    value=4,
                    step=1,
                    label="Number of output images",
                    info="Multi-output with different initial noise",
                )

                run_btn = gr.Button("Run Inference", variant="primary")

            with gr.Column():
                output_gallery = gr.Gallery(
                    label="Results",
                    height=800,
                    columns=2,
                )
                status_text = gr.Textbox(
                    label="Status",
                    lines=3,
                )

        gr.Markdown(
"""
### Usage
1. Upload a **target image** (the person whose face/hair/motion/head will be modified)
2. Upload a **reference image** (the source of the attribute to transfer)
3. Select the **task** type
4. Click "Run Inference"

### Requirements
- Works best when the heads in the two input images have similar sizes.
"""
        )

        run_btn.click(
            fn=inference,
            inputs=[task_dropdown, tgt_image, ref_image, ddim_steps, seed, num_images],
            outputs=[output_gallery, status_text],
        )

        task_dropdown.change(
            fn=lambda t: f"Task switched to: {t} transfer",
            inputs=[task_dropdown],
            outputs=[status_text],
        )

        gr.Examples(
            examples=[
                ["face", "examples/face/tgt.png", "examples/face/ref.png",       20, 42, 4],
                ["hair", "examples/hair/tgt.png", "examples/hair/ref.png",       20, 42, 4],
                ["motion", "examples/motion/tgt.png", "examples/motion/ref.png", 20, 42, 4],
                ["head", "examples/head/tgt.png", "examples/head/ref.png",       20, 42, 4],
            ],
            inputs=[task_dropdown, tgt_image, ref_image, ddim_steps, seed, num_images],
            label="Examples",
        )

    return demo


if __name__ == "__main__":
    demo = create_demo()
    demo.launch()