matthewkram commited on
Commit
25a48e3
·
verified ·
1 Parent(s): 961e0b4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +80 -76
app.py CHANGED
@@ -1,98 +1,102 @@
1
  import os
2
  import sys
3
- import uuid
4
- import shutil
5
  import time
6
- import gradio as gr
7
  import torch
8
- from diffusers import StableVideoDiffusionPipeline
9
- from PIL import Image
10
  import numpy as np
11
- import cv2
12
- import subprocess
13
  import tempfile
 
 
 
 
 
 
 
 
14
 
15
- class WanAnimateApp:
16
  def __init__(self):
17
- model_name = "stabilityai/stable-video-diffusion-img2vid-xt"
18
  self.pipe = StableVideoDiffusionPipeline.from_pretrained(
19
- model_name,
20
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
21
- variant="fp16",
22
- device_map="cpu"
23
  )
 
 
 
 
 
24
 
25
- def predict(
26
- self,
27
- ref_img,
28
- video,
29
- model_id,
30
- model,
31
- ):
32
- if ref_img is None or video is None:
33
- return None, "Upload both image and video."
34
 
35
- try:
36
- # Local processing — PIL for image (no open for type="pil")
37
- if isinstance(ref_img, Image.Image):
38
- ref_image = ref_img.convert("RGB").resize((576, 320))
39
- else:
40
- ref_image = Image.open(ref_img).convert("RGB").resize((576, 320))
41
 
42
- cap = cv2.VideoCapture(video)
43
- frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
44
- cap.release()
45
- motion_hint = f" with dynamic motion from {frame_count} frames"
 
 
 
 
 
 
46
 
47
- # Prompt based on mode
48
- if model_id == "wan2.2-animate-move":
49
- prompt = f"Animate the character in the reference image{motion_hint}, high quality, smooth movements."
50
- else:
51
- prompt = f"Replace the character in the video with the reference image{motion_hint}, seamless, detailed."
52
 
53
- # Parameters
54
- num_frames = 25 if model == "wan-pro" else 14
55
- num_steps = 25 if model == "wan-pro" else 15
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
- # Local generation
58
- generator = torch.Generator(device="cpu").manual_seed(42)
59
- output = self.pipe(
60
- ref_image,
61
- num_inference_steps=num_steps,
62
- num_frames=num_frames,
63
- generator=generator,
64
- decode_chunk_size=2
65
- ).frames[0]
 
66
 
67
- # Save MP4 with ffmpeg
68
- temp_dir = tempfile.mkdtemp()
69
- for i, frame in enumerate(output):
70
- frame.save(f"{temp_dir}/frame_{i:04d}.png")
71
- temp_video = f"/tmp/output_{uuid.uuid4()}.mp4"
72
- subprocess.run([
73
- 'ffmpeg', '-y', '-framerate', '7', '-i', f"{temp_dir}/frame_%04d.png",
74
- '-c:v', 'libx264', '-pix_fmt', 'yuv420p', temp_video
75
- ], check=True)
76
- shutil.rmtree(temp_dir)
77
 
78
- return temp_video, "SUCCEEDED"
 
 
 
 
79
 
80
- except Exception as e:
81
- return None, f"Failed: {str(e)}"
 
 
 
 
 
 
 
82
 
83
  def start_app():
84
- app = WanAnimateApp()
85
 
86
- with gr.Blocks(title="Wan2.2-Animate (Local No API)") as demo:
87
- gr.HTML("""
88
- <div style="padding: 2rem; text-align: center; max-width: 1200px; margin: 0 auto; font-family: Arial, sans-serif;">
89
- <h1 style="font-size: 2.5rem; font-weight: bold; margin-bottom: 0.5rem; color: #333;">
90
- Wan2.2-Animate: Unified Character Animation and Replacement with Holistic Replication
91
- </h1>
92
- <h3 style="font-size: 1.5rem; font-weight: bold; margin-bottom: 0.5rem; color: #333;">
93
- Local version without API (SVD Proxy)
94
- </h3>
95
- <div style="font-size: 1.25rem; margin-bottom: 1.5rem; color: #555;">
96
- Tongyi Lab, Alibaba
97
- </div>
98
- <div style="display: flex; flex-wrap: wrap; justify-content: center; gap: 1rem; margin-bottom:
 
1
  import os
2
  import sys
 
 
3
  import time
 
4
  import torch
 
 
5
  import numpy as np
 
 
6
  import tempfile
7
+ from PIL import Image
8
+ from datetime import datetime
9
+ import gradio as gr
10
+ from torch import autocast
11
+ from pytorch_lightning import seed_everything
12
+ import torchvision.transforms as T
13
+ from diffusers import StableVideoDiffusionPipeline
14
+ from diffusers.utils import load_image, export_to_video
15
 
16
+ class WorldAnimate:
17
  def __init__(self):
18
+ model_id = "stabilityai/stable-video-diffusion-img2vid-xt"
19
  self.pipe = StableVideoDiffusionPipeline.from_pretrained(
20
+ model_id, torch_dtype=torch.float16, variant="fp16"
 
 
 
21
  )
22
+ self.pipe.enable_model_cpu_offload()
23
+ self.pipe.enable_vae_slicing()
24
+ self.pipe.unet.enable_forward_chunking(chunk_size=1, dim=1)
25
+ self.pipe.to("cuda" if torch.cuda.is_available() else "cpu")
26
+ torch.backends.cuda.matmul.allow_tf32 = True
27
 
28
+ def process_input(self, image, seed, num_frames, fps, decode_chunk_size, motion_bucket_id, noise_aug_strength):
29
+ if seed == -1:
30
+ seed = int.from_bytes(os.urandom(2), "big")
31
+ seed_everything(seed)
 
 
 
 
 
32
 
33
+ if isinstance(image, str):
34
+ image = load_image(image)
35
+ image = image.resize((1024, 576))
 
 
 
36
 
37
+ generator = torch.manual_seed(seed)
38
+ frames = self.pipe(
39
+ image,
40
+ num_frames=num_frames,
41
+ fps=fps,
42
+ decode_chunk_size=decode_chunk_size,
43
+ motion_bucket_id=motion_bucket_id,
44
+ noise_aug_strength=noise_aug_strength,
45
+ generator=generator,
46
+ ).frames[0]
47
 
48
+ return frames
 
 
 
 
49
 
50
+ def app():
51
+ with gr.Blocks(title="World 2.2 Animate (Local No API)") as demo:
52
+ gr.HTML("""
53
+ <h1 style="text-align: center; font-family: Arial; color: white;">World 2.2 Animate</h1>
54
+ <p style="text-align: center; font-family: Arial; color: white;">
55
+ This is a local processing app for image-to-video conversion using Stable Video Diffusion.<br>
56
+ Upload an image, adjust parameters, and generate a video with smooth motion.<br>
57
+ Parameters:<br>
58
+ - Seed: Random seed for reproducibility (-1 for random).<br>
59
+ - Num Frames: Number of frames in the video (default 25).<br>
60
+ - FPS: Frames per second (default 7).<br>
61
+ - Decode Chunk Size: For memory optimization (default 8).<br>
62
+ - Motion Bucket ID: Controls motion amount (1-255, default 127).<br>
63
+ - Noise Aug Strength: Adds noise for variation (0-1, default 0.02).
64
+ </p>
65
+ """) # Здесь закрываем строку правильно!
66
 
67
+ with gr.Row():
68
+ with gr.Column():
69
+ input_image = gr.Image(label="Upload Image", type="pil")
70
+ seed = gr.Number(label="Seed", value=-1)
71
+ num_frames = gr.Slider(label="Num Frames", minimum=1, maximum=25, value=25, step=1)
72
+ fps = gr.Slider(label="FPS", minimum=1, maximum=30, value=7, step=1)
73
+ decode_chunk_size = gr.Slider(label="Decode Chunk Size", minimum=1, maximum=16, value=8, step=1)
74
+ motion_bucket_id = gr.Slider(label="Motion Bucket ID", minimum=1, maximum=255, value=127, step=1)
75
+ noise_aug_strength = gr.Slider(label="Noise Aug Strength", minimum=0.0, maximum=1.0, value=0.02, step=0.01)
76
+ generate_btn = gr.Button(value="Generate Video")
77
 
78
+ with gr.Column():
79
+ output_video = gr.Video(label="Generated Video")
80
+ status = gr.Textbox(label="Status")
 
 
 
 
 
 
 
81
 
82
+ generate_btn.click(
83
+ fn=process,
84
+ inputs=[input_image, seed, num_frames, fps, decode_chunk_size, motion_bucket_id, noise_aug_strength],
85
+ outputs=[output_video, status]
86
+ )
87
 
88
+ def process(image, seed, num_frames, fps, decode_chunk_size, motion_bucket_id, noise_aug_strength):
89
+ try:
90
+ animator = WorldAnimate()
91
+ frames = animator.process_input(image, seed, num_frames, fps, decode_chunk_size, motion_bucket_id, noise_aug_strength)
92
+ with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as temp_video:
93
+ export_to_video(frames, temp_video.name, fps=fps)
94
+ return temp_video.name, "Success!"
95
+ except Exception as e:
96
+ return None, f"Failed: {str(e)}"
97
 
98
  def start_app():
99
+ app().launch()
100
 
101
+ if __name__ == "__main__":
102
+ start_app()