Sulphur / generate.py
Daankular's picture
Pre-download Gemma at startup; bump guidance to 5.0 for prompt adherence
499128d
"""
HF Spaces version of generate.py — same logic, paths adapted for /tmp/Wan2GP.
Called as a subprocess from app.py inside @spaces.GPU.
"""
import argparse
import os
import sys
from pathlib import Path
WAN2GP_ROOT = Path(os.environ.get("WAN2GP_ROOT", "/tmp/Wan2GP"))
MODEL_SHORTHANDS = {
"sulphur-2": "sulphur_2_base",
}
DEFAULTS = {
"sulphur_2_base": {
"num_inference_steps": 8,
"guidance_scale": 5.0,
"resolution": "832x480",
"video_length": 81,
},
}
def p(*args, **kwargs):
print(*args, **kwargs, flush=True)
def parse_args():
ap = argparse.ArgumentParser()
ap.add_argument("--image", required=True)
ap.add_argument("--prompt", required=True)
ap.add_argument("--output", required=True)
ap.add_argument("--model", default="sulphur-2")
ap.add_argument("--steps", type=int, default=None)
ap.add_argument("--guidance_scale", type=float, default=None)
ap.add_argument("--frames", type=int, default=None)
ap.add_argument("--resolution", default=None)
ap.add_argument("--seed", type=int, default=-1)
return ap.parse_args()
def main():
args = parse_args()
model_type = MODEL_SHORTHANDS.get(args.model, args.model)
defaults = DEFAULTS.get(model_type, DEFAULTS["sulphur_2_base"])
image_path = str(Path(args.image.strip()).resolve())
if not Path(image_path).exists():
print(f"Fatal: image not found: {image_path}", flush=True)
sys.exit(1)
resolution = args.resolution or defaults["resolution"]
if not args.resolution:
try:
from PIL import Image as _PIL
img = _PIL.open(image_path)
iw, ih = img.size
if ih > iw:
tw = 480
th = round(ih / iw * tw / 32) * 32
else:
th = 480
tw = round(iw / ih * th / 32) * 32
resolution = f"{tw}x{th}"
p(f"Auto-detected resolution: {resolution} (from {iw}x{ih} input)")
except Exception:
pass
task = {
"model_type": model_type,
"base_model_type": model_type,
"prompt": args.prompt,
"image_start": image_path,
"num_inference_steps": args.steps or defaults["num_inference_steps"],
"guidance_scale": args.guidance_scale or defaults["guidance_scale"],
"resolution": resolution,
"video_length": args.frames or defaults["video_length"],
"seed": args.seed,
"image_prompt_type": "S",
"input_video_strength": 1.0,
"activated_loras": [
"ltx-2.3-22b-distilled-lora-1.1_fro90_ceil72_condsafe.safetensors",
],
"loras_multipliers": ["0.5"],
}
p(f"Model: {model_type}")
p(f"Image: {image_path}")
p(f"Steps: {task['num_inference_steps']} Guidance: {task['guidance_scale']}")
p(f"Resolution: {task['resolution']} Frames: {task['video_length']}")
p(f"Prompt: {args.prompt[:80]}")
sys.path.insert(0, str(WAN2GP_ROOT))
os.chdir(WAN2GP_ROOT)
from shared.api import WanGPSession
output_dir = Path(args.output).parent
output_dir.mkdir(parents=True, exist_ok=True)
p("Starting session...")
session = WanGPSession(root=WAN2GP_ROOT, output_dir=output_dir, console_output=True)
p("Running generation...")
result = session.run_task(task)
output_file = None
if result.artifacts:
src = result.artifacts[0].path
if src and Path(src).exists():
output_file = src
# Fallback: scan the output dir for any video file Wan2GP may have written
if output_file is None:
candidates = sorted(output_dir.glob("**/*.mp4"), key=lambda f: f.stat().st_mtime, reverse=True)
if candidates:
output_file = str(candidates[0])
p(f"Found output via dir scan: {output_file}")
if output_file:
import shutil
shutil.copy2(output_file, args.output)
p(f"Done: {args.output}")
else:
p(f"No output found in {output_dir}")
if result.errors:
p(f"Errors: {result.errors}")
sys.exit(1)
session.close()
if __name__ == "__main__":
main()