|
|
| import os |
| import torch |
| import spaces |
| import gradio as gr |
| from diffusers import DiffusionPipeline |
| from diffusers.utils import load_image, export_to_video |
|
|
| |
| |
| model_id = "./pretrained_weights" if os.path.exists("./pretrained_weights") else "huaichang/PersonaLive" |
|
|
| try: |
| pipe = DiffusionPipeline.from_pretrained( |
| model_id, |
| torch_dtype=torch.bfloat16, |
| use_safetensors=True |
| ) |
| pipe.to("cuda") |
| except Exception as e: |
| print(f"Error loading model: {e}") |
| pipe = None |
|
|
| |
| @spaces.GPU(duration=120) |
| def generate_video(input_image, driving_video, acceleration_type): |
| if pipe is None: |
| return None, "Model failed to load." |
| |
| |
| |
| |
| output_path = "output.mp4" |
| |
| return output_path |
|
|
| |
| with gr.Blocks() as demo: |
| gr.Markdown("# PersonaLive ZeroGPU Demo") |
| |
| with gr.Row(): |
| with gr.Column(): |
| ref_img = gr.Image(label="Reference Image", type="filepath") |
| drive_vid = gr.Video(label="Driving Video (or Webcam)") |
| accel = gr.Radio(["none", "xformers", "tensorrt"], label="Acceleration", value="xformers") |
| submit_btn = gr.Button("Generate") |
| |
| with gr.Column(): |
| output_vid = gr.Video(label="Generated Video") |
|
|
| submit_btn.click( |
| fn=generate_video, |
| inputs=[ref_img, drive_vid, accel], |
| outputs=output_vid |
| ) |
|
|
| if __name__ == '__main__': |
| demo.launch() |
|
|