File size: 1,838 Bytes
f4884e0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59

import os
import torch
import spaces
import gradio as gr
from diffusers import DiffusionPipeline
from diffusers.utils import load_image, export_to_video

# 1. Global Model Loading (Crucial for ZeroGPU)
# If you upload your weights to the Space in a folder named 'pretrained_weights', this will find them.
model_id = "./pretrained_weights" if os.path.exists("./pretrained_weights") else "huaichang/PersonaLive"

try:
    pipe = DiffusionPipeline.from_pretrained(
        model_id, 
        torch_dtype=torch.bfloat16, 
        use_safetensors=True
    )
    pipe.to("cuda")
except Exception as e:
    print(f"Error loading model: {e}")
    pipe = None

# 2. Inference Function with @spaces.GPU
@spaces.GPU(duration=120)
def generate_video(input_image, driving_video, acceleration_type):
    if pipe is None:
        return None, "Model failed to load."
    
    # Place the core logic from your inference_online.py here
    # output = pipe(image=input_image, video=driving_video, acceleration=acceleration_type).frames[0]
    
    output_path = "output.mp4"
    # Example: export_to_video(output, output_path)
    return output_path

# 3. Gradio Interface
with gr.Blocks() as demo:
    gr.Markdown("# PersonaLive ZeroGPU Demo")
    
    with gr.Row():
        with gr.Column():
            ref_img = gr.Image(label="Reference Image", type="filepath")
            drive_vid = gr.Video(label="Driving Video (or Webcam)")
            accel = gr.Radio(["none", "xformers", "tensorrt"], label="Acceleration", value="xformers")
            submit_btn = gr.Button("Generate")
        
        with gr.Column():
            output_vid = gr.Video(label="Generated Video")

    submit_btn.click(
        fn=generate_video, 
        inputs=[ref_img, drive_vid, accel], 
        outputs=output_vid
    )

if __name__ == '__main__':
    demo.launch()