dippoo Claude Sonnet 4.6 commited on
Commit
e808ae1
·
1 Parent(s): 4340a68

Sync all local changes: video routes, pod management, wavespeed, UI updates

Browse files
config/models.yaml CHANGED
@@ -52,6 +52,35 @@ training_models:
52
  recommended_images: "15-30 high quality photos"
53
  training_script: "flux_train_network.py"
54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  # SD 1.5 Realistic Vision - Good balance of quality and speed
56
  sd15_realistic:
57
  name: "Realistic Vision V5.1"
 
52
  recommended_images: "15-30 high quality photos"
53
  training_script: "flux_train_network.py"
54
 
55
+ # WAN 2.2 - Text-to-Video LoRA training (14B params, uses musubi-tuner)
56
+ wan22_t2v:
57
+ name: "WAN 2.2 T2V (14B)"
58
+ description: "WAN 2.2 text-to-video model. Trains natural-looking video LoRAs. Requires A100 80GB."
59
+ model_type: "wan22"
60
+ training_framework: "musubi-tuner"
61
+ training_script: "wan_train_network.py"
62
+ network_module: "networks.lora_wan"
63
+ resolution: 512
64
+ learning_rate: 2e-4
65
+ network_rank: 64
66
+ network_alpha: 32
67
+ optimizer: "adamw8bit"
68
+ lr_scheduler: "constant"
69
+ timestep_sampling: "shift"
70
+ discrete_flow_shift: 5.0
71
+ gradient_checkpointing: true
72
+ max_train_steps: 2000
73
+ save_every_n_steps: 500
74
+ use_case: "images+video"
75
+ vram_required_gb: 48
76
+ recommended_gpu: "NVIDIA A100 80GB"
77
+ recommended_images: "20-50 high quality photos with detailed captions"
78
+ # Model paths on network volume:
79
+ # DiT low-noise: /workspace/models/WAN2.2/wan2.2_t2v_low_noise_14B_fp16.safetensors
80
+ # DiT high-noise: /workspace/models/WAN2.2/wan2.2_t2v_high_noise_14B_fp16.safetensors
81
+ # VAE: /workspace/models/WAN2.2/Wan2.1_VAE.pth
82
+ # T5: /workspace/models/WAN2.2/models_t5_umt5-xxl-enc-bf16.pth
83
+
84
  # SD 1.5 Realistic Vision - Good balance of quality and speed
85
  sd15_realistic:
86
  name: "Realistic Vision V5.1"
src/content_engine/api/routes_catalog.py CHANGED
@@ -117,9 +117,11 @@ async def serve_image_file(image_id: str):
117
  if not file_path.exists():
118
  raise HTTPException(404, f"Image file not found on disk")
119
 
 
 
120
  return FileResponse(
121
  file_path,
122
- media_type="image/png",
123
  headers={"Cache-Control": "public, max-age=3600"},
124
  )
125
 
 
117
  if not file_path.exists():
118
  raise HTTPException(404, f"Image file not found on disk")
119
 
120
+ ext = file_path.suffix.lower()
121
+ media_type = "video/mp4" if ext == ".mp4" else "video/webm" if ext == ".webm" else "image/png"
122
  return FileResponse(
123
  file_path,
124
+ media_type=media_type,
125
  headers={"Cache-Control": "public, max-age=3600"},
126
  )
127
 
src/content_engine/api/routes_generation.py CHANGED
@@ -169,12 +169,15 @@ async def generate_cloud(request: GenerationRequest):
169
 
170
  job_id = str(uuid.uuid4())
171
 
 
 
 
172
  asyncio.create_task(
173
  _run_cloud_generation(
174
  job_id=job_id,
175
  positive_prompt=request.positive_prompt or "",
176
  negative_prompt=request.negative_prompt or "",
177
- model=request.checkpoint, # Use checkpoint field for model selection
178
  width=request.width or 1024,
179
  height=request.height or 1024,
180
  seed=request.seed or -1,
@@ -182,6 +185,8 @@ async def generate_cloud(request: GenerationRequest):
182
  character_id=request.character_id,
183
  template_id=request.template_id,
184
  variables=request.variables,
 
 
185
  )
186
  )
187
 
@@ -399,6 +404,8 @@ async def _run_cloud_generation(
399
  character_id: str | None,
400
  template_id: str | None,
401
  variables: dict | None,
 
 
402
  ):
403
  """Background task to run a WaveSpeed cloud generation."""
404
  import time
@@ -451,6 +458,8 @@ async def _run_cloud_generation(
451
  width=width,
452
  height=height,
453
  seed=seed,
 
 
454
  )
455
 
456
  # Check if cancelled after API call
 
169
 
170
  job_id = str(uuid.uuid4())
171
 
172
+ lora_path = request.loras[0].name if request.loras else None
173
+ lora_strength = request.loras[0].strength_model if request.loras else 0.85
174
+
175
  asyncio.create_task(
176
  _run_cloud_generation(
177
  job_id=job_id,
178
  positive_prompt=request.positive_prompt or "",
179
  negative_prompt=request.negative_prompt or "",
180
+ model=request.checkpoint,
181
  width=request.width or 1024,
182
  height=request.height or 1024,
183
  seed=request.seed or -1,
 
185
  character_id=request.character_id,
186
  template_id=request.template_id,
187
  variables=request.variables,
188
+ lora_path=lora_path,
189
+ lora_strength=lora_strength,
190
  )
191
  )
192
 
 
404
  character_id: str | None,
405
  template_id: str | None,
406
  variables: dict | None,
407
+ lora_path: str | None = None,
408
+ lora_strength: float = 0.85,
409
  ):
410
  """Background task to run a WaveSpeed cloud generation."""
411
  import time
 
458
  width=width,
459
  height=height,
460
  seed=seed,
461
+ lora_name=lora_path,
462
+ lora_strength=lora_strength,
463
  )
464
 
465
  # Check if cancelled after API call
src/content_engine/api/routes_pod.py CHANGED
@@ -197,7 +197,9 @@ async def list_model_options():
197
  "models": {
198
  "flux2": {"name": "FLUX.2 Dev", "description": "Best for realistic txt2img (requires 48GB+ VRAM)", "use_case": "txt2img"},
199
  "flux1": {"name": "FLUX.1 Dev", "description": "Previous gen FLUX txt2img", "use_case": "txt2img"},
200
- "wan22": {"name": "WAN 2.2", "description": "Image-to-video and general generation", "use_case": "img2video"},
 
 
201
  }
202
  }
203
 
@@ -312,15 +314,45 @@ async def _wait_and_setup_pod(pod_id: str, model_type: str, timeout: int = 600):
312
  _pod_state["setup_status"] = "Connecting via SSH..."
313
 
314
  import paramiko
315
- ssh = paramiko.SSHClient()
316
- ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
317
 
318
  for attempt in range(30):
319
  try:
320
- await asyncio.to_thread(
321
- ssh.connect, ssh_host, port=int(ssh_port),
322
- username="root", password="runpod", timeout=10,
323
- )
324
  break
325
  except Exception:
326
  if attempt == 29:
@@ -329,9 +361,6 @@ async def _wait_and_setup_pod(pod_id: str, model_type: str, timeout: int = 600):
329
  return
330
  await asyncio.sleep(5)
331
 
332
- transport = ssh.get_transport()
333
- transport.set_keepalive(30)
334
-
335
  try:
336
  # Symlink network volume
337
  volume_id, _ = _get_volume_config()
@@ -404,22 +433,158 @@ print('Text encoder downloaded')
404
  await _ssh_exec_async(ssh, f"ln -sf /workspace/models/ae.safetensors {comfy_dir}/models/vae/ae.safetensors")
405
  await _ssh_exec_async(ssh, f"ln -sf /workspace/models/clip_l.safetensors {comfy_dir}/models/text_encoders/clip_l.safetensors")
406
  await _ssh_exec_async(ssh, f"ln -sf /workspace/models/t5xxl_fp16.safetensors {comfy_dir}/models/text_encoders/t5xxl_fp16.safetensors")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
407
  elif model_type == "wan22":
408
- # WAN 2.2 Image-to-Video (14B params)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
409
  wan_dir = "/workspace/models/Wan2.2-I2V-A14B"
410
  wan_exists = (await _ssh_exec_async(ssh, f"test -d {wan_dir} && echo EXISTS || echo MISSING")).strip()
411
  if wan_exists != "EXISTS":
412
- _pod_state["setup_status"] = "Downloading WAN 2.2 model (~28GB, first time only)..."
413
  await _ssh_exec_async(ssh, f"pip install huggingface_hub 2>&1 | tail -1", timeout=60)
414
  await _ssh_exec_async(ssh, f"""python -c "
415
  from huggingface_hub import snapshot_download
416
  snapshot_download('Wan-AI/Wan2.2-I2V-A14B', local_dir='{wan_dir}', ignore_patterns=['*.md', '*.txt'])
417
- print('WAN 2.2 downloaded')
418
  " 2>&1 | tail -10""", timeout=3600)
419
- # Symlink WAN model to ComfyUI diffusion_models dir
420
  await _ssh_exec_async(ssh, f"mkdir -p {comfy_dir}/models/diffusion_models")
421
  await _ssh_exec_async(ssh, f"ln -sf {wan_dir} {comfy_dir}/models/diffusion_models/Wan2.2-I2V-A14B")
422
- # Also need a VAE and text encoder for WAN — they use their own
423
  await _ssh_exec_async(ssh, f"ln -sf {wan_dir} {comfy_dir}/models/checkpoints/Wan2.2-I2V-A14B")
424
 
425
  # Install ComfyUI-WanVideoWrapper custom nodes
@@ -430,8 +595,121 @@ print('WAN 2.2 downloaded')
430
  await _ssh_exec_async(ssh, f"cd {comfy_dir}/custom_nodes && git clone --depth 1 https://github.com/kijai/ComfyUI-WanVideoWrapper.git", timeout=120)
431
  await _ssh_exec_async(ssh, f"cd {wan_nodes_dir} && pip install -r requirements.txt 2>&1 | tail -5", timeout=300)
432
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
433
  # Symlink all LoRAs from volume
434
- await _ssh_exec_async(ssh, f"ls /runpod-volume/loras/*.safetensors 2>/dev/null | while read f; do ln -sf \"$f\" {comfy_dir}/models/loras/; done")
435
 
436
  # Start ComfyUI in background (fire-and-forget — don't wait for output)
437
  _pod_state["setup_status"] = "Starting ComfyUI..."
@@ -500,6 +778,300 @@ def _ssh_exec_fire_and_forget(ssh, cmd: str):
500
  # Don't read stdout/stderr — just let it run
501
 
502
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
503
  @router.post("/stop")
504
  async def stop_pod():
505
  """Stop the GPU pod."""
@@ -561,8 +1133,8 @@ async def list_pod_loras():
561
 
562
  @router.post("/upload-lora")
563
  async def upload_lora_to_pod(file: UploadFile = File(...)):
564
- """Upload a LoRA file to the running pod."""
565
- import httpx
566
 
567
  if _pod_state["status"] != "running":
568
  raise HTTPException(400, "Pod not running - start it first")
@@ -570,24 +1142,77 @@ async def upload_lora_to_pod(file: UploadFile = File(...)):
570
  if not file.filename.endswith(".safetensors"):
571
  raise HTTPException(400, "Only .safetensors files supported")
572
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
573
  try:
574
- content = await file.read()
575
- comfyui_url = _get_comfyui_url()
 
 
 
 
576
 
577
- async with httpx.AsyncClient(timeout=120) as client:
578
- url = f"{comfyui_url}/upload/image"
579
- files = {"image": (file.filename, content, "application/octet-stream")}
580
- data = {"subfolder": "loras", "type": "input"}
581
- resp = await client.post(url, files=files, data=data)
582
 
583
- if resp.status_code == 200:
584
- return {"status": "uploaded", "filename": file.filename}
585
- else:
586
- raise HTTPException(500, f"Upload failed: {resp.text}")
 
587
 
588
- except httpx.TimeoutException:
589
- raise HTTPException(504, "Upload timed out")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
590
  except Exception as e:
 
591
  raise HTTPException(500, f"Upload failed: {e}")
592
 
593
 
@@ -601,6 +1226,8 @@ class PodGenerateRequest(BaseModel):
601
  seed: int = -1
602
  lora_name: str | None = None
603
  lora_strength: float = 0.85
 
 
604
  character_id: str | None = None
605
  template_id: str | None = None
606
  content_rating: str = "sfw"
@@ -623,18 +1250,33 @@ async def generate_on_pod(request: PodGenerateRequest):
623
  seed = request.seed if request.seed >= 0 else random.randint(0, 2**32 - 1)
624
 
625
  model_type = _pod_state.get("model_type", "flux2")
626
- workflow = _build_flux_workflow(
627
- prompt=request.prompt,
628
- negative_prompt=request.negative_prompt,
629
- width=request.width,
630
- height=request.height,
631
- steps=request.steps,
632
- cfg=request.cfg,
633
- seed=seed,
634
- lora_name=request.lora_name,
635
- lora_strength=request.lora_strength,
636
- model_type=model_type,
637
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
638
 
639
  comfyui_url = _get_comfyui_url()
640
 
@@ -939,3 +1581,249 @@ def _build_flux_workflow(
939
  workflow["7"]["inputs"]["clip"] = ["20", 1]
940
 
941
  return workflow
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
197
  "models": {
198
  "flux2": {"name": "FLUX.2 Dev", "description": "Best for realistic txt2img (requires 48GB+ VRAM)", "use_case": "txt2img"},
199
  "flux1": {"name": "FLUX.1 Dev", "description": "Previous gen FLUX txt2img", "use_case": "txt2img"},
200
+ "wan22": {"name": "WAN 2.2 Remix", "description": "Realistic generation — dual-DiT MoE split-step (NSFW OK)", "use_case": "txt2img"},
201
+ "wan22_i2v": {"name": "WAN 2.2 I2V", "description": "Image-to-video generation", "use_case": "img2video"},
202
+ "wan22_animate": {"name": "WAN 2.2 Animate", "description": "Dance/motion transfer — animate a character from a driving video", "use_case": "animate"},
203
  }
204
  }
205
 
 
314
  _pod_state["setup_status"] = "Connecting via SSH..."
315
 
316
  import paramiko
317
+
318
+ async def _ssh_connect_new() -> "paramiko.SSHClient":
319
+ """Create a fresh SSH connection to the pod."""
320
+ client = paramiko.SSHClient()
321
+ client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
322
+ for attempt in range(10):
323
+ try:
324
+ await asyncio.to_thread(
325
+ client.connect, ssh_host, port=int(ssh_port),
326
+ username="root", password="runpod", timeout=15,
327
+ banner_timeout=30,
328
+ )
329
+ client.get_transport().set_keepalive(30)
330
+ return client
331
+ except Exception:
332
+ if attempt == 9:
333
+ raise
334
+ await asyncio.sleep(5)
335
+ raise RuntimeError("SSH connection failed after retries")
336
+
337
+ async def _ssh_exec_r(cmd: str, timeout: int = 120) -> str:
338
+ """Execute SSH command, reconnecting once if the session dropped."""
339
+ nonlocal ssh
340
+ try:
341
+ t = ssh.get_transport()
342
+ if t is None or not t.is_active():
343
+ logger.info("SSH session dropped, reconnecting...")
344
+ ssh = await _ssh_connect_new()
345
+ return await _ssh_exec_async(ssh, cmd, timeout)
346
+ except Exception as e:
347
+ if "not active" in str(e).lower() or "session" in str(e).lower():
348
+ logger.info("SSH error '%s', reconnecting and retrying...", e)
349
+ ssh = await _ssh_connect_new()
350
+ return await _ssh_exec_async(ssh, cmd, timeout)
351
+ raise
352
 
353
  for attempt in range(30):
354
  try:
355
+ ssh = await _ssh_connect_new()
 
 
 
356
  break
357
  except Exception:
358
  if attempt == 29:
 
361
  return
362
  await asyncio.sleep(5)
363
 
 
 
 
364
  try:
365
  # Symlink network volume
366
  volume_id, _ = _get_volume_config()
 
433
  await _ssh_exec_async(ssh, f"ln -sf /workspace/models/ae.safetensors {comfy_dir}/models/vae/ae.safetensors")
434
  await _ssh_exec_async(ssh, f"ln -sf /workspace/models/clip_l.safetensors {comfy_dir}/models/text_encoders/clip_l.safetensors")
435
  await _ssh_exec_async(ssh, f"ln -sf /workspace/models/t5xxl_fp16.safetensors {comfy_dir}/models/text_encoders/t5xxl_fp16.safetensors")
436
+ elif model_type == "z_image":
437
+ # Z-Image Turbo — 6B param model by Tongyi-MAI, runs in 16GB VRAM
438
+ z_dir = "/runpod-volume/models/z_image"
439
+ await _ssh_exec_async(ssh, f"mkdir -p {z_dir}")
440
+ await _ssh_exec_async(ssh, "pip install huggingface_hub 2>&1 | tail -1", timeout=60)
441
+
442
+ # Delete FLUX.2 from volume to free space
443
+ _pod_state["setup_status"] = "Cleaning up FLUX.2 from volume..."
444
+ await _ssh_exec_async(ssh, "rm -rf /runpod-volume/models/FLUX.2-dev /runpod-volume/models/mistral_3_small_flux2_fp8.safetensors 2>/dev/null; echo done")
445
+
446
+ # Download diffusion model (~12GB)
447
+ diff_model = f"{z_dir}/z_image_turbo_bf16.safetensors"
448
+ exists = (await _ssh_exec_async(ssh, f"test -f {diff_model} && echo EXISTS || echo MISSING")).strip()
449
+ if exists != "EXISTS":
450
+ _pod_state["setup_status"] = "Downloading Z-Image Turbo diffusion model (~12GB)..."
451
+ await _ssh_exec_async(ssh, f"""python -c "
452
+ from huggingface_hub import hf_hub_download
453
+ import shutil, os
454
+ p = hf_hub_download('Comfy-Org/z_image_turbo', 'split_files/diffusion_models/z_image_turbo_bf16.safetensors', local_dir='/tmp/z_image')
455
+ shutil.move(p, '{diff_model}')
456
+ print('Diffusion model downloaded')
457
+ " 2>&1 | tail -5""", timeout=3600)
458
+
459
+ # Download text encoder (~8GB Qwen 3 4B)
460
+ te_model = f"{z_dir}/qwen_3_4b.safetensors"
461
+ exists = (await _ssh_exec_async(ssh, f"test -f {te_model} && echo EXISTS || echo MISSING")).strip()
462
+ if exists != "EXISTS":
463
+ _pod_state["setup_status"] = "Downloading Z-Image text encoder (~8GB)..."
464
+ await _ssh_exec_async(ssh, f"""python -c "
465
+ from huggingface_hub import hf_hub_download
466
+ import shutil
467
+ p = hf_hub_download('Comfy-Org/z_image_turbo', 'split_files/text_encoders/qwen_3_4b.safetensors', local_dir='/tmp/z_image')
468
+ shutil.move(p, '{te_model}')
469
+ print('Text encoder downloaded')
470
+ " 2>&1 | tail -5""", timeout=3600)
471
+
472
+ # Download VAE (~335MB)
473
+ vae_model = f"{z_dir}/ae.safetensors"
474
+ exists = (await _ssh_exec_async(ssh, f"test -f {vae_model} && echo EXISTS || echo MISSING")).strip()
475
+ if exists != "EXISTS":
476
+ _pod_state["setup_status"] = "Downloading Z-Image VAE..."
477
+ await _ssh_exec_async(ssh, f"""python -c "
478
+ from huggingface_hub import hf_hub_download
479
+ import shutil
480
+ p = hf_hub_download('Comfy-Org/z_image_turbo', 'split_files/vae/ae.safetensors', local_dir='/tmp/z_image')
481
+ shutil.move(p, '{vae_model}')
482
+ print('VAE downloaded')
483
+ " 2>&1 | tail -5""", timeout=600)
484
+
485
+ # Symlink into ComfyUI directories
486
+ await _ssh_exec_async(ssh, f"mkdir -p {comfy_dir}/models/diffusion_models {comfy_dir}/models/text_encoders {comfy_dir}/models/vae")
487
+ await _ssh_exec_async(ssh, f"ln -sf {diff_model} {comfy_dir}/models/diffusion_models/z_image_turbo_bf16.safetensors")
488
+ await _ssh_exec_async(ssh, f"ln -sf {te_model} {comfy_dir}/models/text_encoders/qwen_3_4b.safetensors")
489
+ await _ssh_exec_async(ssh, f"ln -sf {vae_model} {comfy_dir}/models/vae/ae_z_image.safetensors")
490
+
491
  elif model_type == "wan22":
492
+ # WAN 2.2 Remix NSFW — dual-DiT MoE split-step for realistic generation
493
+ wan_dir = "/workspace/models/WAN2.2"
494
+ await _ssh_exec_async(ssh, f"mkdir -p {wan_dir}")
495
+
496
+ civitai_token = os.environ.get("CIVITAI_API_TOKEN", "")
497
+ token_param = f"&token={civitai_token}" if civitai_token else ""
498
+
499
+ # CivitAI Remix models (fp8 ~14GB each)
500
+ civitai_models = {
501
+ "Remix T2V High-noise": {
502
+ "path": f"{wan_dir}/wan22_remix_t2v_high_fp8.safetensors",
503
+ "url": f"https://civitai.com/api/download/models/2424167?type=Model&format=SafeTensor&size=pruned{token_param}",
504
+ },
505
+ "Remix T2V Low-noise": {
506
+ "path": f"{wan_dir}/wan22_remix_t2v_low_fp8.safetensors",
507
+ "url": f"https://civitai.com/api/download/models/2424912?type=Model&format=SafeTensor&size=pruned{token_param}",
508
+ },
509
+ }
510
+
511
+ # HuggingFace models (T5 fp8 ~7GB, VAE ~1GB)
512
+ hf_models = {
513
+ "T5 text encoder (fp8)": {
514
+ "path": f"{wan_dir}/umt5_xxl_fp8_e4m3fn_scaled.safetensors",
515
+ "repo": "Comfy-Org/Wan_2.2_ComfyUI_Repackaged",
516
+ "filename": "split_files/text_encoders/umt5_xxl_fp8_e4m3fn_scaled.safetensors",
517
+ },
518
+ "VAE": {
519
+ "path": f"{wan_dir}/wan_2.1_vae.safetensors",
520
+ "repo": "Comfy-Org/Wan_2.2_ComfyUI_Repackaged",
521
+ "filename": "split_files/vae/wan_2.1_vae.safetensors",
522
+ },
523
+ }
524
+
525
+ # Download CivitAI Remix models
526
+ for label, info in civitai_models.items():
527
+ exists = (await _ssh_exec_async(ssh, f"test -f {info['path']} && echo EXISTS || echo MISSING")).strip()
528
+ if exists == "EXISTS":
529
+ logger.info("WAN 2.2 %s already cached", label)
530
+ else:
531
+ _pod_state["setup_status"] = f"Downloading {label} (~14GB)..."
532
+ await _ssh_exec_async(ssh, f"wget -q -O '{info['path']}' '{info['url']}'", timeout=1800)
533
+ # Verify download
534
+ check = (await _ssh_exec_async(ssh, f"test -f {info['path']} && stat -c%s {info['path']} || echo 0")).strip()
535
+ if check == "0" or int(check) < 1000000:
536
+ logger.error("Failed to download %s (size: %s). CivitAI API token may be required.", label, check)
537
+ _pod_state["setup_status"] = f"Failed: {label} download failed. Set CIVITAI_API_TOKEN env var for NSFW models."
538
+ return
539
+
540
+ # Download HuggingFace models
541
+ await _ssh_exec_async(ssh, "pip install huggingface_hub 2>&1 | tail -1", timeout=60)
542
+ for label, info in hf_models.items():
543
+ exists = (await _ssh_exec_async(ssh, f"test -f {info['path']} && echo EXISTS || echo MISSING")).strip()
544
+ if exists == "EXISTS":
545
+ logger.info("WAN 2.2 %s already cached", label)
546
+ else:
547
+ _pod_state["setup_status"] = f"Downloading {label}..."
548
+ await _ssh_exec_async(ssh, f"""python -c "
549
+ from huggingface_hub import hf_hub_download
550
+ import os, shutil
551
+ hf_hub_download('{info['repo']}', '{info['filename']}', local_dir='{wan_dir}')
552
+ downloaded = os.path.join('{wan_dir}', '{info['filename']}')
553
+ target = '{info['path']}'
554
+ if os.path.exists(downloaded) and downloaded != target:
555
+ os.makedirs(os.path.dirname(target), exist_ok=True)
556
+ shutil.move(downloaded, target)
557
+ print('Downloaded {label}')
558
+ " 2>&1 | tail -5""", timeout=1800)
559
+
560
+ # Symlink models into ComfyUI
561
+ await _ssh_exec_async(ssh, f"mkdir -p {comfy_dir}/models/diffusion_models {comfy_dir}/models/text_encoders")
562
+ await _ssh_exec_async(ssh, f"ln -sf {wan_dir}/wan22_remix_t2v_high_fp8.safetensors {comfy_dir}/models/diffusion_models/")
563
+ await _ssh_exec_async(ssh, f"ln -sf {wan_dir}/wan22_remix_t2v_low_fp8.safetensors {comfy_dir}/models/diffusion_models/")
564
+ await _ssh_exec_async(ssh, f"ln -sf {wan_dir}/wan_2.1_vae.safetensors {comfy_dir}/models/vae/")
565
+ await _ssh_exec_async(ssh, f"ln -sf {wan_dir}/umt5_xxl_fp8_e4m3fn_scaled.safetensors {comfy_dir}/models/text_encoders/")
566
+
567
+ # Install wanBlockSwap custom node (VRAM optimization for dual-DiT on 24GB GPUs)
568
+ _pod_state["setup_status"] = "Installing WAN 2.2 custom nodes..."
569
+ blockswap_dir = f"{comfy_dir}/custom_nodes/ComfyUI-wanBlockswap"
570
+ blockswap_exists = (await _ssh_exec_async(ssh, f"test -d {blockswap_dir} && echo EXISTS || echo MISSING")).strip()
571
+ if blockswap_exists != "EXISTS":
572
+ await _ssh_exec_async(ssh, f"cd {comfy_dir}/custom_nodes && git clone --depth 1 https://github.com/orssorbit/ComfyUI-wanBlockswap.git", timeout=120)
573
+
574
+ elif model_type == "wan22_i2v":
575
+ # WAN 2.2 Image-to-Video (14B params) — full model snapshot
576
  wan_dir = "/workspace/models/Wan2.2-I2V-A14B"
577
  wan_exists = (await _ssh_exec_async(ssh, f"test -d {wan_dir} && echo EXISTS || echo MISSING")).strip()
578
  if wan_exists != "EXISTS":
579
+ _pod_state["setup_status"] = "Downloading WAN 2.2 I2V model (~28GB, first time only)..."
580
  await _ssh_exec_async(ssh, f"pip install huggingface_hub 2>&1 | tail -1", timeout=60)
581
  await _ssh_exec_async(ssh, f"""python -c "
582
  from huggingface_hub import snapshot_download
583
  snapshot_download('Wan-AI/Wan2.2-I2V-A14B', local_dir='{wan_dir}', ignore_patterns=['*.md', '*.txt'])
584
+ print('WAN 2.2 I2V downloaded')
585
  " 2>&1 | tail -10""", timeout=3600)
 
586
  await _ssh_exec_async(ssh, f"mkdir -p {comfy_dir}/models/diffusion_models")
587
  await _ssh_exec_async(ssh, f"ln -sf {wan_dir} {comfy_dir}/models/diffusion_models/Wan2.2-I2V-A14B")
 
588
  await _ssh_exec_async(ssh, f"ln -sf {wan_dir} {comfy_dir}/models/checkpoints/Wan2.2-I2V-A14B")
589
 
590
  # Install ComfyUI-WanVideoWrapper custom nodes
 
595
  await _ssh_exec_async(ssh, f"cd {comfy_dir}/custom_nodes && git clone --depth 1 https://github.com/kijai/ComfyUI-WanVideoWrapper.git", timeout=120)
596
  await _ssh_exec_async(ssh, f"cd {wan_nodes_dir} && pip install -r requirements.txt 2>&1 | tail -5", timeout=300)
597
 
598
+ elif model_type == "wan22_animate":
599
+ # WAN 2.2 Animate (14B fp8) — dance/motion transfer via pose skeleton
600
+ animate_dir = "/workspace/models/WAN2.2-Animate"
601
+ wan22_dir = "/workspace/models/WAN2.2"
602
+ await _ssh_exec_async(ssh, f"mkdir -p {animate_dir}")
603
+ await _ssh_exec_async(ssh, "pip install huggingface_hub 2>&1 | tail -1", timeout=60)
604
+
605
+ # Download main Animate model (~28GB bf16 — only version available)
606
+ animate_model = f"{animate_dir}/wan2.2_animate_14B_bf16.safetensors"
607
+ exists = (await _ssh_exec_async(ssh, f"test -f {animate_model} && echo EXISTS || echo MISSING")).strip()
608
+ if exists != "EXISTS":
609
+ _pod_state["setup_status"] = "Downloading WAN 2.2 Animate model (~28GB, first time only)..."
610
+ await _ssh_exec_async(ssh, f"""python -c "
611
+ from huggingface_hub import hf_hub_download
612
+ import os, shutil
613
+ hf_hub_download('Comfy-Org/Wan_2.2_ComfyUI_Repackaged', 'split_files/diffusion_models/wan2.2_animate_14B_bf16.safetensors', local_dir='{animate_dir}')
614
+ src = os.path.join('{animate_dir}', 'split_files', 'diffusion_models', 'wan2.2_animate_14B_bf16.safetensors')
615
+ if os.path.exists(src):
616
+ shutil.move(src, '{animate_model}')
617
+ print('Animate model downloaded')
618
+ " 2>&1 | tail -5""", timeout=7200)
619
+
620
+ # CLIP Vision H (~2.5GB) — ViT-H vision encoder
621
+ clip_vision_target = f"{animate_dir}/clip_vision_h.safetensors"
622
+ exists = (await _ssh_exec_async(ssh, f"test -f {clip_vision_target} && echo EXISTS || echo MISSING")).strip()
623
+ if exists != "EXISTS":
624
+ _pod_state["setup_status"] = "Downloading CLIP Vision H (~2.5GB)..."
625
+ await _ssh_exec_async(ssh, f"""python -c "
626
+ from huggingface_hub import hf_hub_download
627
+ import os, shutil
628
+ result = hf_hub_download('h94/IP-Adapter', 'models/image_encoder/model.safetensors', local_dir='{animate_dir}/tmp_clip')
629
+ shutil.move(result, '{clip_vision_target}')
630
+ shutil.rmtree('{animate_dir}/tmp_clip', ignore_errors=True)
631
+ print('CLIP Vision H downloaded')
632
+ " 2>&1 | tail -5""", timeout=1800)
633
+
634
+ # VAE — reuse from WAN2.2 dir if available, else download (~1GB)
635
+ vae_target = f"{animate_dir}/wan_2.1_vae.safetensors"
636
+ exists = (await _ssh_exec_async(ssh, f"test -f {vae_target} && echo EXISTS || echo MISSING")).strip()
637
+ if exists != "EXISTS":
638
+ vae_from_wan22 = (await _ssh_exec_async(ssh, f"test -f {wan22_dir}/wan_2.1_vae.safetensors && echo EXISTS || echo MISSING")).strip()
639
+ if vae_from_wan22 == "EXISTS":
640
+ await _ssh_exec_async(ssh, f"ln -sf {wan22_dir}/wan_2.1_vae.safetensors {vae_target}")
641
+ else:
642
+ _pod_state["setup_status"] = "Downloading VAE (~1GB)..."
643
+ await _ssh_exec_async(ssh, f"""python -c "
644
+ from huggingface_hub import hf_hub_download
645
+ import os, shutil
646
+ hf_hub_download('Comfy-Org/Wan_2.2_ComfyUI_Repackaged', 'split_files/vae/wan_2.1_vae.safetensors', local_dir='{animate_dir}')
647
+ src = os.path.join('{animate_dir}', 'split_files', 'vae', 'wan_2.1_vae.safetensors')
648
+ if os.path.exists(src):
649
+ shutil.move(src, '{vae_target}')
650
+ print('VAE downloaded')
651
+ " 2>&1 | tail -5""", timeout=600)
652
+
653
+ # UMT5 T5 encoder fp8 (non-scaled) — use Kijai/WanVideo_comfy version
654
+ # which is compatible with LoadWanVideoT5TextEncoder (scaled_fp8 is not supported)
655
+ t5_filename = "umt5-xxl-enc-fp8_e4m3fn.safetensors"
656
+ t5_target = f"{animate_dir}/{t5_filename}"
657
+ t5_comfy_path = f"{comfy_dir}/models/text_encoders/{t5_filename}"
658
+ t5_in_comfy = (await _ssh_exec_async(ssh, f"test -f {t5_comfy_path} && echo EXISTS || echo MISSING")).strip()
659
+ t5_in_vol = (await _ssh_exec_async(ssh, f"test -f {t5_target} && echo EXISTS || echo MISSING")).strip()
660
+ if t5_in_comfy != "EXISTS" and t5_in_vol != "EXISTS":
661
+ _pod_state["setup_status"] = "Downloading UMT5 text encoder (~6.3GB, first time only)..."
662
+ await _ssh_exec_async(ssh, f"""python -c "
663
+ from huggingface_hub import hf_hub_download
664
+ hf_hub_download('Kijai/WanVideo_comfy', '{t5_filename}', local_dir='{animate_dir}')
665
+ print('UMT5 text encoder downloaded')
666
+ " 2>&1 | tail -5""", timeout=1800)
667
+ t5_in_vol = "EXISTS"
668
+
669
+ # Symlink models into ComfyUI directories
670
+ await _ssh_exec_async(ssh, f"mkdir -p {comfy_dir}/models/diffusion_models {comfy_dir}/models/vae {comfy_dir}/models/clip_vision {comfy_dir}/models/text_encoders")
671
+ await _ssh_exec_async(ssh, f"ln -sf {animate_model} {comfy_dir}/models/diffusion_models/")
672
+ await _ssh_exec_async(ssh, f"ln -sf {vae_target} {comfy_dir}/models/vae/")
673
+ await _ssh_exec_async(ssh, f"ln -sf {clip_vision_target} {comfy_dir}/models/clip_vision/")
674
+ if t5_in_vol == "EXISTS" and t5_in_comfy != "EXISTS":
675
+ await _ssh_exec_async(ssh, f"ln -sf {t5_target} {t5_comfy_path}")
676
+
677
+ # Reconnect SSH before custom node setup — connection may have dropped during long downloads
678
+ ssh = await _ssh_connect_new()
679
+
680
+ # Install required custom nodes
681
+ _pod_state["setup_status"] = "Installing WAN Animate custom nodes..."
682
+
683
+ # ComfyUI-WanVideoWrapper (WanVideoAnimateEmbeds, WanVideoSampler, etc.)
684
+ wan_nodes_dir = f"{comfy_dir}/custom_nodes/ComfyUI-WanVideoWrapper"
685
+ exists = (await _ssh_exec_r(f"test -d {wan_nodes_dir} && echo EXISTS || echo MISSING")).strip()
686
+ if exists != "EXISTS":
687
+ await _ssh_exec_r(f"cd {comfy_dir}/custom_nodes && git clone --depth 1 https://github.com/kijai/ComfyUI-WanVideoWrapper.git", timeout=120)
688
+ await _ssh_exec_r(f"cd {wan_nodes_dir} && pip install -r requirements.txt 2>&1 | tail -5", timeout=300)
689
+
690
+ # ComfyUI-VideoHelperSuite (VHS_LoadVideo, VHS_VideoCombine)
691
+ vhs_dir = f"{comfy_dir}/custom_nodes/ComfyUI-VideoHelperSuite"
692
+ exists = (await _ssh_exec_r(f"test -d {vhs_dir} && echo EXISTS || echo MISSING")).strip()
693
+ if exists != "EXISTS":
694
+ await _ssh_exec_r(f"cd {comfy_dir}/custom_nodes && git clone --depth 1 https://github.com/Kosinkadink/ComfyUI-VideoHelperSuite.git", timeout=120)
695
+ await _ssh_exec_r(f"cd {vhs_dir} && pip install -r requirements.txt 2>&1 | tail -5", timeout=300)
696
+
697
+ # comfyui_controlnet_aux (DWPreprocessor for pose extraction)
698
+ aux_dir = f"{comfy_dir}/custom_nodes/comfyui_controlnet_aux"
699
+ exists = (await _ssh_exec_r(f"test -d {aux_dir} && echo EXISTS || echo MISSING")).strip()
700
+ if exists != "EXISTS":
701
+ await _ssh_exec_r(f"cd {comfy_dir}/custom_nodes && git clone --depth 1 https://github.com/Fannovel16/comfyui_controlnet_aux.git", timeout=120)
702
+ await _ssh_exec_r(f"cd {aux_dir} && pip install -r requirements.txt 2>&1 | tail -5", timeout=300)
703
+
704
+ # ComfyUI-KJNodes (ImageResizeKJv2 used in animate workflow)
705
+ kj_dir = f"{comfy_dir}/custom_nodes/ComfyUI-KJNodes"
706
+ exists = (await _ssh_exec_r(f"test -d {kj_dir} && echo EXISTS || echo MISSING")).strip()
707
+ if exists != "EXISTS":
708
+ await _ssh_exec_r(f"cd {comfy_dir}/custom_nodes && git clone --depth 1 https://github.com/kijai/ComfyUI-KJNodes.git", timeout=120)
709
+ await _ssh_exec_r(f"cd {kj_dir} && pip install -r requirements.txt 2>&1 | tail -5", timeout=300)
710
+
711
  # Symlink all LoRAs from volume
712
+ await _ssh_exec_r(f"ls /runpod-volume/loras/*.safetensors 2>/dev/null | while read f; do ln -sf \"$f\" {comfy_dir}/models/loras/; done")
713
 
714
  # Start ComfyUI in background (fire-and-forget — don't wait for output)
715
  _pod_state["setup_status"] = "Starting ComfyUI..."
 
778
  # Don't read stdout/stderr — just let it run
779
 
780
 
781
+ # --- Pre-download models to network volume (saves money during training) ---
782
+
783
+ _download_state = {
784
+ "status": "idle", # idle, downloading, completed, failed
785
+ "pod_id": None,
786
+ "progress": "",
787
+ "error": None,
788
+ }
789
+
790
+
791
+ class DownloadModelsRequest(BaseModel):
792
+ model_type: str = "wan22"
793
+ gpu_type: str = "NVIDIA GeForce RTX 3090" # Cheapest GPU, just for downloading
794
+
795
+
796
+ @router.post("/download-models")
797
+ async def download_models_to_volume(request: DownloadModelsRequest):
798
+ """Pre-download model files to network volume using a cheap pod.
799
+
800
+ This saves expensive GPU time during training — models are cached on the
801
+ shared volume and reused across all future training/generation pods.
802
+ """
803
+ _get_api_key()
804
+
805
+ volume_id, volume_dc = _get_volume_config()
806
+ if not volume_id:
807
+ raise HTTPException(400, "No network volume configured (set RUNPOD_VOLUME_ID)")
808
+
809
+ if _download_state["status"] == "downloading":
810
+ return {"status": "already_downloading", "progress": _download_state["progress"]}
811
+
812
+ _download_state["status"] = "downloading"
813
+ _download_state["progress"] = "Creating cheap download pod..."
814
+ _download_state["error"] = None
815
+
816
+ asyncio.create_task(_download_models_task(request.model_type, request.gpu_type, volume_id, volume_dc))
817
+
818
+ return {"status": "started", "message": f"Downloading {request.model_type} models to volume (using {request.gpu_type})"}
819
+
820
+
821
+ @router.get("/download-models/status")
822
+ async def download_models_status():
823
+ """Check model download progress."""
824
+ return _download_state
825
+
826
+
827
+ async def _download_models_task(model_type: str, gpu_type: str, volume_id: str, volume_dc: str):
828
+ """Background task: spin up cheap pod, download models, terminate."""
829
+ import paramiko
830
+ ssh = None
831
+ pod_id = None
832
+
833
+ try:
834
+ # Create cheap pod with network volume — try multiple GPU types if first unavailable
835
+ pod_kwargs = {
836
+ "container_disk_in_gb": 10,
837
+ "ports": "22/tcp",
838
+ "network_volume_id": volume_id,
839
+ "docker_args": "bash -c 'apt-get update && apt-get install -y openssh-server && mkdir -p /run/sshd && echo root:runpod | chpasswd && /usr/sbin/sshd -o PermitRootLogin=yes && sleep infinity'",
840
+ }
841
+ if volume_dc:
842
+ pod_kwargs["data_center_id"] = volume_dc
843
+
844
+ gpu_fallbacks = [
845
+ gpu_type,
846
+ "NVIDIA RTX A4000",
847
+ "NVIDIA RTX A5000",
848
+ "NVIDIA GeForce RTX 4090",
849
+ "NVIDIA GeForce RTX 4080",
850
+ "NVIDIA A100-SXM4-80GB",
851
+ ]
852
+ pod = None
853
+ used_gpu = gpu_type
854
+ for try_gpu in gpu_fallbacks:
855
+ try:
856
+ pod = await asyncio.to_thread(
857
+ runpod.create_pod,
858
+ f"model-download-{model_type}",
859
+ DOCKER_IMAGE,
860
+ try_gpu,
861
+ **pod_kwargs,
862
+ )
863
+ used_gpu = try_gpu
864
+ logger.info("Download pod created with %s", try_gpu)
865
+ break
866
+ except Exception as e:
867
+ if "SUPPLY_CONSTRAINT" in str(e) or "no longer any instances" in str(e).lower():
868
+ logger.info("GPU %s unavailable, trying next...", try_gpu)
869
+ continue
870
+ raise
871
+ if pod is None:
872
+ raise RuntimeError("No GPU available for download pod. Try again later.")
873
+ pod_id = pod["id"]
874
+ _download_state["pod_id"] = pod_id
875
+ _download_state["progress"] = f"Pod created with {used_gpu} ({pod_id}), waiting for SSH..."
876
+
877
+ # Wait for SSH
878
+ ssh_host = ssh_port = None
879
+ start = time.time()
880
+ while time.time() - start < 300:
881
+ try:
882
+ p = await asyncio.to_thread(runpod.get_pod, pod_id)
883
+ if p and p.get("desiredStatus") == "RUNNING":
884
+ for port in (p.get("runtime") or {}).get("ports") or []:
885
+ if port.get("privatePort") == 22:
886
+ ssh_host = port.get("ip")
887
+ ssh_port = port.get("publicPort")
888
+ if ssh_host and ssh_port:
889
+ break
890
+ except Exception:
891
+ pass
892
+ await asyncio.sleep(5)
893
+
894
+ if not ssh_host:
895
+ raise RuntimeError("Pod SSH not available after 5 min")
896
+
897
+ # Connect SSH
898
+ ssh = paramiko.SSHClient()
899
+ ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
900
+ for attempt in range(20):
901
+ try:
902
+ await asyncio.to_thread(ssh.connect, ssh_host, port=int(ssh_port), username="root", password="runpod", timeout=10)
903
+ break
904
+ except Exception:
905
+ if attempt == 19:
906
+ raise RuntimeError("SSH connection failed after 20 attempts")
907
+ await asyncio.sleep(5)
908
+
909
+ ssh.get_transport().set_keepalive(30)
910
+ _download_state["progress"] = "SSH connected, setting up tools..."
911
+
912
+ # Symlink volume
913
+ await _ssh_exec_async(ssh, "mkdir -p /runpod-volume/models && rm -rf /workspace/models 2>/dev/null; ln -sf /runpod-volume/models /workspace/models")
914
+ await _ssh_exec_async(ssh, "pip install huggingface_hub 2>&1 | tail -1", timeout=120)
915
+ await _ssh_exec_async(ssh, "which aria2c || apt-get install -y aria2 2>&1 | tail -1", timeout=120)
916
+
917
+ if model_type == "wan22":
918
+ wan_dir = "/workspace/models/WAN2.2"
919
+ await _ssh_exec_async(ssh, f"mkdir -p {wan_dir}")
920
+
921
+ civitai_token = os.environ.get("CIVITAI_API_TOKEN", "")
922
+ token_param = f"&token={civitai_token}" if civitai_token else ""
923
+
924
+ # CivitAI Remix models (fp8)
925
+ civitai_files = [
926
+ ("Remix T2V High-noise", f"https://civitai.com/api/download/models/2424167?type=Model&format=SafeTensor&size=pruned{token_param}", f"{wan_dir}/wan22_remix_t2v_high_fp8.safetensors"),
927
+ ("Remix T2V Low-noise", f"https://civitai.com/api/download/models/2424912?type=Model&format=SafeTensor&size=pruned{token_param}", f"{wan_dir}/wan22_remix_t2v_low_fp8.safetensors"),
928
+ ]
929
+
930
+ # HuggingFace models
931
+ hf_files = [
932
+ ("T5 text encoder (fp8)", "Comfy-Org/Wan_2.2_ComfyUI_Repackaged", "split_files/text_encoders/umt5_xxl_fp8_e4m3fn_scaled.safetensors", f"{wan_dir}/umt5_xxl_fp8_e4m3fn_scaled.safetensors"),
933
+ ("VAE", "Comfy-Org/Wan_2.2_ComfyUI_Repackaged", "split_files/vae/wan_2.1_vae.safetensors", f"{wan_dir}/wan_2.1_vae.safetensors"),
934
+ ]
935
+
936
+ total = len(civitai_files) + len(hf_files)
937
+ idx = 0
938
+
939
+ for label, url, target in civitai_files:
940
+ idx += 1
941
+ exists = (await _ssh_exec_async(ssh, f"test -f {target} && echo EXISTS || echo MISSING")).strip()
942
+ if exists == "EXISTS":
943
+ _download_state["progress"] = f"[{idx}/{total}] {label} already cached"
944
+ logger.info("WAN 2.2 %s already on volume", label)
945
+ else:
946
+ _download_state["progress"] = f"[{idx}/{total}] Downloading {label} (~14GB)..."
947
+ await _ssh_exec_async(ssh, f"wget -q -O '{target}' '{url}'", timeout=1800)
948
+ check = (await _ssh_exec_async(ssh, f"test -f {target} && stat -c%s {target} || echo 0")).strip()
949
+ if check == "0" or int(check) < 1000000:
950
+ raise RuntimeError(f"Failed to download {label}. Set CIVITAI_API_TOKEN for NSFW models.")
951
+ _download_state["progress"] = f"[{idx}/{total}] {label} downloaded"
952
+
953
+ for label, repo, filename, target in hf_files:
954
+ idx += 1
955
+ exists = (await _ssh_exec_async(ssh, f"test -f {target} && echo EXISTS || echo MISSING")).strip()
956
+ if exists == "EXISTS":
957
+ _download_state["progress"] = f"[{idx}/{total}] {label} already cached"
958
+ logger.info("WAN 2.2 %s already on volume", label)
959
+ else:
960
+ _download_state["progress"] = f"[{idx}/{total}] Downloading {label}..."
961
+ hf_url = f"https://huggingface.co/{repo}/resolve/main/{filename}"
962
+ fname = target.split("/")[-1]
963
+ tdir = "/".join(target.split("/")[:-1])
964
+ await _ssh_exec_async(ssh, f"aria2c -x 16 -s 16 -c -o '{fname}' --dir='{tdir}' '{hf_url}' 2>&1 | tail -3", timeout=1800)
965
+ check = (await _ssh_exec_async(ssh, f"test -f {target} && echo EXISTS || echo MISSING")).strip()
966
+ if check != "EXISTS":
967
+ raise RuntimeError(f"Failed to download {label}")
968
+ _download_state["progress"] = f"[{idx}/{total}] {label} downloaded"
969
+
970
+ # Also pre-clone musubi-tuner to volume (for training)
971
+ _download_state["progress"] = "Caching musubi-tuner to volume..."
972
+ tuner_exists = (await _ssh_exec_async(ssh, "test -f /runpod-volume/musubi-tuner/pyproject.toml && echo EXISTS || echo MISSING")).strip()
973
+ if tuner_exists != "EXISTS":
974
+ await _ssh_exec_async(ssh, "cd /workspace && git clone --depth 1 https://github.com/kohya-ss/musubi-tuner.git && cp -r /workspace/musubi-tuner /runpod-volume/musubi-tuner", timeout=300)
975
+ _download_state["progress"] = "musubi-tuner cached"
976
+ else:
977
+ _download_state["progress"] = "musubi-tuner already cached"
978
+
979
+ elif model_type == "wan22_animate":
980
+ animate_dir = "/workspace/models/WAN2.2-Animate"
981
+ wan22_dir = "/workspace/models/WAN2.2"
982
+ hf_base = "https://huggingface.co"
983
+ await _ssh_exec_async(ssh, f"mkdir -p {animate_dir}")
984
+
985
+ # Files to download: (label, url, target, timeout_s, min_bytes)
986
+ wget_files = [
987
+ (
988
+ "WAN 2.2 Animate model (~32GB)",
989
+ f"{hf_base}/Comfy-Org/Wan_2.2_ComfyUI_Repackaged/resolve/main/split_files/diffusion_models/wan2.2_animate_14B_bf16.safetensors",
990
+ f"{animate_dir}/wan2.2_animate_14B_bf16.safetensors",
991
+ 7200,
992
+ 30_000_000_000, # 30GB min — partial downloads get resumed
993
+ ),
994
+ (
995
+ "UMT5 text encoder fp8 (~6.3GB)",
996
+ f"{hf_base}/Kijai/WanVideo_comfy/resolve/main/umt5-xxl-enc-fp8_e4m3fn.safetensors",
997
+ f"{animate_dir}/umt5-xxl-enc-fp8_e4m3fn.safetensors",
998
+ 1800,
999
+ 6_000_000_000,
1000
+ ),
1001
+ (
1002
+ "VAE (~242MB)",
1003
+ f"{hf_base}/Comfy-Org/Wan_2.2_ComfyUI_Repackaged/resolve/main/split_files/vae/wan_2.1_vae.safetensors",
1004
+ f"{animate_dir}/wan_2.1_vae.safetensors",
1005
+ 300,
1006
+ 200_000_000,
1007
+ ),
1008
+ (
1009
+ "CLIP Vision H (~2.4GB)",
1010
+ f"{hf_base}/h94/IP-Adapter/resolve/main/models/image_encoder/model.safetensors",
1011
+ f"{animate_dir}/clip_vision_h.safetensors",
1012
+ 900,
1013
+ 2_000_000_000,
1014
+ ),
1015
+ ]
1016
+
1017
+ total = len(wget_files)
1018
+
1019
+ for idx, (label, url, target, dl_timeout, min_bytes) in enumerate(wget_files, 1):
1020
+ # For T5 and VAE, reuse from wan22 dir if already present (and complete)
1021
+ wan22_candidate = f"{wan22_dir}/{target.split('/')[-1]}"
1022
+ reused = False
1023
+ if label in ("UMT5 text encoder fp8 (~6.3GB)", "VAE (~1GB)"):
1024
+ wan22_size = (await _ssh_exec_async(ssh, f"stat -c%s {wan22_candidate} 2>/dev/null || echo 0")).strip()
1025
+ if int(wan22_size) >= min_bytes:
1026
+ _download_state["progress"] = f"[{idx}/{total}] {label} — reusing from WAN2.2 dir"
1027
+ await _ssh_exec_async(ssh, f"ln -sf {wan22_candidate} {target} 2>/dev/null || cp {wan22_candidate} {target}")
1028
+ reused = True
1029
+
1030
+ if not reused:
1031
+ size_str = (await _ssh_exec_async(ssh, f"stat -c%s {target} 2>/dev/null || echo 0")).strip()
1032
+ if int(size_str) >= min_bytes:
1033
+ _download_state["progress"] = f"[{idx}/{total}] {label} already cached"
1034
+ else:
1035
+ _download_state["progress"] = f"[{idx}/{total}] Downloading {label}..."
1036
+ filename = target.split("/")[-1]
1037
+ target_dir = "/".join(target.split("/")[:-1])
1038
+ # Remove stale symlinks before downloading (can't resume through a symlink)
1039
+ await _ssh_exec_async(ssh, f"test -L '{target}' && rm -f '{target}'; true")
1040
+ await _ssh_exec_async(
1041
+ ssh,
1042
+ f"aria2c -x 16 -s 16 -c -o '{filename}' --dir='{target_dir}' '{url}' 2>&1 | tail -3",
1043
+ timeout=dl_timeout,
1044
+ )
1045
+ size_str = (await _ssh_exec_async(ssh, f"stat -c%s {target} 2>/dev/null || echo 0")).strip()
1046
+ if int(size_str) < min_bytes:
1047
+ raise RuntimeError(f"Failed to download {label} (size {size_str} < {min_bytes})")
1048
+ _download_state["progress"] = f"[{idx}/{total}] {label} downloaded"
1049
+
1050
+ _download_state["status"] = "completed"
1051
+ _download_state["progress"] = "All models downloaded to volume! Ready for training."
1052
+ logger.info("Model pre-download complete for %s", model_type)
1053
+
1054
+ except Exception as e:
1055
+ _download_state["status"] = "failed"
1056
+ _download_state["error"] = str(e)
1057
+ _download_state["progress"] = f"Failed: {e}"
1058
+ logger.error("Model download failed: %s", e)
1059
+
1060
+ finally:
1061
+ if ssh:
1062
+ try:
1063
+ ssh.close()
1064
+ except Exception:
1065
+ pass
1066
+ if pod_id:
1067
+ try:
1068
+ await asyncio.to_thread(runpod.terminate_pod, pod_id)
1069
+ logger.info("Download pod terminated: %s", pod_id)
1070
+ except Exception as e:
1071
+ logger.warning("Failed to terminate download pod: %s", e)
1072
+ _download_state["pod_id"] = None
1073
+
1074
+
1075
  @router.post("/stop")
1076
  async def stop_pod():
1077
  """Stop the GPU pod."""
 
1133
 
1134
  @router.post("/upload-lora")
1135
  async def upload_lora_to_pod(file: UploadFile = File(...)):
1136
+ """Upload a LoRA file directly to /runpod-volume/loras/ via SFTP so it persists."""
1137
+ import paramiko, io
1138
 
1139
  if _pod_state["status"] != "running":
1140
  raise HTTPException(400, "Pod not running - start it first")
 
1142
  if not file.filename.endswith(".safetensors"):
1143
  raise HTTPException(400, "Only .safetensors files supported")
1144
 
1145
+ ip = _pod_state.get("ip")
1146
+ port = _pod_state.get("ssh_port") or 22
1147
+ if not ip:
1148
+ raise HTTPException(500, "No SSH IP available")
1149
+
1150
+ content = await file.read()
1151
+ dest_path = f"/runpod-volume/loras/{file.filename}"
1152
+ comfy_link = f"/workspace/ComfyUI/models/loras/{file.filename}"
1153
+
1154
+ def _sftp_upload():
1155
+ client = paramiko.SSHClient()
1156
+ client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
1157
+ client.connect(ip, port=port, username="root", timeout=30)
1158
+ # Ensure dir exists
1159
+ client.exec_command("mkdir -p /runpod-volume/loras")[1].read()
1160
+ sftp = client.open_sftp()
1161
+ sftp.putfo(io.BytesIO(content), dest_path)
1162
+ sftp.close()
1163
+ # Symlink into ComfyUI
1164
+ client.exec_command(f"ln -sf {dest_path} {comfy_link}")[1].read()
1165
+ client.close()
1166
+
1167
  try:
1168
+ await asyncio.to_thread(_sftp_upload)
1169
+ logger.info("LoRA uploaded to volume: %s (%d bytes)", file.filename, len(content))
1170
+ return {"status": "uploaded", "filename": file.filename, "path": dest_path}
1171
+ except Exception as e:
1172
+ logger.error("LoRA upload failed: %s", e)
1173
+ raise HTTPException(500, f"Upload failed: {e}")
1174
 
 
 
 
 
 
1175
 
1176
+ @router.post("/upload-lora-local")
1177
+ async def upload_lora_from_local(local_path: str, filename: str | None = None):
1178
+ """Upload a LoRA from a local server path directly to the volume via SFTP."""
1179
+ import paramiko, io
1180
+ from pathlib import Path
1181
 
1182
+ if _pod_state["status"] != "running":
1183
+ raise HTTPException(400, "Pod not running - start it first")
1184
+
1185
+ src = Path(local_path)
1186
+ if not src.exists():
1187
+ raise HTTPException(404, f"Local file not found: {local_path}")
1188
+
1189
+ dest_name = filename or src.name
1190
+ if not dest_name.endswith(".safetensors"):
1191
+ raise HTTPException(400, "Only .safetensors files supported")
1192
+
1193
+ ip = _pod_state.get("ip")
1194
+ port = _pod_state.get("ssh_port") or 22
1195
+ dest_path = f"/runpod-volume/loras/{dest_name}"
1196
+ comfy_link = f"/workspace/ComfyUI/models/loras/{dest_name}"
1197
+
1198
+ def _sftp_upload():
1199
+ client = paramiko.SSHClient()
1200
+ client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
1201
+ client.connect(ip, port=port, username="root", timeout=30)
1202
+ client.exec_command("mkdir -p /runpod-volume/loras")[1].read()
1203
+ sftp = client.open_sftp()
1204
+ sftp.put(str(src), dest_path)
1205
+ sftp.close()
1206
+ client.exec_command(f"ln -sf {dest_path} {comfy_link}")[1].read()
1207
+ client.close()
1208
+
1209
+ try:
1210
+ await asyncio.to_thread(_sftp_upload)
1211
+ size_mb = src.stat().st_size / 1024 / 1024
1212
+ logger.info("LoRA uploaded from local: %s (%.1f MB)", dest_name, size_mb)
1213
+ return {"status": "uploaded", "filename": dest_name, "path": dest_path, "size_mb": round(size_mb, 1)}
1214
  except Exception as e:
1215
+ logger.error("Local LoRA upload failed: %s", e)
1216
  raise HTTPException(500, f"Upload failed: {e}")
1217
 
1218
 
 
1226
  seed: int = -1
1227
  lora_name: str | None = None
1228
  lora_strength: float = 0.85
1229
+ lora_name_2: str | None = None
1230
+ lora_strength_2: float = 0.85
1231
  character_id: str | None = None
1232
  template_id: str | None = None
1233
  content_rating: str = "sfw"
 
1250
  seed = request.seed if request.seed >= 0 else random.randint(0, 2**32 - 1)
1251
 
1252
  model_type = _pod_state.get("model_type", "flux2")
1253
+ if model_type == "wan22":
1254
+ workflow = _build_wan_t2i_workflow(
1255
+ prompt=request.prompt,
1256
+ negative_prompt=request.negative_prompt,
1257
+ width=request.width,
1258
+ height=request.height,
1259
+ steps=request.steps,
1260
+ cfg=request.cfg,
1261
+ seed=seed,
1262
+ lora_name=request.lora_name,
1263
+ lora_strength=request.lora_strength,
1264
+ lora_name_2=request.lora_name_2,
1265
+ lora_strength_2=request.lora_strength_2,
1266
+ )
1267
+ else:
1268
+ workflow = _build_flux_workflow(
1269
+ prompt=request.prompt,
1270
+ negative_prompt=request.negative_prompt,
1271
+ width=request.width,
1272
+ height=request.height,
1273
+ steps=request.steps,
1274
+ cfg=request.cfg,
1275
+ seed=seed,
1276
+ lora_name=request.lora_name,
1277
+ lora_strength=request.lora_strength,
1278
+ model_type=model_type,
1279
+ )
1280
 
1281
  comfyui_url = _get_comfyui_url()
1282
 
 
1581
  workflow["7"]["inputs"]["clip"] = ["20", 1]
1582
 
1583
  return workflow
1584
+
1585
+
1586
+ def _build_wan_t2i_workflow(
1587
+ prompt: str,
1588
+ negative_prompt: str,
1589
+ width: int,
1590
+ height: int,
1591
+ steps: int,
1592
+ cfg: float,
1593
+ seed: int,
1594
+ lora_name: str | None,
1595
+ lora_strength: float,
1596
+ lora_name_2: str | None = None,
1597
+ lora_strength_2: float = 0.85,
1598
+ ) -> dict:
1599
+ """Build a ComfyUI workflow for WAN 2.2 Remix — dual-DiT MoE split-step.
1600
+
1601
+ Based on the WAN 2.2 Remix workflow from CivitAI:
1602
+ - Two UNETLoaders: high-noise + low-noise Remix models (fp8)
1603
+ - wanBlockSwap on both (offloads blocks to CPU for 24GB GPUs)
1604
+ - ModelSamplingSD3 with shift=5 on both
1605
+ - Dual KSamplerAdvanced: high-noise runs first half, low-noise finishes
1606
+ - CLIPLoader (type=wan) + CLIPTextEncode for T5 text encoding
1607
+ - Standard VAELoader + VAEDecode
1608
+ - EmptyHunyuanLatentVideo for latent (1 frame = image, 81+ = video)
1609
+ """
1610
+ high_dit = "wan22_remix_t2v_high_fp8.safetensors"
1611
+ low_dit = "wan22_remix_t2v_low_fp8.safetensors"
1612
+ t5_name = "umt5_xxl_fp8_e4m3fn_scaled.safetensors"
1613
+ vae_name = "wan_2.1_vae.safetensors"
1614
+
1615
+ total_steps = steps # default 8
1616
+ split_step = total_steps // 2 # high-noise does first half, low-noise does rest
1617
+ shift = 5.0
1618
+ block_swap = 20 # blocks offloaded to CPU (0-40, higher = less VRAM)
1619
+
1620
+ workflow = {
1621
+ # ── Load high-noise DiT ──
1622
+ "1": {
1623
+ "class_type": "UNETLoader",
1624
+ "inputs": {
1625
+ "unet_name": high_dit,
1626
+ "weight_dtype": "fp8_e4m3fn",
1627
+ },
1628
+ },
1629
+ # ── Load low-noise DiT ──
1630
+ "2": {
1631
+ "class_type": "UNETLoader",
1632
+ "inputs": {
1633
+ "unet_name": low_dit,
1634
+ "weight_dtype": "fp8_e4m3fn",
1635
+ },
1636
+ },
1637
+ # ── wanBlockSwap on high-noise (VRAM optimization) ──
1638
+ "11": {
1639
+ "class_type": "wanBlockSwap",
1640
+ "inputs": {
1641
+ "model": ["1", 0],
1642
+ "blocks_to_swap": block_swap,
1643
+ "offload_img_emb": False,
1644
+ "offload_txt_emb": False,
1645
+ },
1646
+ },
1647
+ # ── wanBlockSwap on low-noise ──
1648
+ "12": {
1649
+ "class_type": "wanBlockSwap",
1650
+ "inputs": {
1651
+ "model": ["2", 0],
1652
+ "blocks_to_swap": block_swap,
1653
+ "offload_img_emb": False,
1654
+ "offload_txt_emb": False,
1655
+ },
1656
+ },
1657
+ # ── ModelSamplingSD3 shift on high-noise ──
1658
+ "13": {
1659
+ "class_type": "ModelSamplingSD3",
1660
+ "inputs": {
1661
+ "model": ["11", 0],
1662
+ "shift": shift,
1663
+ },
1664
+ },
1665
+ # ── ModelSamplingSD3 shift on low-noise ──
1666
+ "14": {
1667
+ "class_type": "ModelSamplingSD3",
1668
+ "inputs": {
1669
+ "model": ["12", 0],
1670
+ "shift": shift,
1671
+ },
1672
+ },
1673
+ # ── Load T5 text encoder ──
1674
+ "3": {
1675
+ "class_type": "CLIPLoader",
1676
+ "inputs": {
1677
+ "clip_name": t5_name,
1678
+ "type": "wan",
1679
+ },
1680
+ },
1681
+ # ── Positive prompt ──
1682
+ "6": {
1683
+ "class_type": "CLIPTextEncode",
1684
+ "inputs": {
1685
+ "text": prompt,
1686
+ "clip": ["3", 0],
1687
+ },
1688
+ },
1689
+ # ── Negative prompt ──
1690
+ "7": {
1691
+ "class_type": "CLIPTextEncode",
1692
+ "inputs": {
1693
+ "text": negative_prompt or "",
1694
+ "clip": ["3", 0],
1695
+ },
1696
+ },
1697
+ # ── VAE ──
1698
+ "4": {
1699
+ "class_type": "VAELoader",
1700
+ "inputs": {"vae_name": vae_name},
1701
+ },
1702
+ # ── Empty latent (1 frame = single image) ──
1703
+ "5": {
1704
+ "class_type": "EmptyHunyuanLatentVideo",
1705
+ "inputs": {
1706
+ "width": width,
1707
+ "height": height,
1708
+ "length": 1,
1709
+ "batch_size": 1,
1710
+ },
1711
+ },
1712
+ # ── KSamplerAdvanced #1: High-noise model (first half of steps) ──
1713
+ "15": {
1714
+ "class_type": "KSamplerAdvanced",
1715
+ "inputs": {
1716
+ "model": ["13", 0],
1717
+ "positive": ["6", 0],
1718
+ "negative": ["7", 0],
1719
+ "latent_image": ["5", 0],
1720
+ "add_noise": "enable",
1721
+ "noise_seed": seed,
1722
+ "steps": total_steps,
1723
+ "cfg": cfg,
1724
+ "sampler_name": "euler",
1725
+ "scheduler": "simple",
1726
+ "start_at_step": 0,
1727
+ "end_at_step": split_step,
1728
+ "return_with_leftover_noise": "enable",
1729
+ },
1730
+ },
1731
+ # ── KSamplerAdvanced #2: Low-noise model (second half of steps) ──
1732
+ "16": {
1733
+ "class_type": "KSamplerAdvanced",
1734
+ "inputs": {
1735
+ "model": ["14", 0],
1736
+ "positive": ["6", 0],
1737
+ "negative": ["7", 0],
1738
+ "latent_image": ["15", 0],
1739
+ "add_noise": "disable",
1740
+ "noise_seed": seed,
1741
+ "steps": total_steps,
1742
+ "cfg": cfg,
1743
+ "sampler_name": "euler",
1744
+ "scheduler": "simple",
1745
+ "start_at_step": split_step,
1746
+ "end_at_step": 10000,
1747
+ "return_with_leftover_noise": "disable",
1748
+ },
1749
+ },
1750
+ # ── VAE Decode ──
1751
+ "8": {
1752
+ "class_type": "VAEDecode",
1753
+ "inputs": {
1754
+ "samples": ["16", 0],
1755
+ "vae": ["4", 0],
1756
+ },
1757
+ },
1758
+ # ── Save Image ──
1759
+ "9": {
1760
+ "class_type": "SaveImage",
1761
+ "inputs": {
1762
+ "filename_prefix": "wan_remix_pod",
1763
+ "images": ["8", 0],
1764
+ },
1765
+ },
1766
+ }
1767
+
1768
+ # Add LoRA(s) to both models if specified — chained: DiT → LoRA1 → LoRA2 → Sampler
1769
+ if lora_name:
1770
+ # LoRA 1 (body) on high-noise and low-noise DiT
1771
+ workflow["20"] = {
1772
+ "class_type": "LoraLoader",
1773
+ "inputs": {
1774
+ "lora_name": lora_name,
1775
+ "strength_model": lora_strength,
1776
+ "strength_clip": 1.0,
1777
+ "model": ["13", 0],
1778
+ "clip": ["3", 0],
1779
+ },
1780
+ }
1781
+ workflow["21"] = {
1782
+ "class_type": "LoraLoader",
1783
+ "inputs": {
1784
+ "lora_name": lora_name,
1785
+ "strength_model": lora_strength,
1786
+ "strength_clip": 1.0,
1787
+ "model": ["14", 0],
1788
+ "clip": ["3", 0],
1789
+ },
1790
+ }
1791
+
1792
+ # Determine what the samplers and CLIP read from (LoRA2 if present, else LoRA1)
1793
+ high_model_out = ["20", 0]
1794
+ low_model_out = ["21", 0]
1795
+ clip_out = ["20", 1]
1796
+
1797
+ if lora_name_2:
1798
+ # LoRA 2 (face) chained after LoRA 1 on both models
1799
+ workflow["22"] = {
1800
+ "class_type": "LoraLoader",
1801
+ "inputs": {
1802
+ "lora_name": lora_name_2,
1803
+ "strength_model": lora_strength_2,
1804
+ "strength_clip": 1.0,
1805
+ "model": ["20", 0],
1806
+ "clip": ["20", 1],
1807
+ },
1808
+ }
1809
+ workflow["23"] = {
1810
+ "class_type": "LoraLoader",
1811
+ "inputs": {
1812
+ "lora_name": lora_name_2,
1813
+ "strength_model": lora_strength_2,
1814
+ "strength_clip": 1.0,
1815
+ "model": ["21", 0],
1816
+ "clip": ["21", 1],
1817
+ },
1818
+ }
1819
+ high_model_out = ["22", 0]
1820
+ low_model_out = ["23", 0]
1821
+ clip_out = ["22", 1]
1822
+
1823
+ # Rewire samplers and CLIP encoding
1824
+ workflow["15"]["inputs"]["model"] = high_model_out
1825
+ workflow["16"]["inputs"]["model"] = low_model_out
1826
+ workflow["6"]["inputs"]["clip"] = clip_out
1827
+ workflow["7"]["inputs"]["clip"] = clip_out
1828
+
1829
+ return workflow
src/content_engine/api/routes_ui.py CHANGED
@@ -1,172 +1,23 @@
1
- """Web UI route — serves the single-page dashboard with password protection."""
2
 
3
  from __future__ import annotations
4
 
5
- import hashlib
6
- import os
7
- import secrets
8
  from pathlib import Path
9
 
10
- from fastapi import APIRouter, Request, Form, HTTPException
11
- from fastapi.responses import HTMLResponse, Response, RedirectResponse
12
 
13
  router = APIRouter(tags=["ui"])
14
 
15
  UI_HTML_PATH = Path(__file__).parent / "ui.html"
16
 
17
- # Simple session storage (in-memory, resets on restart)
18
- _valid_sessions: set[str] = set()
19
-
20
- # Get password from environment variable
21
- APP_PASSWORD = os.environ.get("APP_PASSWORD", "")
22
-
23
-
24
- def _check_session(request: Request) -> bool:
25
- """Check if request has valid session."""
26
- if not APP_PASSWORD:
27
- return True # No password set, allow access
28
- session_token = request.cookies.get("session")
29
- return session_token in _valid_sessions
30
-
31
-
32
- LOGIN_HTML = """
33
- <!DOCTYPE html>
34
- <html lang="en">
35
- <head>
36
- <meta charset="UTF-8">
37
- <meta name="viewport" content="width=device-width, initial-scale=1.0">
38
- <title>Login - Content Engine</title>
39
- <style>
40
- *, *::before, *::after { box-sizing: border-box; margin: 0; padding: 0; }
41
- body {
42
- font-family: 'Segoe UI', -apple-system, system-ui, sans-serif;
43
- background: linear-gradient(135deg, #0a0a0f 0%, #1a1a2e 100%);
44
- color: #eee;
45
- min-height: 100vh;
46
- display: flex;
47
- align-items: center;
48
- justify-content: center;
49
- }
50
- .login-box {
51
- background: #1a1a28;
52
- border: 1px solid #2a2a3a;
53
- border-radius: 16px;
54
- padding: 40px;
55
- width: 100%;
56
- max-width: 400px;
57
- box-shadow: 0 20px 60px rgba(0,0,0,0.5);
58
- }
59
- h1 {
60
- font-size: 24px;
61
- margin-bottom: 8px;
62
- background: linear-gradient(135deg, #7c3aed, #ec4899);
63
- -webkit-background-clip: text;
64
- -webkit-text-fill-color: transparent;
65
- }
66
- .subtitle { color: #888; font-size: 14px; margin-bottom: 30px; }
67
- label { display: block; font-size: 13px; color: #888; margin-bottom: 6px; }
68
- input[type="password"] {
69
- width: 100%;
70
- padding: 12px 16px;
71
- border-radius: 8px;
72
- border: 1px solid #2a2a3a;
73
- background: #0a0a0f;
74
- color: #eee;
75
- font-size: 16px;
76
- margin-bottom: 20px;
77
- }
78
- input[type="password"]:focus { outline: none; border-color: #7c3aed; }
79
- button {
80
- width: 100%;
81
- padding: 14px;
82
- border-radius: 8px;
83
- border: none;
84
- background: linear-gradient(135deg, #7c3aed, #6d28d9);
85
- color: white;
86
- font-size: 16px;
87
- font-weight: 600;
88
- cursor: pointer;
89
- transition: transform 0.1s, box-shadow 0.2s;
90
- }
91
- button:hover { transform: translateY(-1px); box-shadow: 0 4px 20px rgba(124, 58, 237, 0.4); }
92
- .error { color: #ef4444; font-size: 13px; margin-bottom: 16px; }
93
- </style>
94
- </head>
95
- <body>
96
- <div class="login-box">
97
- <h1>Content Engine</h1>
98
- <p class="subtitle">Enter password to access</p>
99
- {{ERROR}}
100
- <form method="POST" action="/login">
101
- <label>Password</label>
102
- <input type="password" name="password" placeholder="Enter password" autofocus required>
103
- <button type="submit">Login</button>
104
- </form>
105
- </div>
106
- </body>
107
- </html>
108
- """
109
-
110
 
111
  @router.get("/", response_class=HTMLResponse)
112
- async def dashboard(request: Request):
113
  """Serve the main dashboard UI."""
114
- if not _check_session(request):
115
- return RedirectResponse(url="/login", status_code=302)
116
-
117
  content = UI_HTML_PATH.read_text(encoding="utf-8")
118
  return Response(
119
  content=content,
120
  media_type="text/html",
121
  headers={"Cache-Control": "no-cache, no-store, must-revalidate"},
122
  )
123
-
124
-
125
- @router.get("/login", response_class=HTMLResponse)
126
- async def login_page(request: Request, error: str = ""):
127
- """Show login page."""
128
- if not APP_PASSWORD:
129
- return RedirectResponse(url="/", status_code=302)
130
-
131
- if _check_session(request):
132
- return RedirectResponse(url="/", status_code=302)
133
-
134
- error_html = f'<p class="error">{error}</p>' if error else ""
135
- html = LOGIN_HTML.replace("{{ERROR}}", error_html)
136
- return Response(content=html, media_type="text/html")
137
-
138
-
139
- @router.post("/login")
140
- async def login_submit(password: str = Form(...)):
141
- """Handle login form submission."""
142
- if not APP_PASSWORD:
143
- return RedirectResponse(url="/", status_code=302)
144
-
145
- if password == APP_PASSWORD:
146
- # Create session token
147
- session_token = secrets.token_hex(32)
148
- _valid_sessions.add(session_token)
149
-
150
- response = RedirectResponse(url="/", status_code=302)
151
- response.set_cookie(
152
- key="session",
153
- value=session_token,
154
- httponly=True,
155
- max_age=86400 * 7, # 7 days
156
- samesite="lax",
157
- )
158
- return response
159
- else:
160
- return RedirectResponse(url="/login?error=Invalid+password", status_code=302)
161
-
162
-
163
- @router.get("/logout")
164
- async def logout(request: Request):
165
- """Log out and clear session."""
166
- session_token = request.cookies.get("session")
167
- if session_token in _valid_sessions:
168
- _valid_sessions.discard(session_token)
169
-
170
- response = RedirectResponse(url="/login", status_code=302)
171
- response.delete_cookie("session")
172
- return response
 
1
+ """Web UI route — serves the single-page dashboard."""
2
 
3
  from __future__ import annotations
4
 
 
 
 
5
  from pathlib import Path
6
 
7
+ from fastapi import APIRouter
8
+ from fastapi.responses import HTMLResponse, Response
9
 
10
  router = APIRouter(tags=["ui"])
11
 
12
  UI_HTML_PATH = Path(__file__).parent / "ui.html"
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  @router.get("/", response_class=HTMLResponse)
16
+ async def dashboard():
17
  """Serve the main dashboard UI."""
 
 
 
18
  content = UI_HTML_PATH.read_text(encoding="utf-8")
19
  return Response(
20
  content=content,
21
  media_type="text/html",
22
  headers={"Cache-Control": "no-cache, no-store, must-revalidate"},
23
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/content_engine/api/routes_video.py CHANGED
@@ -10,17 +10,10 @@ import time
10
  import uuid
11
  from pathlib import Path
12
 
 
13
  from fastapi import APIRouter, File, Form, HTTPException, UploadFile
14
  from pydantic import BaseModel
15
 
16
- # Optional RunPod import
17
- try:
18
- import runpod
19
- RUNPOD_AVAILABLE = True
20
- except ImportError:
21
- runpod = None
22
- RUNPOD_AVAILABLE = False
23
-
24
  logger = logging.getLogger(__name__)
25
 
26
  router = APIRouter(prefix="/api/video", tags=["video"])
@@ -47,6 +40,10 @@ def _get_pod_state():
47
  from content_engine.api.routes_pod import _pod_state
48
  return _pod_state
49
 
 
 
 
 
50
 
51
  class VideoGenerateRequest(BaseModel):
52
  prompt: str
@@ -93,9 +90,10 @@ async def generate_video(
93
  )
94
 
95
  try:
 
96
  async with httpx.AsyncClient(timeout=30) as client:
97
  # First upload the image to ComfyUI
98
- upload_url = f"http://{pod_state['ip']}:{pod_state['port']}/upload/image"
99
  files = {"image": (f"input_{job_id}.png", image_bytes, "image/png")}
100
  upload_resp = await client.post(upload_url, files=files)
101
 
@@ -116,7 +114,7 @@ async def generate_video(
116
  )
117
 
118
  # Submit workflow
119
- url = f"http://{pod_state['ip']}:{pod_state['port']}/prompt"
120
  resp = await client.post(url, json={"prompt": workflow})
121
  resp.raise_for_status()
122
 
@@ -364,6 +362,8 @@ async def _poll_wavespeed_video(poll_url: str, api_key: str, job_id: str, max_at
364
  if status == "failed":
365
  error_msg = data.get("error", "Unknown error")
366
  logger.error("WaveSpeed video job failed: %s", error_msg)
 
 
367
  return None
368
 
369
  outputs = data.get("outputs", [])
@@ -430,6 +430,11 @@ async def _generate_cloud_video(
430
  if negative_prompt:
431
  payload["negative_prompt"] = negative_prompt
432
 
 
 
 
 
 
433
  _video_jobs[job_id]["message"] = f"Calling WaveSpeed API ({wavespeed_model})..."
434
  logger.info("Calling WaveSpeed video API: %s", endpoint)
435
 
@@ -527,14 +532,14 @@ async def _poll_video_job(job_id: str, prompt_id: str):
527
  """Poll ComfyUI for video job completion."""
528
  import httpx
529
 
530
- pod_state = _get_pod_state()
531
  start = time.time()
532
- timeout = 600 # 10 minutes for video
 
533
 
534
  async with httpx.AsyncClient(timeout=60) as client:
535
  while time.time() - start < timeout:
536
  try:
537
- url = f"http://{pod_state['ip']}:{pod_state['port']}/history/{prompt_id}"
538
  resp = await client.get(url)
539
 
540
  if resp.status_code == 200:
@@ -573,7 +578,7 @@ async def _download_video(client, job_id: str, video_info: dict, pod_state: dict
573
  file_type = video_info.get("type", "output")
574
 
575
  # Download video
576
- view_url = f"http://{pod_state['ip']}:{pod_state['port']}/view"
577
  params = {"filename": filename, "type": file_type}
578
  if subfolder:
579
  params["subfolder"] = subfolder
@@ -627,10 +632,339 @@ async def get_video_file(filename: str):
627
  if not video_path.exists():
628
  raise HTTPException(404, "Video not found")
629
 
630
- media_type = "video/webm" if filename.endswith(".webm") else "image/webp"
 
 
 
 
 
631
  return FileResponse(video_path, media_type=media_type)
632
 
633
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
634
  def _build_wan_i2v_workflow(
635
  uploaded_filename: str = None,
636
  image_b64: str = None,
 
10
  import uuid
11
  from pathlib import Path
12
 
13
+ import runpod
14
  from fastapi import APIRouter, File, Form, HTTPException, UploadFile
15
  from pydantic import BaseModel
16
 
 
 
 
 
 
 
 
 
17
  logger = logging.getLogger(__name__)
18
 
19
  router = APIRouter(prefix="/api/video", tags=["video"])
 
40
  from content_engine.api.routes_pod import _pod_state
41
  return _pod_state
42
 
43
+ def _get_comfyui_url():
44
+ from content_engine.api.routes_pod import _get_comfyui_url as _gcurl
45
+ return _gcurl()
46
+
47
 
48
  class VideoGenerateRequest(BaseModel):
49
  prompt: str
 
90
  )
91
 
92
  try:
93
+ comfyui_url = _get_comfyui_url()
94
  async with httpx.AsyncClient(timeout=30) as client:
95
  # First upload the image to ComfyUI
96
+ upload_url = f"{comfyui_url}/upload/image"
97
  files = {"image": (f"input_{job_id}.png", image_bytes, "image/png")}
98
  upload_resp = await client.post(upload_url, files=files)
99
 
 
114
  )
115
 
116
  # Submit workflow
117
+ url = f"{comfyui_url}/prompt"
118
  resp = await client.post(url, json={"prompt": workflow})
119
  resp.raise_for_status()
120
 
 
362
  if status == "failed":
363
  error_msg = data.get("error", "Unknown error")
364
  logger.error("WaveSpeed video job failed: %s", error_msg)
365
+ _video_jobs[job_id]["status"] = "failed"
366
+ _video_jobs[job_id]["error"] = error_msg
367
  return None
368
 
369
  outputs = data.get("outputs", [])
 
430
  if negative_prompt:
431
  payload["negative_prompt"] = negative_prompt
432
 
433
+ # Grok Imagine Video uses duration (6 or 10s) instead of frame counts
434
+ if model == "grok-imagine-i2v":
435
+ num_frames = _video_jobs[job_id].get("num_frames", 81)
436
+ payload["duration"] = 10 if num_frames > 150 else 6
437
+
438
  _video_jobs[job_id]["message"] = f"Calling WaveSpeed API ({wavespeed_model})..."
439
  logger.info("Calling WaveSpeed video API: %s", endpoint)
440
 
 
532
  """Poll ComfyUI for video job completion."""
533
  import httpx
534
 
 
535
  start = time.time()
536
+ timeout = 1800 # 30 minutes for video (WAN 2.2 needs time to load 14B model first run)
537
+ comfyui_url = _get_comfyui_url()
538
 
539
  async with httpx.AsyncClient(timeout=60) as client:
540
  while time.time() - start < timeout:
541
  try:
542
+ url = f"{comfyui_url}/history/{prompt_id}"
543
  resp = await client.get(url)
544
 
545
  if resp.status_code == 200:
 
578
  file_type = video_info.get("type", "output")
579
 
580
  # Download video
581
+ view_url = f"{_get_comfyui_url()}/view"
582
  params = {"filename": filename, "type": file_type}
583
  if subfolder:
584
  params["subfolder"] = subfolder
 
632
  if not video_path.exists():
633
  raise HTTPException(404, "Video not found")
634
 
635
+ if filename.endswith(".webm"):
636
+ media_type = "video/webm"
637
+ elif filename.endswith(".mp4"):
638
+ media_type = "video/mp4"
639
+ else:
640
+ media_type = "image/webp"
641
  return FileResponse(video_path, media_type=media_type)
642
 
643
 
644
+ @router.post("/animate")
645
+ async def generate_video_animate(
646
+ image: UploadFile = File(...),
647
+ driving_video: UploadFile = File(...),
648
+ prompt: str = Form("a person dancing, smooth motion, high quality"),
649
+ negative_prompt: str = Form(""),
650
+ width: int = Form(832),
651
+ height: int = Form(480),
652
+ num_frames: int = Form(81),
653
+ fps: int = Form(16),
654
+ seed: int = Form(-1),
655
+ steps: int = Form(20),
656
+ cfg: float = Form(6.0),
657
+ bg_mode: str = Form("keep"), # keep | driving_video | auto
658
+ ):
659
+ """Generate a dance animation via WAN 2.2 Animate on RunPod ComfyUI pod.
660
+
661
+ Requires on the pod:
662
+ - models/diffusion_models/Wan2_2-Animate-14B_fp8_e4m3fn_scaled_KJ.safetensors
663
+ - models/vae/wan_2.1_vae.safetensors
664
+ - models/clip_vision/clip_vision_h.safetensors
665
+ - models/text_encoders/umt5-xxl-enc-bf16.safetensors
666
+ - Custom nodes: ComfyUI-WanVideoWrapper, ComfyUI-VideoHelperSuite, comfyui_controlnet_aux
667
+ """
668
+ import httpx
669
+ import random
670
+
671
+ pod_state = _get_pod_state()
672
+ if pod_state["status"] != "running":
673
+ raise HTTPException(400, "Pod not running — start it first in Status page")
674
+
675
+ job_id = str(uuid.uuid4())[:8]
676
+ seed = seed if seed >= 0 else random.randint(0, 2**32 - 1)
677
+
678
+ image_bytes = await image.read()
679
+ video_bytes = await driving_video.read()
680
+
681
+ try:
682
+ base_url = _get_comfyui_url()
683
+ async with httpx.AsyncClient(timeout=60) as client:
684
+
685
+ # Upload character reference image
686
+ img_resp = await client.post(
687
+ f"{base_url}/upload/image",
688
+ files={"image": (f"ref_{job_id}.png", image_bytes, "image/png")},
689
+ )
690
+ if img_resp.status_code != 200:
691
+ raise HTTPException(500, f"Failed to upload character image: {img_resp.text[:200]}")
692
+ img_filename = img_resp.json().get("name", f"ref_{job_id}.png")
693
+ logger.info("Uploaded character image: %s", img_filename)
694
+
695
+ # Upload driving video
696
+ vid_ext = "mp4"
697
+ if driving_video.filename and "." in driving_video.filename:
698
+ vid_ext = driving_video.filename.rsplit(".", 1)[-1].lower()
699
+ vid_resp = await client.post(
700
+ f"{base_url}/upload/image",
701
+ files={"image": (f"drive_{job_id}.{vid_ext}", video_bytes, "video/mp4")},
702
+ )
703
+ if vid_resp.status_code != 200:
704
+ raise HTTPException(500, f"Failed to upload driving video: {vid_resp.text[:200]}")
705
+ vid_filename = vid_resp.json().get("name", f"drive_{job_id}.{vid_ext}")
706
+ logger.info("Uploaded driving video: %s", vid_filename)
707
+
708
+ workflow = _build_wan_animate_workflow(
709
+ ref_image_filename=img_filename,
710
+ driving_video_filename=vid_filename,
711
+ prompt=prompt,
712
+ negative_prompt=negative_prompt,
713
+ width=width,
714
+ height=height,
715
+ num_frames=num_frames,
716
+ fps=fps,
717
+ seed=seed,
718
+ steps=steps,
719
+ cfg=cfg,
720
+ bg_mode=bg_mode,
721
+ )
722
+
723
+ resp = await client.post(f"{base_url}/prompt", json={"prompt": workflow})
724
+ if resp.status_code != 200:
725
+ logger.error("ComfyUI /prompt rejected workflow: %s", resp.text[:2000])
726
+ resp.raise_for_status()
727
+ prompt_id = resp.json()["prompt_id"]
728
+
729
+ _video_jobs[job_id] = {
730
+ "prompt_id": prompt_id,
731
+ "status": "running",
732
+ "seed": seed,
733
+ "started_at": time.time(),
734
+ "num_frames": num_frames,
735
+ "fps": fps,
736
+ "mode": "animate",
737
+ "message": "WAN 2.2 Animate submitted...",
738
+ }
739
+
740
+ logger.info("WAN Animate job started: %s -> %s", job_id, prompt_id)
741
+ asyncio.create_task(_poll_video_job(job_id, prompt_id))
742
+
743
+ return {
744
+ "job_id": job_id,
745
+ "status": "running",
746
+ "seed": seed,
747
+ "estimated_time": f"~{num_frames * 3} seconds",
748
+ }
749
+
750
+ except httpx.HTTPError as e:
751
+ logger.error("WAN Animate generation failed: %s", e)
752
+ raise HTTPException(500, f"Generation failed: {e}")
753
+
754
+
755
+ def _build_wan_animate_workflow(
756
+ ref_image_filename: str,
757
+ driving_video_filename: str,
758
+ prompt: str = "a person dancing, smooth motion",
759
+ negative_prompt: str = "",
760
+ width: int = 832,
761
+ height: int = 480,
762
+ num_frames: int = 81,
763
+ fps: int = 16,
764
+ seed: int = 42,
765
+ steps: int = 20,
766
+ cfg: float = 6.0,
767
+ bg_mode: str = "auto",
768
+ ) -> dict:
769
+ """Build ComfyUI API workflow for WAN 2.2 Animate (motion transfer from driving video).
770
+
771
+ Pipeline:
772
+ reference image -> CLIP encode + resize
773
+ driving video -> DWPreprocessor (pose skeleton)
774
+ both -> WanVideoAnimateEmbeds -> WanVideoSampler -> decode -> MP4
775
+
776
+ bg_mode options:
777
+ "keep" - use reference image as background (character's original background)
778
+ "driving_video" - use driving video frames as background
779
+ "auto" - no bg hint, model generates its own background
780
+ """
781
+ neg = negative_prompt or "blurry, static, low quality, watermark, text"
782
+
783
+ workflow = {
784
+ # VAE
785
+ "1": {
786
+ "class_type": "WanVideoVAELoader",
787
+ "inputs": {
788
+ "model_name": "wan_2.1_vae.safetensors",
789
+ "precision": "bf16",
790
+ },
791
+ },
792
+ # CLIP Vision
793
+ "2": {
794
+ "class_type": "CLIPVisionLoader",
795
+ "inputs": {"clip_name": "clip_vision_h.safetensors"},
796
+ },
797
+ # Diffusion model
798
+ "3": {
799
+ "class_type": "WanVideoModelLoader",
800
+ "inputs": {
801
+ "model": "wan2.2_animate_14B_bf16.safetensors",
802
+ "base_precision": "bf16",
803
+ "quantization": "fp8_e4m3fn",
804
+ "load_device": "offload_device",
805
+ "attention_mode": "sdpa",
806
+ },
807
+ },
808
+ # Load T5 text encoder
809
+ "4": {
810
+ "class_type": "LoadWanVideoT5TextEncoder",
811
+ "inputs": {
812
+ "model_name": "umt5-xxl-enc-fp8_e4m3fn.safetensors",
813
+ "precision": "bf16",
814
+ },
815
+ },
816
+ # Encode text prompts
817
+ "16": {
818
+ "class_type": "WanVideoTextEncode",
819
+ "inputs": {
820
+ "positive_prompt": prompt,
821
+ "negative_prompt": neg,
822
+ "t5": ["4", 0],
823
+ "force_offload": True,
824
+ },
825
+ },
826
+ # Load reference character image
827
+ "5": {
828
+ "class_type": "LoadImage",
829
+ "inputs": {"image": ref_image_filename},
830
+ },
831
+ # Resize to target resolution
832
+ "6": {
833
+ "class_type": "ImageResizeKJv2",
834
+ "inputs": {
835
+ "image": ["5", 0],
836
+ "width": width,
837
+ "height": height,
838
+ "upscale_method": "lanczos",
839
+ "keep_proportion": "pad_edge_pixel",
840
+ "pad_color": "0, 0, 0",
841
+ "crop_position": "top",
842
+ "divisible_by": 16,
843
+ },
844
+ },
845
+ # CLIP Vision encode reference
846
+ "7": {
847
+ "class_type": "WanVideoClipVisionEncode",
848
+ "inputs": {
849
+ "clip_vision": ["2", 0],
850
+ "image_1": ["6", 0],
851
+ "strength_1": 1.0,
852
+ "strength_2": 1.0,
853
+ "crop": "center",
854
+ "combine_embeds": "average",
855
+ "force_offload": True,
856
+ },
857
+ },
858
+ # Load driving video (dance moves)
859
+ "8": {
860
+ "class_type": "VHS_LoadVideo",
861
+ "inputs": {
862
+ "video": driving_video_filename,
863
+ "force_rate": fps,
864
+ "custom_width": 0,
865
+ "custom_height": 0,
866
+ "frame_load_cap": num_frames if num_frames > 0 else 0,
867
+ "skip_first_frames": 0,
868
+ "select_every_nth": 1,
869
+ },
870
+ },
871
+ # Extract pose skeleton from driving video
872
+ "9": {
873
+ "class_type": "DWPreprocessor",
874
+ "inputs": {
875
+ "image": ["8", 0],
876
+ "detect_hand": "disable",
877
+ "detect_body": "enable",
878
+ "detect_face": "disable",
879
+ "resolution": max(width, height),
880
+ "bbox_detector": "yolox_l.torchscript.pt",
881
+ "pose_estimator": "dw-ll_ucoco_384_bs5.torchscript.pt",
882
+ "scale_stick_for_xinsr_cn": "disable",
883
+ },
884
+ },
885
+ # Animate embeddings: combine ref image + pose + optional background
886
+ "10": {
887
+ "class_type": "WanVideoAnimateEmbeds",
888
+ "inputs": {
889
+ "vae": ["1", 0],
890
+ "clip_embeds": ["7", 0],
891
+ "ref_images": ["6", 0],
892
+ "pose_images": ["9", 0],
893
+ # bg_mode: "keep" = ref image bg, "driving_video" = video frames bg, "auto" = model decides
894
+ **({} if bg_mode == "auto" else {
895
+ "bg_images": ["6", 0] if bg_mode == "keep" else ["8", 0],
896
+ }),
897
+ "width": width,
898
+ "height": height,
899
+ # When num_frames==0 ("Match video"), link to GetImageSizeAndCount output slot 3
900
+ "num_frames": ["15", 3] if num_frames == 0 else num_frames,
901
+ "force_offload": True,
902
+ "frame_window_size": 77,
903
+ "colormatch": "disabled",
904
+ "pose_strength": 1.0,
905
+ "face_strength": 1.0,
906
+ },
907
+ },
908
+ # Diffusion sampler (no context_options — WanAnim handles looping internally)
909
+ "12": {
910
+ "class_type": "WanVideoSampler",
911
+ "inputs": {
912
+ "model": ["3", 0],
913
+ "image_embeds": ["10", 0],
914
+ "text_embeds": ["16", 0],
915
+ "steps": steps,
916
+ "cfg": cfg,
917
+ "shift": 5.0,
918
+ "seed": seed,
919
+ "force_offload": True,
920
+ "scheduler": "dpm++_sde",
921
+ "riflex_freq_index": 0,
922
+ "denoise_strength": 1.0,
923
+ },
924
+ },
925
+ # Decode latents to frames
926
+ "13": {
927
+ "class_type": "WanVideoDecode",
928
+ "inputs": {
929
+ "vae": ["1", 0],
930
+ "samples": ["12", 0],
931
+ "enable_vae_tiling": True,
932
+ "tile_x": 272,
933
+ "tile_y": 272,
934
+ "tile_stride_x": 144,
935
+ "tile_stride_y": 128,
936
+ },
937
+ },
938
+ # Combine frames into MP4
939
+ "14": {
940
+ "class_type": "VHS_VideoCombine",
941
+ "inputs": {
942
+ "images": ["13", 0],
943
+ "frame_rate": fps,
944
+ "loop_count": 0,
945
+ "filename_prefix": "WanAnimate",
946
+ "format": "video/h264-mp4",
947
+ "pix_fmt": "yuv420p",
948
+ "crf": 19,
949
+ "save_metadata": True,
950
+ "trim_to_audio": False,
951
+ "pingpong": False,
952
+ "save_output": True,
953
+ },
954
+ },
955
+ }
956
+
957
+ # "Match video" mode (num_frames=0): detect actual frame count from posed video
958
+ # GetImageSizeAndCount outputs: (IMAGE, width, height, count) — slot 3 = frame count
959
+ if num_frames == 0:
960
+ workflow["15"] = {
961
+ "class_type": "GetImageSizeAndCount",
962
+ "inputs": {"image": ["9", 0]},
963
+ }
964
+
965
+ return workflow
966
+
967
+
968
  def _build_wan_i2v_workflow(
969
  uploaded_filename: str = None,
970
  image_b64: str = None,
src/content_engine/api/ui.html CHANGED
@@ -898,7 +898,7 @@ select { cursor: pointer; }
898
 
899
  <div id="cloud-model-select" style="display:none">
900
  <label>Model</label>
901
- <select id="gen-cloud-model">
902
  <optgroup label="Recommended">
903
  <option value="seedream-4.5" selected>SeeDream v4.5 (Best)</option>
904
  <option value="gpt-image-1.5">GPT Image 1.5</option>
@@ -909,9 +909,14 @@ select { cursor: pointer; }
909
  <option value="seedream-3.1">SeeDream v3.1</option>
910
  </optgroup>
911
  <optgroup label="Fast">
 
 
912
  <option value="gpt-image-1-mini">GPT Image Mini</option>
913
  <option value="nano-banana">NanoBanana</option>
914
  </optgroup>
 
 
 
915
  <optgroup label="Other">
916
  <option value="kling-image-o3">Kling Image O3</option>
917
  <option value="wan-2.6">WAN 2.6</option>
@@ -922,6 +927,19 @@ select { cursor: pointer; }
922
  </select>
923
  </div>
924
 
 
 
 
 
 
 
 
 
 
 
 
 
 
925
  <div id="cloud-edit-model-select" style="display:none">
926
  <label>Model</label>
927
  <select id="gen-cloud-edit-model">
@@ -933,6 +951,7 @@ select { cursor: pointer; }
933
  <optgroup label="Multi-Reference (2+ images)">
934
  <option value="seedream-4.5-multi">SeeDream v4.5 Sequential (up to 3)</option>
935
  <option value="seedream-4-multi">SeeDream v4 Sequential (up to 3)</option>
 
936
  <option value="kling-o1-multi">Kling O1 (up to 10 refs)</option>
937
  <option value="qwen-multi-angle">Qwen Multi-Angle</option>
938
  </optgroup>
@@ -948,7 +967,6 @@ select { cursor: pointer; }
948
  <option value="wan-2.5-edit">WAN 2.5 Edit</option>
949
  <option value="wan-2.2-edit">WAN 2.2 Edit</option>
950
  <option value="qwen-edit-lora">Qwen Edit + LoRA</option>
951
- <option value="nano-banana-pro-edit">NanoBanana Pro Edit</option>
952
  <option value="kling-o3-edit">Kling O3 Edit</option>
953
  <option value="dreamina-3-edit">Dreamina v3 Edit</option>
954
  </optgroup>
@@ -964,16 +982,23 @@ select { cursor: pointer; }
964
  <span id="pod-status-indicator">Checking pod status...</span>
965
  </div>
966
  <label>Base Model</label>
967
- <select id="pod-model-select">
 
968
  <option value="flux">FLUX.2 Dev (Realistic)</option>
969
- <option value="wan-t2i">WAN 2.2 (Stylized/Anime)</option>
970
  </select>
971
- <label style="margin-top:8px">Your LoRA</label>
972
  <select id="pod-lora-select">
973
  <option value="">None (Base model only)</option>
974
  </select>
975
- <label style="margin-top:8px">LoRA Strength</label>
976
  <input type="number" id="pod-lora-strength" value="0.85" min="0" max="1.5" step="0.05" style="width:80px">
 
 
 
 
 
 
977
  <div style="font-size:11px;color:var(--text-secondary);margin-top:4px">
978
  Start the pod in Status page first.
979
  </div>
@@ -981,55 +1006,121 @@ select { cursor: pointer; }
981
 
982
  <!-- Image to Video settings -->
983
  <div id="img2video-section" style="display:none">
984
- <div class="section-title">Source Image</div>
985
- <div class="drop-zone" id="video-drop-zone" onclick="document.getElementById('video-file-input').click()">
986
- <svg viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="1.5"><path d="M21 15v4a2 2 0 0 1-2 2H5a2 2 0 0 1-2-2v-4"/><polyline points="17 8 12 3 7 8"/><line x1="12" y1="3" x2="12" y2="15"/></svg>
987
- <div>Drop or click to upload</div>
 
988
  </div>
989
- <input type="file" id="video-file-input" accept="image/*" style="display:none" onchange="handleVideoImage(this)">
990
- <div id="video-preview" style="display:none; margin-top:6px">
991
- <img id="video-preview-img" style="max-width:100%; max-height:100px; border-radius:6px">
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
992
  </div>
993
 
994
- <label style="margin-top:12px">Video Model</label>
995
- <select id="video-cloud-model">
996
- <optgroup label="Recommended">
997
- <option value="wan-2.6-i2v-pro" selected>WAN 2.6 Pro ($0.05/s)</option>
998
- <option value="wan-2.6-i2v-flash">WAN 2.6 Flash (Fast)</option>
999
- <option value="kling-o3-pro">Kling O3 Pro</option>
1000
- </optgroup>
1001
- <optgroup label="Premium (Higgsfield - requires API key)">
1002
- <option value="kling-3.0-pro">Kling 3.0 Pro (15s + Audio)</option>
1003
- <option value="kling-3.0">Kling 3.0</option>
1004
- <option value="sora-2-hf">Sora 2</option>
1005
- <option value="veo-3.1-hf">Veo 3.1</option>
1006
- </optgroup>
1007
- <optgroup label="Budget Friendly">
1008
- <option value="wan-2.2-i2v-720p">WAN 2.2 720p ($0.01/s)</option>
1009
- <option value="wan-2.2-i2v-1080p">WAN 2.2 1080p</option>
1010
- <option value="wan-2.5-i2v">WAN 2.5</option>
1011
- </optgroup>
1012
- <optgroup label="Cinematic">
1013
- <option value="higgsfield-dop">Higgsfield DoP (5s)</option>
1014
- <option value="seedance-1.5-pro">Seedance Pro</option>
1015
- <option value="dreamina-i2v-1080p">Dreamina 1080p</option>
1016
- </optgroup>
1017
- <optgroup label="Other">
1018
- <option value="kling-o3">Kling O3</option>
1019
- <option value="veo-3.1">Veo 3.1 (WaveSpeed)</option>
1020
- <option value="sora-2">Sora 2 (WaveSpeed)</option>
1021
- <option value="vidu-q3">Vidu Q3</option>
1022
- </optgroup>
1023
- </select>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1024
 
1025
- <label>Duration</label>
1026
- <select id="video-duration">
1027
- <option value="41">2s</option>
1028
- <option value="81" selected>3s</option>
1029
- <option value="121">5s</option>
1030
- <option value="241">10s</option>
1031
- <option value="361">15s</option>
1032
- </select>
1033
  </div>
1034
 
1035
  <!-- Reference image upload for img2img -->
@@ -1284,6 +1375,10 @@ select { cursor: pointer; }
1284
  </optgroup>
1285
  </select>
1286
  </div>
 
 
 
 
1287
  </div>
1288
  <div id="runpod-not-configured" style="display:none;margin-top:8px;padding:12px;background:rgba(239,68,68,0.08);border:1px solid var(--red);border-radius:8px;font-size:12px;color:var(--text-secondary)">
1289
  <div style="font-weight:600;color:var(--red);margin-bottom:4px">RunPod Not Configured</div>
@@ -1303,7 +1398,7 @@ select { cursor: pointer; }
1303
  <div>Drop images here or click to browse</div>
1304
  <div style="font-size:11px;margin-top:4px">Upload 20-50 images of the subject (min 5)</div>
1305
  </div>
1306
- <input type="file" id="train-file-input" accept="image/*" multiple style="display:none" onchange="handleTrainImages(this)">
1307
  <div id="train-image-count" style="font-size:12px;color:var(--text-secondary);margin-top:6px"></div>
1308
 
1309
  <!-- Caption editor: shown after images are uploaded -->
@@ -1422,9 +1517,12 @@ select { cursor: pointer; }
1422
  </div>
1423
  <div id="pod-controls" style="display:flex; gap:8px; align-items:center; flex-wrap:wrap">
1424
  <select id="pod-model-type" style="padding:8px 12px; border-radius:6px; background:var(--bg-primary); border:1px solid var(--border); color:var(--text-primary)">
 
1425
  <option value="flux2">FLUX.2 Dev (Realistic txt2img)</option>
1426
  <option value="flux1">FLUX.1 Dev (txt2img)</option>
1427
- <option value="wan22">WAN 2.2 (img2video)</option>
 
 
1428
  </select>
1429
  <select id="pod-gpu-select" style="padding:8px 12px; border-radius:6px; background:var(--bg-primary); border:1px solid var(--border); color:var(--text-primary)">
1430
  <optgroup label="48GB+ (FLUX.2 / Large models)">
@@ -1583,6 +1681,9 @@ let currentPage = 'generate';
1583
  let selectedRating = 'sfw';
1584
  let selectedBackend = 'pod';
1585
  let selectedVideoBackend = 'cloud';
 
 
 
1586
  let selectedMode = 'txt2img';
1587
  let templatesData = [];
1588
  let charactersData = [];
@@ -1832,6 +1933,43 @@ function clearPoseImage() {
1832
  document.getElementById('pose-file-input').value = '';
1833
  }
1834
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1835
  function handleVideoImage(input) {
1836
  if (input.files[0]) {
1837
  videoImageFile = input.files[0];
@@ -1866,15 +2004,50 @@ function clearVideoImage() {
1866
  }
1867
 
1868
  function handleTrainImages(input) {
1869
- trainImageFiles = Array.from(input.files);
1870
- updateTrainCount();
1871
- buildCaptionEditor();
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1872
  }
1873
 
1874
  function handleTrainDrop(files) {
1875
- trainImageFiles = Array.from(files).filter(f => f.type.startsWith('image/'));
1876
- updateTrainCount();
1877
- buildCaptionEditor();
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1878
  }
1879
 
1880
  function updateTrainCount() {
@@ -2144,6 +2317,12 @@ function selectBackend(chip, backend) {
2144
  updateCloudModelVisibility();
2145
  }
2146
 
 
 
 
 
 
 
2147
  function updateDimensions() {
2148
  const aspect = document.getElementById('gen-aspect').value;
2149
  const dimensions = {
@@ -2194,13 +2373,21 @@ function updateCloudModelVisibility() {
2194
  document.getElementById('cloud-model-select').style.display = (isCloud && !isImg2img) ? '' : 'none';
2195
  // Show edit cloud models when cloud + img2img
2196
  document.getElementById('cloud-edit-model-select').style.display = (isCloud && isImg2img) ? '' : 'none';
 
 
2197
  // Show pod settings when pod backend selected (not in video mode)
2198
  document.getElementById('pod-settings-section').style.display = isPod ? '' : 'none';
2199
  if (isPod) {
2200
  loadPodLorasForGeneration();
2201
- // Set FLUX.2 defaults (low CFG, 28 steps)
2202
- document.getElementById('gen-cfg').value = '2';
2203
- document.getElementById('gen-steps').value = '28';
 
 
 
 
 
 
2204
  // Auto-open Advanced section so CFG/steps are visible
2205
  const adv = document.querySelector('#local-settings-section details');
2206
  if (adv) adv.open = true;
@@ -2241,13 +2428,14 @@ async function loadPodLorasForGeneration() {
2241
  const loraRes = await fetch(API + '/api/pod/loras');
2242
  const loraData = await loraRes.json();
2243
 
2244
- loraSelect.innerHTML = '<option value="">None - Base FLUX model</option>';
 
 
2245
  if (loraData.loras && loraData.loras.length > 0) {
2246
  loraData.loras.forEach(lora => {
2247
- const opt = document.createElement('option');
2248
- opt.value = lora;
2249
- opt.text = lora.replace('.safetensors', '');
2250
- loraSelect.appendChild(opt);
2251
  });
2252
  }
2253
 
@@ -2280,6 +2468,31 @@ async function doGenerate() {
2280
  try {
2281
  // img2video mode — video generation
2282
  if (selectedMode === 'img2video') {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2283
  if (!videoImageFile) {
2284
  throw new Error('Please upload an image to animate');
2285
  }
@@ -2288,7 +2501,7 @@ async function doGenerate() {
2288
  formData.append('prompt', document.getElementById('gen-positive').value || 'smooth motion, high quality video');
2289
  formData.append('negative_prompt', document.getElementById('gen-negative').value || 'blurry, low quality, static');
2290
  formData.append('num_frames', document.getElementById('video-duration').value || '81');
2291
- formData.append('fps', document.getElementById('video-fps').value || '24');
2292
  formData.append('seed', document.getElementById('gen-seed').value || '-1');
2293
  formData.append('backend', selectedVideoBackend);
2294
 
@@ -2388,6 +2601,8 @@ async function doGenerate() {
2388
  height: parseInt(document.getElementById('gen-height').value) || 1024,
2389
  lora_name: document.getElementById('pod-lora-select')?.value || null,
2390
  lora_strength: parseFloat(document.getElementById('pod-lora-strength')?.value) || 0.85,
 
 
2391
  character_id: document.getElementById('gen-character').value || null,
2392
  template_id: document.getElementById('gen-template').value || null,
2393
  };
@@ -2406,6 +2621,8 @@ async function doGenerate() {
2406
  return;
2407
  }
2408
 
 
 
2409
  const body = {
2410
  character_id: document.getElementById('gen-character').value || null,
2411
  template_id: document.getElementById('gen-template').value || null,
@@ -2419,6 +2636,7 @@ async function doGenerate() {
2419
  width: parseInt(document.getElementById('gen-width').value) || 832,
2420
  height: parseInt(document.getElementById('gen-height').value) || 1216,
2421
  variables: variables,
 
2422
  };
2423
 
2424
  const endpoint = selectedBackend === 'cloud' ? '/api/generate/cloud' : '/api/generate';
@@ -2536,7 +2754,7 @@ async function pollForVideo(jobId) {
2536
  const preview = document.getElementById('preview-body');
2537
  const startTime = Date.now();
2538
 
2539
- for (let i = 0; i < 120; i++) { // Up to 6 minutes
2540
  await new Promise(r => setTimeout(r, 3000));
2541
 
2542
  try {
@@ -2573,19 +2791,28 @@ function showPreviewVideo(job) {
2573
  const preview = document.getElementById('preview-body');
2574
  preview.innerHTML = `
2575
  <div style="text-align:center;width:100%">
2576
- <video src="/api/video/${job.filename}" autoplay loop muted playsinline
2577
  style="max-width:100%;max-height:70vh;border-radius:8px;margin-bottom:12px"></video>
2578
- <div style="display:flex;gap:8px;justify-content:center;flex-wrap:wrap">
2579
  <span class="tag" style="background:var(--accent);color:white">Video</span>
2580
  <span class="tag" style="background:var(--bg-hover)">${job.num_frames} frames</span>
2581
  <span class="tag" style="background:var(--bg-hover)">${job.fps} fps</span>
 
2582
  </div>
2583
- <p style="color:var(--text-secondary);margin-top:8px;font-size:12px">Seed: ${job.seed || 'N/A'}</p>
2584
  <a href="/api/video/${job.filename}" download class="btn btn-secondary" style="margin-top:12px">Download Video</a>
2585
  </div>
2586
  `;
2587
  }
2588
 
 
 
 
 
 
 
 
 
2589
  // --- Batch ---
2590
  async function doBatch() {
2591
  const btn = document.getElementById('batch-btn');
@@ -2933,6 +3160,69 @@ function updateModelDefaults() {
2933
  break;
2934
  }
2935
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2936
  }
2937
 
2938
  function selectTrainBackend(chip, backend) {
 
898
 
899
  <div id="cloud-model-select" style="display:none">
900
  <label>Model</label>
901
+ <select id="gen-cloud-model" onchange="updateCloudLoraVisibility()">
902
  <optgroup label="Recommended">
903
  <option value="seedream-4.5" selected>SeeDream v4.5 (Best)</option>
904
  <option value="gpt-image-1.5">GPT Image 1.5</option>
 
909
  <option value="seedream-3.1">SeeDream v3.1</option>
910
  </optgroup>
911
  <optgroup label="Fast">
912
+ <option value="z-image-turbo">Z-Image Turbo (Fastest)</option>
913
+ <option value="z-image-turbo-lora">Z-Image Turbo + LoRA</option>
914
  <option value="gpt-image-1-mini">GPT Image Mini</option>
915
  <option value="nano-banana">NanoBanana</option>
916
  </optgroup>
917
+ <optgroup label="LoRA Support">
918
+ <option value="z-image-base-lora">Z-Image Base + LoRA ($0.012)</option>
919
+ </optgroup>
920
  <optgroup label="Other">
921
  <option value="kling-image-o3">Kling Image O3</option>
922
  <option value="wan-2.6">WAN 2.6</option>
 
927
  </select>
928
  </div>
929
 
930
+ <div id="cloud-lora-input" style="display:none">
931
+ <label>LoRA Path <span style="color:var(--text-secondary);font-weight:400">(HuggingFace repo or URL)</span></label>
932
+ <input type="text" id="cloud-lora-path" placeholder="e.g. username/my-character-lora"
933
+ style="width:100%;padding:8px;border-radius:6px;border:1px solid var(--border);background:var(--bg-primary);color:var(--text-primary);font-size:13px;box-sizing:border-box">
934
+ <div style="display:flex;align-items:center;gap:8px;margin-top:6px">
935
+ <label style="margin:0;flex-shrink:0">Strength</label>
936
+ <input type="range" id="cloud-lora-strength" min="0" max="2" step="0.05" value="1"
937
+ oninput="this.nextElementSibling.textContent=this.value"
938
+ style="flex:1">
939
+ <span style="font-size:12px;min-width:28px">1</span>
940
+ </div>
941
+ </div>
942
+
943
  <div id="cloud-edit-model-select" style="display:none">
944
  <label>Model</label>
945
  <select id="gen-cloud-edit-model">
 
951
  <optgroup label="Multi-Reference (2+ images)">
952
  <option value="seedream-4.5-multi">SeeDream v4.5 Sequential (up to 3)</option>
953
  <option value="seedream-4-multi">SeeDream v4 Sequential (up to 3)</option>
954
+ <option value="nano-banana-pro-multi">NanoBanana Pro (2 refs)</option>
955
  <option value="kling-o1-multi">Kling O1 (up to 10 refs)</option>
956
  <option value="qwen-multi-angle">Qwen Multi-Angle</option>
957
  </optgroup>
 
967
  <option value="wan-2.5-edit">WAN 2.5 Edit</option>
968
  <option value="wan-2.2-edit">WAN 2.2 Edit</option>
969
  <option value="qwen-edit-lora">Qwen Edit + LoRA</option>
 
970
  <option value="kling-o3-edit">Kling O3 Edit</option>
971
  <option value="dreamina-3-edit">Dreamina v3 Edit</option>
972
  </optgroup>
 
982
  <span id="pod-status-indicator">Checking pod status...</span>
983
  </div>
984
  <label>Base Model</label>
985
+ <select id="pod-model-select" onchange="updateVisibility()">
986
+ <option value="z_image">Z-Image Turbo (+ LoRA)</option>
987
  <option value="flux">FLUX.2 Dev (Realistic)</option>
988
+ <option value="wan22">WAN 2.2 T2V (txt2img + LoRA)</option>
989
  </select>
990
+ <label style="margin-top:8px">LoRA 1 <span style="color:var(--text-secondary);font-weight:400">(body)</span></label>
991
  <select id="pod-lora-select">
992
  <option value="">None (Base model only)</option>
993
  </select>
994
+ <label style="margin-top:6px">Strength</label>
995
  <input type="number" id="pod-lora-strength" value="0.85" min="0" max="1.5" step="0.05" style="width:80px">
996
+ <label style="margin-top:8px">LoRA 2 <span style="color:var(--text-secondary);font-weight:400">(face)</span></label>
997
+ <select id="pod-lora-select-2">
998
+ <option value="">None</option>
999
+ </select>
1000
+ <label style="margin-top:6px">Strength</label>
1001
+ <input type="number" id="pod-lora-strength-2" value="0.85" min="0" max="1.5" step="0.05" style="width:80px">
1002
  <div style="font-size:11px;color:var(--text-secondary);margin-top:4px">
1003
  Start the pod in Status page first.
1004
  </div>
 
1006
 
1007
  <!-- Image to Video settings -->
1008
  <div id="img2video-section" style="display:none">
1009
+
1010
+ <!-- Sub-mode: Image to Video vs Animate -->
1011
+ <div class="chips" id="video-submode-chips" style="margin-bottom:10px">
1012
+ <div class="chip selected" onclick="selectVideoSubMode(this,'i2v')">Image to Video</div>
1013
+ <div class="chip" onclick="selectVideoSubMode(this,'animate')">Animate (Dance)</div>
1014
  </div>
1015
+
1016
+ <!-- Standard Image-to-Video -->
1017
+ <div id="i2v-sub-section">
1018
+ <div class="section-title">Source Image</div>
1019
+ <div class="drop-zone" id="video-drop-zone" onclick="document.getElementById('video-file-input').click()">
1020
+ <svg viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="1.5"><path d="M21 15v4a2 2 0 0 1-2 2H5a2 2 0 0 1-2-2v-4"/><polyline points="17 8 12 3 7 8"/><line x1="12" y1="3" x2="12" y2="15"/></svg>
1021
+ <div>Drop or click to upload</div>
1022
+ </div>
1023
+ <input type="file" id="video-file-input" accept="image/*" style="display:none" onchange="handleVideoImage(this)">
1024
+ <div id="video-preview" style="display:none; margin-top:6px">
1025
+ <img id="video-preview-img" style="max-width:100%; max-height:100px; border-radius:6px">
1026
+ </div>
1027
+
1028
+ <label style="margin-top:12px">Video Model</label>
1029
+ <select id="video-cloud-model">
1030
+ <optgroup label="Recommended">
1031
+ <option value="wan-2.6-i2v-pro" selected>WAN 2.6 Pro ($0.05/s)</option>
1032
+ <option value="wan-2.6-i2v-flash">WAN 2.6 Flash (Fast)</option>
1033
+ <option value="kling-o3-pro">Kling O3 Pro</option>
1034
+ </optgroup>
1035
+ <optgroup label="Premium (Higgsfield - requires API key)">
1036
+ <option value="kling-3.0-pro">Kling 3.0 Pro (15s + Audio)</option>
1037
+ <option value="kling-3.0">Kling 3.0</option>
1038
+ <option value="sora-2-hf">Sora 2</option>
1039
+ <option value="veo-3.1-hf">Veo 3.1</option>
1040
+ </optgroup>
1041
+ <optgroup label="Budget Friendly">
1042
+ <option value="wan-2.2-i2v-720p">WAN 2.2 720p ($0.01/s)</option>
1043
+ <option value="wan-2.2-i2v-1080p">WAN 2.2 1080p</option>
1044
+ <option value="wan-2.5-i2v">WAN 2.5</option>
1045
+ </optgroup>
1046
+ <optgroup label="Cinematic">
1047
+ <option value="higgsfield-dop">Higgsfield DoP (5s)</option>
1048
+ <option value="seedance-1.5-pro">Seedance Pro</option>
1049
+ <option value="dreamina-i2v-1080p">Dreamina 1080p</option>
1050
+ </optgroup>
1051
+ <optgroup label="Other">
1052
+ <option value="kling-o3">Kling O3</option>
1053
+ <option value="grok-imagine-i2v">Grok Imagine Video (xAI)</option>
1054
+ <option value="veo-3.1">Veo 3.1 (WaveSpeed)</option>
1055
+ <option value="sora-2">Sora 2 (WaveSpeed)</option>
1056
+ <option value="vidu-q3">Vidu Q3</option>
1057
+ </optgroup>
1058
+ </select>
1059
+
1060
+ <label>Duration</label>
1061
+ <select id="video-duration">
1062
+ <option value="41">2s</option>
1063
+ <option value="81" selected>3s</option>
1064
+ <option value="121">5s</option>
1065
+ <option value="241">10s</option>
1066
+ <option value="361">15s</option>
1067
+ </select>
1068
  </div>
1069
 
1070
+ <!-- Animate (Dance) sub-section -->
1071
+ <div id="animate-sub-section" style="display:none">
1072
+ <div class="section-title">Character Image</div>
1073
+ <div class="drop-zone" id="animate-char-zone" onclick="document.getElementById('animate-char-input').click()">
1074
+ <svg viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="1.5"><path d="M21 15v4a2 2 0 0 1-2 2H5a2 2 0 0 1-2-2v-4"/><polyline points="17 8 12 3 7 8"/><line x1="12" y1="3" x2="12" y2="15"/></svg>
1075
+ <div>Character photo</div>
1076
+ </div>
1077
+ <input type="file" id="animate-char-input" accept="image/*" style="display:none" onchange="handleAnimateChar(this)">
1078
+
1079
+ <div class="section-title" style="margin-top:10px">Driving Video</div>
1080
+ <div class="drop-zone" id="animate-video-zone" onclick="document.getElementById('animate-video-input').click()">
1081
+ <svg viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="1.5"><rect x="2" y="2" width="20" height="20" rx="2"/><polygon points="10,8 16,12 10,16"/></svg>
1082
+ <div>Dance video (mp4)</div>
1083
+ </div>
1084
+ <input type="file" id="animate-video-input" accept="video/*" style="display:none" onchange="handleAnimateVideo(this)">
1085
+
1086
+ <label style="margin-top:10px">Resolution</label>
1087
+ <select id="animate-resolution">
1088
+ <option value="480x832">480×832 (portrait)</option>
1089
+ <option value="720x1280" selected>720×1280 (HD portrait)</option>
1090
+ <option value="1080x1920">1080×1920 (TikTok full HD ⚡ high VRAM)</option>
1091
+ <option value="832x480">832×480 (landscape)</option>
1092
+ <option value="1280x720">1280×720 (HD landscape)</option>
1093
+ <option value="512x512">512×512 (square)</option>
1094
+ </select>
1095
+
1096
+ <label>Background</label>
1097
+ <select id="animate-bg-mode">
1098
+ <option value="auto" selected>Auto (model decides)</option>
1099
+ <option value="driving_video">From driving video</option>
1100
+ <option value="keep">Keep (character image bg)</option>
1101
+ </select>
1102
+
1103
+ <label>Frames</label>
1104
+ <select id="animate-frames">
1105
+ <option value="0">Match video (auto)</option>
1106
+ <option value="25">25 (~1.5s)</option>
1107
+ <option value="49">49 (~3s)</option>
1108
+ <option value="81" selected>81 (~5s)</option>
1109
+ <option value="121">121 (~7.5s)</option>
1110
+ <option value="161">161 (~10s)</option>
1111
+ <option value="201">201 (~12.5s)</option>
1112
+ <option value="241">241 (~15s)</option>
1113
+ <option value="289">289 (~18s)</option>
1114
+ <option value="321">321 (~20s)</option>
1115
+ <option value="385">385 (~24s)</option>
1116
+ <option value="481">481 (~30s)</option>
1117
+ </select>
1118
+
1119
+ <div style="font-size:11px;color:var(--text-secondary);margin-top:6px">
1120
+ Runs on RunPod pod via WAN 2.2 Animate. Pod must be running with models installed.
1121
+ </div>
1122
+ </div>
1123
 
 
 
 
 
 
 
 
 
1124
  </div>
1125
 
1126
  <!-- Reference image upload for img2img -->
 
1375
  </optgroup>
1376
  </select>
1377
  </div>
1378
+ <div style="margin-top:8px;padding-top:8px;border-top:1px solid rgba(59,130,246,0.2)">
1379
+ <button class="btn btn-secondary btn-small" onclick="preDownloadModels()" id="btn-predownload">Pre-download models to volume</button>
1380
+ <span id="predownload-status" style="font-size:11px;margin-left:8px;color:var(--text-secondary)"></span>
1381
+ </div>
1382
  </div>
1383
  <div id="runpod-not-configured" style="display:none;margin-top:8px;padding:12px;background:rgba(239,68,68,0.08);border:1px solid var(--red);border-radius:8px;font-size:12px;color:var(--text-secondary)">
1384
  <div style="font-weight:600;color:var(--red);margin-bottom:4px">RunPod Not Configured</div>
 
1398
  <div>Drop images here or click to browse</div>
1399
  <div style="font-size:11px;margin-top:4px">Upload 20-50 images of the subject (min 5)</div>
1400
  </div>
1401
+ <input type="file" id="train-file-input" accept="image/*,.txt" multiple style="display:none" onchange="handleTrainImages(this)">
1402
  <div id="train-image-count" style="font-size:12px;color:var(--text-secondary);margin-top:6px"></div>
1403
 
1404
  <!-- Caption editor: shown after images are uploaded -->
 
1517
  </div>
1518
  <div id="pod-controls" style="display:flex; gap:8px; align-items:center; flex-wrap:wrap">
1519
  <select id="pod-model-type" style="padding:8px 12px; border-radius:6px; background:var(--bg-primary); border:1px solid var(--border); color:var(--text-primary)">
1520
+ <option value="z_image">Z-Image Turbo (txt2img + LoRA)</option>
1521
  <option value="flux2">FLUX.2 Dev (Realistic txt2img)</option>
1522
  <option value="flux1">FLUX.1 Dev (txt2img)</option>
1523
+ <option value="wan22">WAN 2.2 T2V (txt2img + LoRA)</option>
1524
+ <option value="wan22_i2v">WAN 2.2 I2V (img2video)</option>
1525
+ <option value="wan22_animate">WAN 2.2 Animate (Dance/Motion transfer)</option>
1526
  </select>
1527
  <select id="pod-gpu-select" style="padding:8px 12px; border-radius:6px; background:var(--bg-primary); border:1px solid var(--border); color:var(--text-primary)">
1528
  <optgroup label="48GB+ (FLUX.2 / Large models)">
 
1681
  let selectedRating = 'sfw';
1682
  let selectedBackend = 'pod';
1683
  let selectedVideoBackend = 'cloud';
1684
+ let videoSubMode = 'i2v';
1685
+ let animateCharFile = null;
1686
+ let animateDrivingVideoFile = null;
1687
  let selectedMode = 'txt2img';
1688
  let templatesData = [];
1689
  let charactersData = [];
 
1933
  document.getElementById('pose-file-input').value = '';
1934
  }
1935
 
1936
+ function selectVideoSubMode(chip, mode) {
1937
+ chip.parentElement.querySelectorAll('.chip').forEach(c => c.classList.remove('selected'));
1938
+ chip.classList.add('selected');
1939
+ videoSubMode = mode;
1940
+ document.getElementById('i2v-sub-section').style.display = mode === 'i2v' ? '' : 'none';
1941
+ document.getElementById('animate-sub-section').style.display = mode === 'animate' ? '' : 'none';
1942
+ }
1943
+
1944
+ function handleAnimateChar(input) {
1945
+ if (!input.files[0]) return;
1946
+ animateCharFile = input.files[0];
1947
+ const zone = document.getElementById('animate-char-zone');
1948
+ zone.classList.add('has-file');
1949
+ const reader = new FileReader();
1950
+ reader.onload = e => {
1951
+ zone.innerHTML = `
1952
+ <img src="${e.target.result}" style="max-height:120px;border-radius:6px">
1953
+ <div style="margin-top:4px;font-size:11px">${input.files[0].name}</div>
1954
+ <button class="btn btn-secondary btn-small" onclick="event.stopPropagation();animateCharFile=null;this.closest('.drop-zone').classList.remove('has-file');this.closest('.drop-zone').innerHTML='<svg viewBox=\\'0 0 24 24\\' fill=\\'none\\' stroke=\\'currentColor\\' stroke-width=\\'1.5\\'><path d=\\'M21 15v4a2 2 0 0 1-2 2H5a2 2 0 0 1-2-2v-4\\'/><polyline points=\\'17 8 12 3 7 8\\'/><line x1=\\'12\\' y1=\\'3\\' x2=\\'12\\' y2=\\'15\\'/></svg><div>Character photo</div>'" style="margin-top:6px">Remove</button>
1955
+ `;
1956
+ };
1957
+ reader.readAsDataURL(input.files[0]);
1958
+ }
1959
+
1960
+ function handleAnimateVideo(input) {
1961
+ if (!input.files[0]) return;
1962
+ animateDrivingVideoFile = input.files[0];
1963
+ const zone = document.getElementById('animate-video-zone');
1964
+ zone.classList.add('has-file');
1965
+ zone.innerHTML = `
1966
+ <svg viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="1.5"><rect x="2" y="2" width="20" height="20" rx="2"/><polygon points="10,8 16,12 10,16"/></svg>
1967
+ <div style="font-size:12px;margin-top:4px">${input.files[0].name}</div>
1968
+ <div style="font-size:11px;color:var(--text-secondary)">${(input.files[0].size/1024/1024).toFixed(1)} MB</div>
1969
+ <button class="btn btn-secondary btn-small" onclick="event.stopPropagation();animateDrivingVideoFile=null;this.closest('.drop-zone').classList.remove('has-file');this.closest('.drop-zone').innerHTML='<svg viewBox=\\'0 0 24 24\\' fill=\\'none\\' stroke=\\'currentColor\\' stroke-width=\\'1.5\\'><rect x=\\'2\\' y=\\'2\\' width=\\'20\\' height=\\'20\\' rx=\\'2\\'/><polygon points=\\'10,8 16,12 10,16\\'/></svg><div>Dance video (mp4)</div>'" style="margin-top:6px">Remove</button>
1970
+ `;
1971
+ }
1972
+
1973
  function handleVideoImage(input) {
1974
  if (input.files[0]) {
1975
  videoImageFile = input.files[0];
 
2004
  }
2005
 
2006
  function handleTrainImages(input) {
2007
+ const allFiles = Array.from(input.files);
2008
+ const imageFiles = allFiles.filter(f => f.type.startsWith('image/'));
2009
+ const txtFiles = allFiles.filter(f => f.name.endsWith('.txt'));
2010
+ trainImageFiles = imageFiles;
2011
+ // Auto-load captions from .txt files matching image filenames (e.g. 1.txt -> 1.png)
2012
+ if (txtFiles.length > 0) {
2013
+ const pending = txtFiles.map(tf => tf.text().then(text => {
2014
+ const baseName = tf.name.replace(/\.txt$/, '');
2015
+ const matchImg = imageFiles.find(img => img.name.replace(/\.[^.]+$/, '') === baseName);
2016
+ if (matchImg) trainCaptions[matchImg.name] = text.trim();
2017
+ }));
2018
+ Promise.all(pending).then(() => {
2019
+ updateTrainCount();
2020
+ buildCaptionEditor();
2021
+ const loaded = Object.keys(trainCaptions).length;
2022
+ if (loaded > 0) toast(`Loaded ${loaded} captions from .txt files`, 'success');
2023
+ });
2024
+ } else {
2025
+ updateTrainCount();
2026
+ buildCaptionEditor();
2027
+ }
2028
  }
2029
 
2030
  function handleTrainDrop(files) {
2031
+ const allFiles = Array.from(files);
2032
+ const imageFiles = allFiles.filter(f => f.type.startsWith('image/'));
2033
+ const txtFiles = allFiles.filter(f => f.name.endsWith('.txt'));
2034
+ trainImageFiles = imageFiles;
2035
+ if (txtFiles.length > 0) {
2036
+ const pending = txtFiles.map(tf => tf.text().then(text => {
2037
+ const baseName = tf.name.replace(/\.txt$/, '');
2038
+ const matchImg = imageFiles.find(img => img.name.replace(/\.[^.]+$/, '') === baseName);
2039
+ if (matchImg) trainCaptions[matchImg.name] = text.trim();
2040
+ }));
2041
+ Promise.all(pending).then(() => {
2042
+ updateTrainCount();
2043
+ buildCaptionEditor();
2044
+ const loaded = Object.keys(trainCaptions).length;
2045
+ if (loaded > 0) toast(`Loaded ${loaded} captions from .txt files`, 'success');
2046
+ });
2047
+ } else {
2048
+ updateTrainCount();
2049
+ buildCaptionEditor();
2050
+ }
2051
  }
2052
 
2053
  function updateTrainCount() {
 
2317
  updateCloudModelVisibility();
2318
  }
2319
 
2320
+ function updateCloudLoraVisibility() {
2321
+ const model = document.getElementById('gen-cloud-model')?.value || '';
2322
+ const loraInput = document.getElementById('cloud-lora-input');
2323
+ if (loraInput) loraInput.style.display = model.includes('-lora') ? '' : 'none';
2324
+ }
2325
+
2326
  function updateDimensions() {
2327
  const aspect = document.getElementById('gen-aspect').value;
2328
  const dimensions = {
 
2373
  document.getElementById('cloud-model-select').style.display = (isCloud && !isImg2img) ? '' : 'none';
2374
  // Show edit cloud models when cloud + img2img
2375
  document.getElementById('cloud-edit-model-select').style.display = (isCloud && isImg2img) ? '' : 'none';
2376
+ // Show LoRA input for z-image lora models
2377
+ updateCloudLoraVisibility();
2378
  // Show pod settings when pod backend selected (not in video mode)
2379
  document.getElementById('pod-settings-section').style.display = isPod ? '' : 'none';
2380
  if (isPod) {
2381
  loadPodLorasForGeneration();
2382
+ // Set defaults based on pod model type
2383
+ const podModel = document.getElementById('pod-model-select')?.value || '';
2384
+ if (podModel.startsWith('wan22')) {
2385
+ document.getElementById('gen-cfg').value = '1';
2386
+ document.getElementById('gen-steps').value = '8';
2387
+ } else {
2388
+ document.getElementById('gen-cfg').value = '2';
2389
+ document.getElementById('gen-steps').value = '28';
2390
+ }
2391
  // Auto-open Advanced section so CFG/steps are visible
2392
  const adv = document.querySelector('#local-settings-section details');
2393
  if (adv) adv.open = true;
 
2428
  const loraRes = await fetch(API + '/api/pod/loras');
2429
  const loraData = await loraRes.json();
2430
 
2431
+ const loraSelect2 = document.getElementById('pod-lora-select-2');
2432
+ loraSelect.innerHTML = '<option value="">None - Base model</option>';
2433
+ if (loraSelect2) loraSelect2.innerHTML = '<option value="">None</option>';
2434
  if (loraData.loras && loraData.loras.length > 0) {
2435
  loraData.loras.forEach(lora => {
2436
+ const label = lora.replace('.safetensors', '');
2437
+ loraSelect.appendChild(Object.assign(document.createElement('option'), { value: lora, text: label }));
2438
+ if (loraSelect2) loraSelect2.appendChild(Object.assign(document.createElement('option'), { value: lora, text: label }));
 
2439
  });
2440
  }
2441
 
 
2468
  try {
2469
  // img2video mode — video generation
2470
  if (selectedMode === 'img2video') {
2471
+
2472
+ // Animate (Dance) sub-mode — WAN 2.2 Animate on RunPod
2473
+ if (videoSubMode === 'animate') {
2474
+ if (!animateCharFile) throw new Error('Please upload a character image');
2475
+ if (!animateDrivingVideoFile) throw new Error('Please upload a driving dance video');
2476
+ const resParts = document.getElementById('animate-resolution').value.split('x');
2477
+ const formData = new FormData();
2478
+ formData.append('image', animateCharFile);
2479
+ formData.append('driving_video', animateDrivingVideoFile);
2480
+ formData.append('prompt', document.getElementById('gen-positive').value || 'a person dancing, smooth motion, high quality');
2481
+ formData.append('negative_prompt', document.getElementById('gen-negative').value || '');
2482
+ formData.append('width', resParts[0] || '832');
2483
+ formData.append('height', resParts[1] || '480');
2484
+ formData.append('num_frames', document.getElementById('animate-frames').value || '81');
2485
+ formData.append('bg_mode', document.getElementById('animate-bg-mode').value || 'keep');
2486
+ formData.append('seed', document.getElementById('gen-seed').value || '-1');
2487
+ const res = await fetch(API + '/api/video/animate', { method: 'POST', body: formData });
2488
+ const data = await res.json();
2489
+ if (!res.ok) throw new Error(data.detail || 'Animate generation failed');
2490
+ toast('Animation generating on RunPod (WAN 2.2 Animate)...', 'info');
2491
+ await pollForVideo(data.job_id);
2492
+ return;
2493
+ }
2494
+
2495
+ // Standard Image-to-Video
2496
  if (!videoImageFile) {
2497
  throw new Error('Please upload an image to animate');
2498
  }
 
2501
  formData.append('prompt', document.getElementById('gen-positive').value || 'smooth motion, high quality video');
2502
  formData.append('negative_prompt', document.getElementById('gen-negative').value || 'blurry, low quality, static');
2503
  formData.append('num_frames', document.getElementById('video-duration').value || '81');
2504
+ formData.append('fps', document.getElementById('video-fps')?.value || '24');
2505
  formData.append('seed', document.getElementById('gen-seed').value || '-1');
2506
  formData.append('backend', selectedVideoBackend);
2507
 
 
2601
  height: parseInt(document.getElementById('gen-height').value) || 1024,
2602
  lora_name: document.getElementById('pod-lora-select')?.value || null,
2603
  lora_strength: parseFloat(document.getElementById('pod-lora-strength')?.value) || 0.85,
2604
+ lora_name_2: document.getElementById('pod-lora-select-2')?.value || null,
2605
+ lora_strength_2: parseFloat(document.getElementById('pod-lora-strength-2')?.value) || 0.85,
2606
  character_id: document.getElementById('gen-character').value || null,
2607
  template_id: document.getElementById('gen-template').value || null,
2608
  };
 
2621
  return;
2622
  }
2623
 
2624
+ const cloudLoraPath = document.getElementById('cloud-lora-path')?.value?.trim();
2625
+ const cloudLoraStrength = parseFloat(document.getElementById('cloud-lora-strength')?.value) || 1.0;
2626
  const body = {
2627
  character_id: document.getElementById('gen-character').value || null,
2628
  template_id: document.getElementById('gen-template').value || null,
 
2636
  width: parseInt(document.getElementById('gen-width').value) || 832,
2637
  height: parseInt(document.getElementById('gen-height').value) || 1216,
2638
  variables: variables,
2639
+ loras: cloudLoraPath ? [{ name: cloudLoraPath, strength_model: cloudLoraStrength, strength_clip: cloudLoraStrength }] : [],
2640
  };
2641
 
2642
  const endpoint = selectedBackend === 'cloud' ? '/api/generate/cloud' : '/api/generate';
 
2754
  const preview = document.getElementById('preview-body');
2755
  const startTime = Date.now();
2756
 
2757
+ for (let i = 0; i < 600; i++) { // Up to 30 minutes
2758
  await new Promise(r => setTimeout(r, 3000));
2759
 
2760
  try {
 
2791
  const preview = document.getElementById('preview-body');
2792
  preview.innerHTML = `
2793
  <div style="text-align:center;width:100%">
2794
+ <video id="preview-video" src="/api/video/${job.filename}" autoplay loop controls playsinline
2795
  style="max-width:100%;max-height:70vh;border-radius:8px;margin-bottom:12px"></video>
2796
+ <div style="display:flex;gap:8px;justify-content:center;flex-wrap:wrap;margin-bottom:8px">
2797
  <span class="tag" style="background:var(--accent);color:white">Video</span>
2798
  <span class="tag" style="background:var(--bg-hover)">${job.num_frames} frames</span>
2799
  <span class="tag" style="background:var(--bg-hover)">${job.fps} fps</span>
2800
+ <button id="audio-toggle-btn" onclick="toggleVideoAudio()" style="padding:4px 12px;border-radius:6px;border:1px solid var(--border);background:var(--bg-secondary);color:var(--text-primary);cursor:pointer;font-size:12px">🔇 Unmute</button>
2801
  </div>
2802
+ <p style="color:var(--text-secondary);margin-top:4px;font-size:12px">Seed: ${job.seed || 'N/A'}</p>
2803
  <a href="/api/video/${job.filename}" download class="btn btn-secondary" style="margin-top:12px">Download Video</a>
2804
  </div>
2805
  `;
2806
  }
2807
 
2808
+ function toggleVideoAudio() {
2809
+ const video = document.getElementById('preview-video');
2810
+ const btn = document.getElementById('audio-toggle-btn');
2811
+ if (!video) return;
2812
+ video.muted = !video.muted;
2813
+ btn.textContent = video.muted ? '🔇 Unmute' : '🔊 Mute';
2814
+ }
2815
+
2816
  // --- Batch ---
2817
  async function doBatch() {
2818
  const btn = document.getElementById('batch-btn');
 
3160
  break;
3161
  }
3162
  }
3163
+
3164
+ // Auto-select GPU for models that need specific hardware
3165
+ const gpuSelect = document.getElementById('train-gpu-type');
3166
+ if (gpuSelect) {
3167
+ const modelType = model.model_type || '';
3168
+ if (modelType === 'wan22') {
3169
+ // WAN 2.2 needs A100 80GB
3170
+ for (let opt of gpuSelect.options) {
3171
+ if (opt.value.includes('A100-SXM4')) { opt.selected = true; break; }
3172
+ }
3173
+ } else if (modelType === 'flux2') {
3174
+ // FLUX.2 needs 48GB+ — default to A6000
3175
+ for (let opt of gpuSelect.options) {
3176
+ if (opt.value.includes('A6000')) { opt.selected = true; break; }
3177
+ }
3178
+ }
3179
+ }
3180
+ }
3181
+
3182
+ async function preDownloadModels() {
3183
+ const btn = document.getElementById('btn-predownload');
3184
+ const status = document.getElementById('predownload-status');
3185
+ const modelKey = document.getElementById('train-base-model').value;
3186
+ const model = trainingModels[modelKey];
3187
+ const modelType = model?.model_type || 'wan22';
3188
+
3189
+ btn.disabled = true;
3190
+ status.textContent = 'Starting download pod...';
3191
+ status.style.color = 'var(--blue)';
3192
+
3193
+ try {
3194
+ const res = await fetch(API + '/api/pod/download-models', {
3195
+ method: 'POST',
3196
+ headers: {'Content-Type': 'application/json'},
3197
+ body: JSON.stringify({model_type: modelType, gpu_type: 'NVIDIA GeForce RTX 3090'})
3198
+ });
3199
+ const data = await res.json();
3200
+ if (!res.ok) { throw new Error(data.detail || 'Failed'); }
3201
+
3202
+ // Poll for progress
3203
+ const poll = setInterval(async () => {
3204
+ try {
3205
+ const r = await fetch(API + '/api/pod/download-models/status');
3206
+ const d = await r.json();
3207
+ status.textContent = d.progress || d.status;
3208
+ if (d.status === 'completed') {
3209
+ clearInterval(poll);
3210
+ status.style.color = 'var(--green)';
3211
+ btn.disabled = false;
3212
+ btn.textContent = 'Models downloaded!';
3213
+ } else if (d.status === 'failed') {
3214
+ clearInterval(poll);
3215
+ status.style.color = 'var(--red)';
3216
+ status.textContent = 'Failed: ' + (d.error || 'unknown');
3217
+ btn.disabled = false;
3218
+ }
3219
+ } catch(e) { /* ignore poll errors */ }
3220
+ }, 5000);
3221
+ } catch(e) {
3222
+ status.textContent = 'Error: ' + e.message;
3223
+ status.style.color = 'var(--red)';
3224
+ btn.disabled = false;
3225
+ }
3226
  }
3227
 
3228
  function selectTrainBackend(chip, backend) {
src/content_engine/services/cloud_providers/wavespeed_provider.py CHANGED
@@ -27,17 +27,10 @@ import uuid
27
  from typing import Any
28
 
29
  import httpx
 
30
 
31
  from content_engine.services.cloud_providers.base import CloudGenerationResult, CloudProvider
32
 
33
- # Optional wavespeed SDK import
34
- try:
35
- from wavespeed import Client as WaveSpeedClient
36
- WAVESPEED_SDK_AVAILABLE = True
37
- except ImportError:
38
- WaveSpeedClient = None
39
- WAVESPEED_SDK_AVAILABLE = False
40
-
41
  logger = logging.getLogger(__name__)
42
 
43
  # Map friendly names to WaveSpeed model IDs (text-to-image)
@@ -53,6 +46,10 @@ MODEL_MAP = {
53
  # WAN (Alibaba)
54
  "wan-2.6": "alibaba/wan-2.6/text-to-image",
55
  "wan-2.5": "alibaba/wan-2.5/text-to-image",
 
 
 
 
56
  # Qwen (WaveSpeed)
57
  "qwen-image": "wavespeed-ai/qwen-image/text-to-image",
58
  # GPT Image (OpenAI)
@@ -96,6 +93,8 @@ VIDEO_MODEL_MAP = {
96
  "dreamina-i2v-720p": "bytedance/dreamina-v3.0/image-to-video-720p",
97
  # Sora (OpenAI)
98
  "sora-2": "openai/sora-2/image-to-video",
 
 
99
  # Vidu
100
  "vidu-q3": "vidu/q3-turbo/image-to-video",
101
  # Default
@@ -141,6 +140,8 @@ MULTI_REF_MODELS = {
141
  # SeeDream Sequential (up to 3 images for character consistency)
142
  "seedream-4.5-multi": "bytedance/seedream-v4.5/edit-sequential",
143
  "seedream-4-multi": "bytedance/seedream-v4/edit-sequential",
 
 
144
  # Kling O1 (up to 10 reference images)
145
  "kling-o1-multi": "kwaivgi/kling-o1/image-to-image",
146
  # Qwen Multi-Angle (multiple angles of same subject)
@@ -165,12 +166,7 @@ class WaveSpeedProvider(CloudProvider):
165
 
166
  def __init__(self, api_key: str):
167
  self._api_key = api_key
168
- self._client = None
169
- if WAVESPEED_SDK_AVAILABLE and WaveSpeedClient:
170
- try:
171
- self._client = WaveSpeedClient(api_key=api_key)
172
- except Exception as e:
173
- logger.warning("Failed to initialize WaveSpeed SDK: %s", e)
174
  self._http_client = httpx.AsyncClient(timeout=300)
175
 
176
  @property
@@ -390,29 +386,12 @@ class WaveSpeedProvider(CloudProvider):
390
  logger.info("Submitting to WaveSpeed model=%s", wavespeed_model)
391
 
392
  try:
393
- if self._client:
394
- # Use SDK if available
395
- output = self._client.run(
396
- wavespeed_model,
397
- payload,
398
- timeout=300.0,
399
- poll_interval=2.0,
400
- )
401
- else:
402
- # Fall back to direct HTTP API
403
- endpoint = f"{WAVESPEED_API_BASE}/{wavespeed_model}"
404
- payload["enable_sync_mode"] = True
405
- resp = await self._http_client.post(
406
- endpoint,
407
- json=payload,
408
- headers={
409
- "Authorization": f"Bearer {self._api_key}",
410
- "Content-Type": "application/json",
411
- },
412
- )
413
- resp.raise_for_status()
414
- output = resp.json()
415
-
416
  job_id = str(uuid.uuid4())
417
  self._last_result = {
418
  "job_id": job_id,
@@ -613,26 +592,20 @@ class WaveSpeedProvider(CloudProvider):
613
 
614
  async def is_available(self) -> bool:
615
  """Check if WaveSpeed API is reachable with valid credentials."""
616
- # Try SDK first if available
617
- if self._client:
618
- try:
619
- self._client.run(
620
- "wavespeed-ai/z-image/turbo",
621
- {"prompt": "test"},
622
- enable_sync_mode=True,
623
- timeout=10.0,
624
- )
625
- return True
626
- except Exception:
627
- pass
628
-
629
- # Fall back to HTTP health check
630
  try:
631
- resp = await self._http_client.get(
632
- "https://api.wavespeed.ai/api/v3/health",
633
- headers={"Authorization": f"Bearer {self._api_key}"},
 
634
  timeout=10.0,
635
  )
636
- return resp.status_code < 500
637
  except Exception:
638
- return False
 
 
 
 
 
 
 
 
27
  from typing import Any
28
 
29
  import httpx
30
+ from wavespeed import Client as WaveSpeedClient
31
 
32
  from content_engine.services.cloud_providers.base import CloudGenerationResult, CloudProvider
33
 
 
 
 
 
 
 
 
 
34
  logger = logging.getLogger(__name__)
35
 
36
  # Map friendly names to WaveSpeed model IDs (text-to-image)
 
46
  # WAN (Alibaba)
47
  "wan-2.6": "alibaba/wan-2.6/text-to-image",
48
  "wan-2.5": "alibaba/wan-2.5/text-to-image",
49
+ # Z-Image (WaveSpeed) — supports LoRA, ultra fast
50
+ "z-image-turbo": "wavespeed-ai/z-image/turbo",
51
+ "z-image-turbo-lora": "wavespeed-ai/z-image/turbo-lora",
52
+ "z-image-base-lora": "wavespeed-ai/z-image/base-lora",
53
  # Qwen (WaveSpeed)
54
  "qwen-image": "wavespeed-ai/qwen-image/text-to-image",
55
  # GPT Image (OpenAI)
 
93
  "dreamina-i2v-720p": "bytedance/dreamina-v3.0/image-to-video-720p",
94
  # Sora (OpenAI)
95
  "sora-2": "openai/sora-2/image-to-video",
96
+ # Grok (xAI)
97
+ "grok-imagine-i2v": "x-ai/grok-imagine-video/image-to-video",
98
  # Vidu
99
  "vidu-q3": "vidu/q3-turbo/image-to-video",
100
  # Default
 
140
  # SeeDream Sequential (up to 3 images for character consistency)
141
  "seedream-4.5-multi": "bytedance/seedream-v4.5/edit-sequential",
142
  "seedream-4-multi": "bytedance/seedream-v4/edit-sequential",
143
+ # NanoBanana Pro (Google) - multi-reference edit
144
+ "nano-banana-pro-multi": "google/nano-banana-pro/edit",
145
  # Kling O1 (up to 10 reference images)
146
  "kling-o1-multi": "kwaivgi/kling-o1/image-to-image",
147
  # Qwen Multi-Angle (multiple angles of same subject)
 
166
 
167
  def __init__(self, api_key: str):
168
  self._api_key = api_key
169
+ self._client = WaveSpeedClient(api_key=api_key)
 
 
 
 
 
170
  self._http_client = httpx.AsyncClient(timeout=300)
171
 
172
  @property
 
386
  logger.info("Submitting to WaveSpeed model=%s", wavespeed_model)
387
 
388
  try:
389
+ output = self._client.run(
390
+ wavespeed_model,
391
+ payload,
392
+ timeout=300.0,
393
+ poll_interval=2.0,
394
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
395
  job_id = str(uuid.uuid4())
396
  self._last_result = {
397
  "job_id": job_id,
 
592
 
593
  async def is_available(self) -> bool:
594
  """Check if WaveSpeed API is reachable with valid credentials."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
595
  try:
596
+ test = self._client.run(
597
+ "wavespeed-ai/z-image/turbo",
598
+ {"prompt": "test"},
599
+ enable_sync_mode=True,
600
  timeout=10.0,
601
  )
602
+ return True
603
  except Exception:
604
+ try:
605
+ resp = await self._http_client.get(
606
+ "https://api.wavespeed.ai/api/v3/health",
607
+ headers={"Authorization": f"Bearer {self._api_key}"},
608
+ )
609
+ return resp.status_code < 500
610
+ except Exception:
611
+ return False
src/content_engine/services/runpod_trainer.py CHANGED
@@ -472,6 +472,58 @@ print('Downloaded ae.safetensors')
472
 
473
  job._log("FLUX.2 Dev models ready")
474
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
475
  else:
476
  # SD 1.5 / SDXL / FLUX.1 — download single model file
477
  model_exists = (await self._ssh_exec(ssh, f"test -f /workspace/models/{hf_filename} && echo EXISTS || echo MISSING")).strip()
@@ -512,6 +564,8 @@ hf_hub_download('black-forest-labs/FLUX.1-dev', 'ae.safetensors', local_dir='/wo
512
 
513
  if model_type == "flux2":
514
  model_path = f"/workspace/models/FLUX.2-dev/flux2-dev.safetensors"
 
 
515
  else:
516
  model_path = f"/workspace/models/{hf_filename}"
517
 
@@ -529,47 +583,87 @@ resolution = [{resolution}, {resolution}]
529
  job._log("Created dataset.toml config")
530
 
531
  # musubi-tuner requires pre-caching latents and text encoder outputs
532
- flux2_dir = "/workspace/models/FLUX.2-dev"
533
- vae_path = f"{flux2_dir}/ae.safetensors"
534
- te_path = f"{flux2_dir}/text_encoder/model-00001-of-00010.safetensors"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
535
 
536
- job._log("Caching latents (VAE encoding)...")
537
- job.progress = 0.15
538
- self._schedule_db_save(job)
539
- cache_latents_cmd = (
540
- f"cd /workspace/musubi-tuner && PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True python src/musubi_tuner/flux_2_cache_latents.py"
541
- f" --dataset_config /workspace/dataset.toml"
542
- f" --vae {vae_path}"
543
- f" --model_version dev"
544
- f" --vae_dtype bfloat16"
545
- f" 2>&1 | tee /tmp/cache_latents.log; echo EXIT_CODE=${{PIPESTATUS[0]}}"
546
- )
547
- out = await self._ssh_exec(ssh, cache_latents_cmd, timeout=600)
548
- # Get last lines which have the real error
549
- last_lines = out.split('\n')[-30:]
550
- job._log('\n'.join(last_lines))
551
- if "EXIT_CODE=0" not in out:
552
- # Fetch the full error log
553
- err_log = await self._ssh_exec(ssh, "grep -i 'error\\|exception\\|traceback\\|failed' /tmp/cache_latents.log | tail -10")
554
- job._log(f"Cache error details: {err_log}")
555
- raise RuntimeError(f"Latent caching failed")
556
-
557
- job._log("Caching text encoder outputs (bf16)...")
558
- job.progress = 0.25
559
- self._schedule_db_save(job)
560
- cache_te_cmd = (
561
- f"cd /workspace/musubi-tuner && PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True"
562
- f" python src/musubi_tuner/flux_2_cache_text_encoder_outputs.py"
563
- f" --dataset_config /workspace/dataset.toml"
564
- f" --text_encoder {te_path}"
565
- f" --model_version dev"
566
- f" --batch_size 1"
567
- f" 2>&1; echo EXIT_CODE=$?"
568
- )
569
- out = await self._ssh_exec(ssh, cache_te_cmd, timeout=600)
570
- job._log(out[-500:] if out else "done")
571
- if "EXIT_CODE=0" not in out:
572
- raise RuntimeError(f"Text encoder caching failed: {out[-200:]}")
 
 
 
 
573
 
574
  # Build training command based on model type
575
  train_cmd = self._build_training_command(
@@ -689,6 +783,16 @@ resolution = [{resolution}, {resolution}]
689
  await self._ssh_exec(ssh, f"cp {remote_output} /runpod-volume/loras/{name}.safetensors")
690
  job._log(f"LoRA saved to volume: /runpod-volume/loras/{name}.safetensors")
691
 
 
 
 
 
 
 
 
 
 
 
692
  # Download locally (skip on HF Spaces — limited storage)
693
  if IS_HF_SPACES:
694
  job.output_path = f"/runpod-volume/loras/{name}.safetensors"
@@ -1098,6 +1202,48 @@ resolution = [{resolution}, {resolution}]
1098
 
1099
  return " ".join(args) + " 2>&1"
1100
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1101
  elif model_type == "flux":
1102
  # FLUX.1 training via sd-scripts
1103
  script = "flux_train_network.py"
 
472
 
473
  job._log("FLUX.2 Dev models ready")
474
 
475
+ elif model_type == "wan22":
476
+ # WAN 2.2 T2V — 4 model files stored in /workspace/models/WAN2.2/
477
+ wan_dir = "/workspace/models/WAN2.2"
478
+ await self._ssh_exec(ssh, f"mkdir -p {wan_dir}")
479
+
480
+ wan_files = {
481
+ "DiT low-noise": {
482
+ "path": f"{wan_dir}/wan2.2_t2v_low_noise_14B_fp16.safetensors",
483
+ "repo": "Comfy-Org/Wan_2.2_ComfyUI_Repackaged",
484
+ "filename": "split_files/diffusion_models/wan2.2_t2v_low_noise_14B_fp16.safetensors",
485
+ },
486
+ "DiT high-noise": {
487
+ "path": f"{wan_dir}/wan2.2_t2v_high_noise_14B_fp16.safetensors",
488
+ "repo": "Comfy-Org/Wan_2.2_ComfyUI_Repackaged",
489
+ "filename": "split_files/diffusion_models/wan2.2_t2v_high_noise_14B_fp16.safetensors",
490
+ },
491
+ "VAE": {
492
+ "path": f"{wan_dir}/Wan2.1_VAE.pth",
493
+ "repo": "Wan-AI/Wan2.1-I2V-14B-720P",
494
+ "filename": "Wan2.1_VAE.pth",
495
+ },
496
+ "T5 text encoder": {
497
+ "path": f"{wan_dir}/models_t5_umt5-xxl-enc-bf16.pth",
498
+ "repo": "Wan-AI/Wan2.1-I2V-14B-720P",
499
+ "filename": "models_t5_umt5-xxl-enc-bf16.pth",
500
+ },
501
+ }
502
+
503
+ for label, info in wan_files.items():
504
+ exists = (await self._ssh_exec(ssh, f"test -f {info['path']} && echo EXISTS || echo MISSING")).strip()
505
+ if exists == "EXISTS":
506
+ job._log(f"WAN 2.2 {label} already cached")
507
+ else:
508
+ job._log(f"Downloading WAN 2.2 {label}...")
509
+ await self._ssh_exec(ssh, f"""python -c "
510
+ from huggingface_hub import hf_hub_download
511
+ hf_hub_download('{info['repo']}', '{info['filename']}', local_dir='{wan_dir}')
512
+ # hf_hub_download puts files in subdirs matching the filename path — move to root
513
+ import os, shutil
514
+ downloaded = os.path.join('{wan_dir}', '{info['filename']}')
515
+ target = '{info['path']}'
516
+ if os.path.exists(downloaded) and downloaded != target:
517
+ shutil.move(downloaded, target)
518
+ print('Downloaded {label}')
519
+ " 2>&1 | tail -5""", timeout=1800)
520
+ # Verify
521
+ check = (await self._ssh_exec(ssh, f"test -f {info['path']} && echo EXISTS || echo MISSING")).strip()
522
+ if check != "EXISTS":
523
+ raise RuntimeError(f"Failed to download WAN 2.2 {label}")
524
+
525
+ job._log("WAN 2.2 models ready")
526
+
527
  else:
528
  # SD 1.5 / SDXL / FLUX.1 — download single model file
529
  model_exists = (await self._ssh_exec(ssh, f"test -f /workspace/models/{hf_filename} && echo EXISTS || echo MISSING")).strip()
 
564
 
565
  if model_type == "flux2":
566
  model_path = f"/workspace/models/FLUX.2-dev/flux2-dev.safetensors"
567
+ elif model_type == "wan22":
568
+ model_path = "/workspace/models/WAN2.2/wan2.2_t2v_low_noise_14B_fp16.safetensors"
569
  else:
570
  model_path = f"/workspace/models/{hf_filename}"
571
 
 
583
  job._log("Created dataset.toml config")
584
 
585
  # musubi-tuner requires pre-caching latents and text encoder outputs
586
+ if model_type == "wan22":
587
+ wan_dir = "/workspace/models/WAN2.2"
588
+ vae_path = f"{wan_dir}/Wan2.1_VAE.pth"
589
+ te_path = f"{wan_dir}/models_t5_umt5-xxl-enc-bf16.pth"
590
+
591
+ job._log("Caching WAN 2.2 latents (VAE encoding)...")
592
+ job.progress = 0.15
593
+ self._schedule_db_save(job)
594
+ cache_latents_cmd = (
595
+ f"cd /workspace/musubi-tuner && PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True"
596
+ f" python src/musubi_tuner/wan_cache_latents.py"
597
+ f" --dataset_config /workspace/dataset.toml"
598
+ f" --vae {vae_path}"
599
+ f" --vae_dtype bfloat16"
600
+ f" 2>&1 | tee /tmp/cache_latents.log; echo EXIT_CODE=${{PIPESTATUS[0]}}"
601
+ )
602
+ out = await self._ssh_exec(ssh, cache_latents_cmd, timeout=600)
603
+ last_lines = out.split('\n')[-30:]
604
+ job._log('\n'.join(last_lines))
605
+ if "EXIT_CODE=0" not in out:
606
+ err_log = await self._ssh_exec(ssh, "grep -i 'error\\|exception\\|traceback\\|failed' /tmp/cache_latents.log | tail -10")
607
+ job._log(f"Cache error details: {err_log}")
608
+ raise RuntimeError(f"WAN latent caching failed")
609
+
610
+ job._log("Caching WAN 2.2 text encoder outputs (T5)...")
611
+ job.progress = 0.25
612
+ self._schedule_db_save(job)
613
+ cache_te_cmd = (
614
+ f"cd /workspace/musubi-tuner && PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True"
615
+ f" python src/musubi_tuner/wan_cache_text_encoder_outputs.py"
616
+ f" --dataset_config /workspace/dataset.toml"
617
+ f" --t5 {te_path}"
618
+ f" --batch_size 16"
619
+ f" 2>&1; echo EXIT_CODE=$?"
620
+ )
621
+ out = await self._ssh_exec(ssh, cache_te_cmd, timeout=600)
622
+ job._log(out[-500:] if out else "done")
623
+ if "EXIT_CODE=0" not in out:
624
+ raise RuntimeError(f"WAN text encoder caching failed: {out[-200:]}")
625
 
626
+ else:
627
+ # FLUX.2 caching
628
+ flux2_dir = "/workspace/models/FLUX.2-dev"
629
+ vae_path = f"{flux2_dir}/ae.safetensors"
630
+ te_path = f"{flux2_dir}/text_encoder/model-00001-of-00010.safetensors"
631
+
632
+ job._log("Caching latents (VAE encoding)...")
633
+ job.progress = 0.15
634
+ self._schedule_db_save(job)
635
+ cache_latents_cmd = (
636
+ f"cd /workspace/musubi-tuner && PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True python src/musubi_tuner/flux_2_cache_latents.py"
637
+ f" --dataset_config /workspace/dataset.toml"
638
+ f" --vae {vae_path}"
639
+ f" --model_version dev"
640
+ f" --vae_dtype bfloat16"
641
+ f" 2>&1 | tee /tmp/cache_latents.log; echo EXIT_CODE=${{PIPESTATUS[0]}}"
642
+ )
643
+ out = await self._ssh_exec(ssh, cache_latents_cmd, timeout=600)
644
+ last_lines = out.split('\n')[-30:]
645
+ job._log('\n'.join(last_lines))
646
+ if "EXIT_CODE=0" not in out:
647
+ err_log = await self._ssh_exec(ssh, "grep -i 'error\\|exception\\|traceback\\|failed' /tmp/cache_latents.log | tail -10")
648
+ job._log(f"Cache error details: {err_log}")
649
+ raise RuntimeError(f"Latent caching failed")
650
+
651
+ job._log("Caching text encoder outputs (bf16)...")
652
+ job.progress = 0.25
653
+ self._schedule_db_save(job)
654
+ cache_te_cmd = (
655
+ f"cd /workspace/musubi-tuner && PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True"
656
+ f" python src/musubi_tuner/flux_2_cache_text_encoder_outputs.py"
657
+ f" --dataset_config /workspace/dataset.toml"
658
+ f" --text_encoder {te_path}"
659
+ f" --model_version dev"
660
+ f" --batch_size 1"
661
+ f" 2>&1; echo EXIT_CODE=$?"
662
+ )
663
+ out = await self._ssh_exec(ssh, cache_te_cmd, timeout=600)
664
+ job._log(out[-500:] if out else "done")
665
+ if "EXIT_CODE=0" not in out:
666
+ raise RuntimeError(f"Text encoder caching failed: {out[-200:]}")
667
 
668
  # Build training command based on model type
669
  train_cmd = self._build_training_command(
 
783
  await self._ssh_exec(ssh, f"cp {remote_output} /runpod-volume/loras/{name}.safetensors")
784
  job._log(f"LoRA saved to volume: /runpod-volume/loras/{name}.safetensors")
785
 
786
+ # Also save intermediate checkpoints (step 500, 1000, 1500, etc.)
787
+ checkpoint_files = (await self._ssh_exec(ssh, f"ls /workspace/output/{name}-step*.safetensors 2>/dev/null")).strip()
788
+ if checkpoint_files:
789
+ for ckpt in checkpoint_files.split("\n"):
790
+ ckpt = ckpt.strip()
791
+ if ckpt:
792
+ ckpt_name = ckpt.split("/")[-1]
793
+ await self._ssh_exec(ssh, f"cp {ckpt} /runpod-volume/loras/{ckpt_name}")
794
+ job._log(f"Checkpoint saved: /runpod-volume/loras/{ckpt_name}")
795
+
796
  # Download locally (skip on HF Spaces — limited storage)
797
  if IS_HF_SPACES:
798
  job.output_path = f"/runpod-volume/loras/{name}.safetensors"
 
1202
 
1203
  return " ".join(args) + " 2>&1"
1204
 
1205
+ elif model_type == "wan22":
1206
+ # WAN 2.2 T2V LoRA training via musubi-tuner
1207
+ wan_dir = "/workspace/models/WAN2.2"
1208
+ dit_low = f"{wan_dir}/wan2.2_t2v_low_noise_14B_fp16.safetensors"
1209
+ dit_high = f"{wan_dir}/wan2.2_t2v_high_noise_14B_fp16.safetensors"
1210
+
1211
+ network_mod = model_cfg.get("network_module", "networks.lora_wan")
1212
+ ts_sampling = model_cfg.get("timestep_sampling", "shift")
1213
+ discrete_shift = model_cfg.get("discrete_flow_shift", 5.0)
1214
+
1215
+ args = [
1216
+ "cd /workspace/musubi-tuner && PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True",
1217
+ "accelerate launch --num_cpu_threads_per_process 1 --mixed_precision fp16",
1218
+ "src/musubi_tuner/wan_train_network.py",
1219
+ "--task t2v-A14B",
1220
+ f"--dit {dit_low}",
1221
+ f"--dit_high_noise {dit_high}",
1222
+ "--dataset_config /workspace/dataset.toml",
1223
+ "--sdpa --mixed_precision fp16",
1224
+ "--gradient_checkpointing",
1225
+ f"--timestep_sampling {ts_sampling}",
1226
+ f"--discrete_flow_shift {discrete_shift}",
1227
+ f"--network_module {network_mod}",
1228
+ f"--network_dim={network_rank}",
1229
+ f"--network_alpha={network_alpha}",
1230
+ f"--optimizer_type={optimizer}",
1231
+ f"--learning_rate={learning_rate}",
1232
+ "--seed=42",
1233
+ "--output_dir=/workspace/output",
1234
+ f"--output_name={name}",
1235
+ ]
1236
+
1237
+ if max_train_steps:
1238
+ args.append(f"--max_train_steps={max_train_steps}")
1239
+ if save_every_n_steps:
1240
+ args.append(f"--save_every_n_steps={save_every_n_steps}")
1241
+ else:
1242
+ args.append(f"--max_train_epochs={num_epochs}")
1243
+ args.append(f"--save_every_n_epochs={save_every_n_epochs}")
1244
+
1245
+ return " ".join(args) + " 2>&1"
1246
+
1247
  elif model_type == "flux":
1248
  # FLUX.1 training via sd-scripts
1249
  script = "flux_train_network.py"