File size: 2,039 Bytes
40cfce6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
# handler.py
import subprocess, os, tempfile, shutil, uuid, base64
from typing import Dict

class Handler:
    def __init__(self, model_dir: str):
        self.model_dir = model_dir
        self.code_dir = os.path.join(model_dir, "Wan2.2")
        self.ckpt_dir = os.path.join(model_dir, "Wan2.2-S2V-14B")

    def __call__(self, inputs: Dict):
        prompt = inputs.get("prompt", "a person is talking")
        image_b64 = inputs.get("image_b64")
        audio_b64 = inputs.get("audio_b64")

        tmpd = tempfile.mkdtemp()
        try:
            image_path = os.path.join(tmpd, "input.jpg")
            audio_path = os.path.join(tmpd, "input.wav")

            if image_b64:
                with open(image_path, "wb") as f:
                    f.write(base64.b64decode(image_b64))
            if audio_b64:
                with open(audio_path, "wb") as f:
                    f.write(base64.b64decode(audio_b64))

            out_path = os.path.join(tmpd, f"out_{uuid.uuid4().hex}.mp4")

            cmd = [
                "python", "generate.py",
                "--task", "s2v-14B",
                "--size", "1024*704",
                "--ckpt_dir", self.ckpt_dir,
                "--offload_model", "True",
                "--convert_model_dtype",
                "--prompt", prompt,
                "--image", image_path,
                "--audio", audio_path,
                "--num_clip", "1"
            ]

            subprocess.check_call(cmd, cwd=self.code_dir)

            # Wan2.2 usually writes to outputs/, so adapt if needed
            if os.path.exists("outputs"):
                video_file = sorted(os.listdir("outputs"))[-1]
                with open(os.path.join("outputs", video_file), "rb") as f:
                    return {"video_b64": base64.b64encode(f.read()).decode("utf-8")}
            else:
                with open(out_path, "rb") as f:
                    return {"video_b64": base64.b64encode(f.read()).decode("utf-8")}

        finally:
            shutil.rmtree(tmpd, ignore_errors=True)