File size: 4,392 Bytes
c838058
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
499128d
c838058
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
483e424
c838058
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
031028a
 
c838058
 
 
031028a
 
 
 
 
 
 
 
 
 
 
 
 
c838058
031028a
c838058
 
031028a
c838058
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
"""
HF Spaces version of generate.py — same logic, paths adapted for /tmp/Wan2GP.
Called as a subprocess from app.py inside @spaces.GPU.
"""

import argparse
import os
import sys
from pathlib import Path

WAN2GP_ROOT = Path(os.environ.get("WAN2GP_ROOT", "/tmp/Wan2GP"))

MODEL_SHORTHANDS = {
    "sulphur-2": "sulphur_2_base",
}

DEFAULTS = {
    "sulphur_2_base": {
        "num_inference_steps": 8,
        "guidance_scale":      5.0,
        "resolution":          "832x480",
        "video_length":        81,
    },
}


def p(*args, **kwargs):
    print(*args, **kwargs, flush=True)


def parse_args():
    ap = argparse.ArgumentParser()
    ap.add_argument("--image",          required=True)
    ap.add_argument("--prompt",         required=True)
    ap.add_argument("--output",         required=True)
    ap.add_argument("--model",          default="sulphur-2")
    ap.add_argument("--steps",          type=int,   default=None)
    ap.add_argument("--guidance_scale", type=float, default=None)
    ap.add_argument("--frames",         type=int,   default=None)
    ap.add_argument("--resolution",     default=None)
    ap.add_argument("--seed",           type=int,   default=-1)
    return ap.parse_args()


def main():
    args = parse_args()

    model_type = MODEL_SHORTHANDS.get(args.model, args.model)
    defaults   = DEFAULTS.get(model_type, DEFAULTS["sulphur_2_base"])

    image_path = str(Path(args.image.strip()).resolve())
    if not Path(image_path).exists():
        print(f"Fatal: image not found: {image_path}", flush=True)
        sys.exit(1)

    resolution = args.resolution or defaults["resolution"]
    if not args.resolution:
        try:
            from PIL import Image as _PIL
            img = _PIL.open(image_path)
            iw, ih = img.size
            if ih > iw:
                tw = 480
                th = round(ih / iw * tw / 32) * 32
            else:
                th = 480
                tw = round(iw / ih * th / 32) * 32
            resolution = f"{tw}x{th}"
            p(f"Auto-detected resolution: {resolution} (from {iw}x{ih} input)")
        except Exception:
            pass

    task = {
        "model_type":           model_type,
        "base_model_type":      model_type,
        "prompt":               args.prompt,
        "image_start":          image_path,
        "num_inference_steps":  args.steps         or defaults["num_inference_steps"],
        "guidance_scale":       args.guidance_scale or defaults["guidance_scale"],
        "resolution":           resolution,
        "video_length":         args.frames        or defaults["video_length"],
        "seed":                 args.seed,
        "image_prompt_type":    "S",
        "input_video_strength": 1.0,
        "activated_loras": [
            "ltx-2.3-22b-distilled-lora-1.1_fro90_ceil72_condsafe.safetensors",
        ],
        "loras_multipliers": ["0.5"],
    }

    p(f"Model:      {model_type}")
    p(f"Image:      {image_path}")
    p(f"Steps:      {task['num_inference_steps']}  Guidance: {task['guidance_scale']}")
    p(f"Resolution: {task['resolution']}  Frames: {task['video_length']}")
    p(f"Prompt:     {args.prompt[:80]}")

    sys.path.insert(0, str(WAN2GP_ROOT))
    os.chdir(WAN2GP_ROOT)

    from shared.api import WanGPSession

    output_dir = Path(args.output).parent
    output_dir.mkdir(parents=True, exist_ok=True)

    p("Starting session...")
    session = WanGPSession(root=WAN2GP_ROOT, output_dir=output_dir, console_output=True)

    p("Running generation...")
    result = session.run_task(task)

    output_file = None

    if result.artifacts:
        src = result.artifacts[0].path
        if src and Path(src).exists():
            output_file = src

    # Fallback: scan the output dir for any video file Wan2GP may have written
    if output_file is None:
        candidates = sorted(output_dir.glob("**/*.mp4"), key=lambda f: f.stat().st_mtime, reverse=True)
        if candidates:
            output_file = str(candidates[0])
            p(f"Found output via dir scan: {output_file}")

    if output_file:
        import shutil
        shutil.copy2(output_file, args.output)
        p(f"Done: {args.output}")
    else:
        p(f"No output found in {output_dir}")
        if result.errors:
            p(f"Errors: {result.errors}")
        sys.exit(1)

    session.close()


if __name__ == "__main__":
    main()