nemece commited on
Commit
10fb632
·
verified ·
1 Parent(s): 10ae1b0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -36
app.py CHANGED
@@ -7,47 +7,34 @@ import gradio as gr
7
 
8
  import os
9
  HF_TOKEN = os.getenv("HF_TOKEN")
 
10
 
 
 
 
 
 
 
 
11
 
 
 
12
 
13
- def generate_video(prompt_img_url):
14
- pipe = StableVideoDiffusionPipeline.from_pretrained(
15
- "stabilityai/stable-video-diffusion-img2vid",
16
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
17
- use_safetensors=True,
18
- token=HF_TOKEN
19
- )
20
 
21
- device = "cuda" if torch.cuda.is_available() else "cpu"
22
- pipe = pipe.to(device)
23
 
24
- # 2. Load input image
25
- img = Image.open(requests.get(prompt_img_url, stream=True).raw).convert("RGB")
26
- img = img.resize((576, 320))
 
27
 
28
- # 3. Generate video frames
29
- frames = pipe(img, num_frames=5).frames[0]
30
 
31
- # 4. Save frames
32
- os.makedirs("frames", exist_ok=True)
33
 
34
- for i, f in enumerate(frames):
35
- f.save(f"frames/frame_{i:03d}.png")
36
-
37
- # 5. Build MP4
38
- os.system("ffmpeg -y -framerate 10 -i frames/frame_%03d.png output.mp4")
39
-
40
- return "output.mp4"
41
-
42
-
43
- # Gradio UI
44
- with gr.Blocks() as demo:
45
- gr.Markdown("# Stable Video Diffusion (IMG2VID)")
46
-
47
- img_url = gr.Textbox(label="Image URL", value="https://image.pollinations.ai/prompt/cat%20running")
48
- btn = gr.Button("Generate video")
49
- output_video = gr.Video()
50
-
51
- btn.click(generate_video, inputs=img_url, outputs=output_video)
52
-
53
- demo.launch()
 
7
 
8
  import os
9
  HF_TOKEN = os.getenv("HF_TOKEN")
10
+
11
 
12
+ # 1. Load pipeline
13
+ pipe = StableVideoDiffusionPipeline.from_pretrained(
14
+ "stabilityai/stable-video-diffusion-img2vid",
15
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
16
+ use_safetensors=True,
17
+ token=HF_TOKEN
18
+ )
19
 
20
+ device = "cuda" if torch.cuda.is_available() else "cpu"
21
+ pipe = pipe.to(device)
22
 
23
+ # 2. Load input image
24
+ img_url = "https://image.pollinations.ai/prompt/cat%20running"
25
+ img = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
26
+ img = img.resize((576, 320))
 
 
 
27
 
28
+ # 3. Generate video frames
29
+ frames = pipe(img, num_frames=5).frames[0]
30
 
31
+ # 4. Save frames
32
+ os.makedirs("frames", exist_ok=True)
33
+ for i, f in enumerate(frames):
34
+ f.save(f"frames/frame_{i:03d}.png")
35
 
36
+ # 5. Build MP4 video
37
+ os.system("ffmpeg -y -framerate 10 -i frames/frame_%03d.png output.mp4")
38
 
39
+ print("Video saved as output.mp4")
 
40