File size: 3,269 Bytes
d374f60
 
 
 
 
 
 
 
 
 
 
fbc5276
 
d374f60
fbc5276
 
 
d374f60
fbc5276
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d374f60
 
 
 
 
 
 
 
 
 
 
fbc5276
 
 
 
d374f60
 
fbc5276
 
 
 
 
 
 
 
 
d374f60
fbc5276
 
d374f60
fbc5276
 
 
 
 
 
 
d374f60
fbc5276
 
d374f60
fbc5276
d374f60
fbc5276
 
 
 
 
 
d374f60
fbc5276
d374f60
fbc5276
 
 
 
 
 
d374f60
fbc5276
d374f60
fbc5276
 
 
 
 
 
 
 
d374f60
fbc5276
 
d374f60
fbc5276
d374f60
fbc5276
d374f60
 
 
fbc5276
d374f60
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import gradio as gr
import torch
from diffusers import StableVideoDiffusionPipeline
from PIL import Image
import imageio
import uuid
import numpy as np
import cv2

device = "cuda" if torch.cuda.is_available() else "cpu"

pipe = None
current_model = None

# 🔄 Load model only when needed (fixes slow startup)
def load_model(model_name):
    global pipe, current_model

    if current_model == model_name:
        return pipe

    try:
        if model_name == "Fast (SVD)":
            model_id = "stabilityai/stable-video-diffusion-img2vid"
        else:
            model_id = "stabilityai/stable-video-diffusion-img2vid-xt"

        pipe = StableVideoDiffusionPipeline.from_pretrained(
            model_id,
            torch_dtype=torch.float16 if device == "cuda" else torch.float32
        )

        pipe = pipe.to(device)

        if device == "cuda":
            pipe.enable_attention_slicing()
            pipe.enable_model_cpu_offload()

        current_model = model_name
        return pipe

    except Exception as e:
        print("Model load error:", e)
        return None


# 🎥 Extract frame from video
def extract_frame(video_path):
    cap = cv2.VideoCapture(video_path)
    success, frame = cap.read()
    cap.release()

    if success:
        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        return Image.fromarray(frame)
    return None


def generate_video(image, video, fps, motion_strength, model_choice):
    try:
        pipe = load_model(model_choice)
        if pipe is None:
            return None

        # Select input
        if image is not None:
            input_image = image.convert("RGB")
        elif video is not None:
            input_image = extract_frame(video)
            if input_image is None:
                return None
        else:
            return None

        # Resize (⚡ HUGE speed boost)
        input_image = input_image.resize((512, 512))

        # Generate frames (reduced for speed)
        output = pipe(
            input_image,
            num_frames=16,  # ⚡ faster
            decode_chunk_size=4,
            motion_bucket_id=int(motion_strength)
        )

        frames = output.frames[0]
        frames = [(frame * 255).astype(np.uint8) for frame in frames]

        filename = f"video_{uuid.uuid4().hex}.mp4"

        imageio.mimsave(
            filename,
            frames,
            fps=fps,
            codec="libx264"
        )

        return filename

    except Exception as e:
        print("Generation error:", e)
        return None


# 🎨 UI
with gr.Blocks() as demo:
    gr.Markdown("# 🎬 StuffMotion AI (FAST + MODEL SELECT)")

    image_input = gr.Image(type="pil", label="🖼️ Image Input")
    video_input = gr.Video(label="🎥 Video Input")

    model_choice = gr.Dropdown(
        ["Fast (SVD)", "High Quality (XT)"],
        value="Fast (SVD)",
        label="🧠 Model"
    )

    fps = gr.Slider(8, 24, value=12, step=1, label="FPS")
    motion = gr.Slider(1, 255, value=100, label="Motion")

    generate_btn = gr.Button("⚡ Generate")

    video_output = gr.Video()

    generate_btn.click(
        fn=generate_video,
        inputs=[image_input, video_input, fps, motion, model_choice],
        outputs=video_output
    )

demo.launch()