File size: 2,670 Bytes
ac53067
 
 
96c9fdb
ac53067
c901cb7
 
 
ac53067
c901cb7
 
2f7fdab
c901cb7
ac53067
c901cb7
 
 
 
 
 
2f7fdab
ac53067
 
c901cb7
 
 
ac53067
 
 
c901cb7
96c9fdb
c901cb7
2f7fdab
c901cb7
2f7fdab
c901cb7
96c9fdb
ac53067
2f7fdab
 
 
 
 
 
 
 
 
 
 
 
 
 
96c9fdb
 
ac53067
96c9fdb
ac53067
 
96c9fdb
2f7fdab
96c9fdb
2f7fdab
ac53067
 
2f7fdab
ac53067
 
c901cb7
 
 
ac53067
 
 
2f7fdab
c901cb7
 
 
2f7fdab
ac53067
 
c901cb7
2f7fdab
ac53067
 
 
 
2f7fdab
ac53067
 
2f7fdab
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import gradio as gr
from diffusers import DiffusionPipeline
import torch
import numpy as np
from PIL import Image
import time
import warnings
warnings.filterwarnings("ignore")

# Set to use CPU
torch_device = "cpu"
torch_dtype = torch.float32

def load_model():
    model_id = "damo-vilab/text-to-video-ms-1.7b"
    pipe = DiffusionPipeline.from_pretrained(
        model_id, 
        torch_dtype=torch_dtype
    )
    pipe = pipe.to(torch_device)
    pipe.enable_attention_slicing()
    return pipe

def generate_video(prompt, num_frames=8, num_inference_steps=20):
    start_time = time.time()
    
    if not hasattr(generate_video, "pipe"):
        generate_video.pipe = load_model()
    
    with torch.no_grad():
        output = generate_video.pipe(
            prompt,
            num_frames=min(num_frames, 8),
            num_inference_steps=min(num_inference_steps, 20),
            height=256,
            width=256
        )
    
    # Correct frame conversion - handle the 4D array properly
    video_frames = output.frames
    if isinstance(video_frames, np.ndarray):
        # Reshape from (1, num_frames, height, width, 3) to (num_frames, height, width, 3)
        if video_frames.ndim == 5:
            video_frames = video_frames[0]  # Remove batch dimension
        
        frames = []
        for frame in video_frames:
            # Convert to 8-bit and ensure correct channel order
            frame = (frame * 255).astype(np.uint8)
            frames.append(Image.fromarray(frame))
    else:
        raise ValueError("Unexpected frame format")
    
    # Create GIF
    gif_path = "output.gif"
    frames[0].save(
        gif_path,
        save_all=True,
        append_images=frames[1:],
        duration=100,  # 100ms per frame
        loop=0,
        quality=80
    )
    
    print(f"Generation took {time.time() - start_time:.2f} seconds")
    return gif_path

# Gradio Interface
with gr.Blocks(title="CPU Text-to-Video") as demo:
    gr.Markdown("# 🐢 CPU Text-to-Video Generator")
    
    with gr.Row():
        with gr.Column():
            prompt = gr.Textbox(label="Prompt")
            with gr.Accordion("Advanced Options", open=False):
                frames = gr.Slider(4, 12, value=8, step=4, label="Frames")
                steps = gr.Slider(10, 30, value=20, step=5, label="Steps")
            submit = gr.Button("Generate")
        
        with gr.Column():
            output = gr.Image(label="Result", format="gif")
            gr.Markdown("Note: CPU generation may take several minutes")
    
    submit.click(
        fn=generate_video,
        inputs=[prompt, frames, steps],
        outputs=output
    )

demo.launch()