File size: 4,878 Bytes
a0d6949 cf1b6d7 a0d6949 6a2169d 0e80bb9 a0d6949 b1e986f a0d6949 b1e986f a0d6949 b1e986f a0d6949 6a2169d 0e80bb9 6a2169d a0d6949 6a2169d a0d6949 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 |
import os
import sys
# CRITICAL: Import spaces FIRST before any CUDA initialization
try:
import spaces
except ImportError:
pass
sys.stdout.flush()
import functools
print = functools.partial(print, flush=True)
import ftfy
import sentencepiece
from FlowFacade import FlowFacade
from BackgroundEngine import BackgroundEngine
from style_transfer import StyleTransferEngine
from ui_manager import UIManager
def preload_models():
"""
Pre-download models to cache on HF Spaces startup.
Backup method if YAML preload_from_hub doesn't work.
Only runs in HF Spaces environment.
"""
if not os.environ.get('SPACE_ID'):
return
cache_dir = os.path.expanduser("~/.cache/huggingface/hub")
if os.path.exists(cache_dir):
cached_models = os.listdir(cache_dir)
if any("wan2.2" in m.lower() or "models--kijai" in m.lower() for m in cached_models):
print("✓ Models already cached (YAML preload worked)")
return
print("→ Pre-caching models to disk (first-time setup)...")
print(" This may take 2-3 minutes, please wait...")
try:
from diffusers import WanTransformer3DModel
from transformers import AutoModelForCausalLM, AutoTokenizer
from huggingface_hub import hf_hub_download
import torch
print(" [1/4] Downloading video model transformer...")
WanTransformer3DModel.from_pretrained(
"cbensimon/Wan2.2-I2V-A14B-bf16-Diffusers",
subfolder='transformer',
torch_dtype=torch.bfloat16,
)
print(" [2/4] Downloading video model transformer_2...")
WanTransformer3DModel.from_pretrained(
"cbensimon/Wan2.2-I2V-A14B-bf16-Diffusers",
subfolder='transformer_2',
torch_dtype=torch.bfloat16,
)
print(" [3/4] Downloading Lightning LoRA...")
hf_hub_download(
"Kijai/WanVideo_comfy",
"Lightx2v/lightx2v_I2V_14B_480p_cfg_step_distill_rank128_bf16.safetensors"
)
print(" [4/4] Downloading text model (optional)...")
AutoModelForCausalLM.from_pretrained(
"Qwen/Qwen2.5-0.5B-Instruct",
torch_dtype=torch.bfloat16,
)
AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
print("✓ All models cached successfully!")
print(" Future users will load instantly from cache")
except Exception as e:
print(f"⚠ Pre-cache warning: {e}")
print(" Models will download on first generation instead")
def check_environment():
required_packages = [
"torch", "transformers", "diffusers", "gradio", "PIL",
"accelerate", "numpy", "ftfy", "sentencepiece"
]
optional_packages = {
"torchao": "INT8/FP8 quantization",
"xformers": "Memory efficient attention",
"aoti": "AoT compilation"
}
missing_packages = []
missing_optional = []
for package in required_packages:
try:
__import__(package)
except ImportError:
missing_packages.append(package)
for package, description in optional_packages.items():
try:
__import__(package)
except ImportError:
missing_optional.append(f"{package} ({description})")
if missing_packages:
print("\n❌ Missing required packages:", ", ".join(missing_packages))
print("\nInstall commands:")
print("!pip install torch==2.9.0 torchvision==0.24.0 torchaudio==2.9.0 --index-url https://download.pytorch.org/whl/cu126")
print("!pip install diffusers>=0.32.0 transformers>=4.46.0 accelerate gradio pillow numpy spaces ftfy sentencepiece protobuf imageio-ffmpeg")
print("!pip install torchao xformers")
sys.exit(1)
# Only show missing optional in debug mode
if missing_optional and os.environ.get('DEBUG'):
print("⚠ Optional packages missing:", ", ".join(missing_optional))
def main():
check_environment()
preload_models()
try:
facade = FlowFacade()
background_engine = BackgroundEngine()
style_engine = StyleTransferEngine()
ui_manager = UIManager(facade, background_engine, style_engine)
interface = ui_manager.create_interface()
is_colab = 'google.colab' in sys.modules
print("✓ Ready")
interface.launch(
share=is_colab,
server_name="0.0.0.0",
server_port=None,
show_error=True
)
except KeyboardInterrupt:
print("\n⚠ Shutdown requested")
if 'facade' in locals():
facade.cleanup()
sys.exit(0)
except Exception as e:
print(f"\n❌ Startup error: {str(e)}")
import traceback
traceback.print_exc()
sys.exit(1)
if __name__ == "__main__":
main()
|