import spaces import os import sys import requests import site APP_DIR = os.path.dirname(os.path.abspath(__file__)) if APP_DIR not in sys.path: sys.path.insert(0, APP_DIR) print(f"✅ Added project root '{APP_DIR}' to sys.path.") SAGE_PATCH_APPLIED = False def apply_sage_attention_patch(): global SAGE_PATCH_APPLIED if SAGE_PATCH_APPLIED: return "SageAttention patch already applied." try: from comfy import model_management import sageattention print("--- [Runtime Patch] sageattention package found. Applying patch... ---") model_management.sage_attention_enabled = lambda: True model_management.pytorch_attention_enabled = lambda: False SAGE_PATCH_APPLIED = True return "✅ Successfully enabled SageAttention." except ImportError: SAGE_PATCH_APPLIED = False msg = "--- [Runtime Patch] ⚠️ sageattention package not found. Continuing with default attention. ---" print(msg) return msg except Exception as e: SAGE_PATCH_APPLIED = False msg = f"--- [Runtime Patch] ❌ An error occurred while applying SageAttention patch: {e} ---" print(msg) return msg @spaces.GPU def dummy_gpu_for_startup(): print("--- [GPU Startup] Dummy function for startup check initiated. ---") patch_result = apply_sage_attention_patch() print(f"--- [GPU Startup] {patch_result} ---") print("--- [GPU Startup] Startup check passed. ---") return "Startup check passed." def handle_private_downloads(): """ Checks for a private_file_list.yaml, downloads required models using HF_TOKEN, and then clears the token from the environment. """ import yaml from huggingface_hub import hf_hub_download from core.settings import ( DIFFUSION_MODELS_DIR, TEXT_ENCODERS_DIR, VAE_DIR, CHECKPOINT_DIR, LORA_DIR, CONTROLNET_DIR, MODEL_PATCHES_DIR, EMBEDDING_DIR ) print("--- [Startup] Checking for private models to download... ---") private_list_path = os.path.join(APP_DIR, 'yaml', 'private_file_list.yaml') if not os.path.exists(private_list_path): print("--- [Startup] No private model list found. Skipping. ---") if 'HF_TOKEN' in os.environ: del os.environ['HF_TOKEN'] print("--- [Startup] Cleared HF_TOKEN environment variable as it is no longer needed. ---") print(f"--- [Startup] Verifying HF_TOKEN after clearing: {os.environ.get('HF_TOKEN')}") return try: with open(private_list_path, 'r', encoding='utf-8') as f: private_files_config = yaml.safe_load(f) if not private_files_config or 'file' not in private_files_config: print("--- [Startup] Private model list is empty or malformed. Skipping. ---") return category_to_dir_map = { "diffusion_models": DIFFUSION_MODELS_DIR, "text_encoders": TEXT_ENCODERS_DIR, "vae": VAE_DIR, "checkpoints": CHECKPOINT_DIR, "loras": LORA_DIR, "controlnet": CONTROLNET_DIR, "model_patches": MODEL_PATCHES_DIR, "embeddings": EMBEDDING_DIR, } files_to_download = [] for category, files in private_files_config.get('file', {}).items(): dest_dir = category_to_dir_map.get(category) if not dest_dir: print(f"--- [Startup] ⚠️ Unknown category '{category}' in private_file_list.yaml. Skipping. ---") continue if isinstance(files, list): for file_info in files: files_to_download.append((file_info, dest_dir)) if not files_to_download: print("--- [Startup] No private models configured for download. ---") return print(f"--- [Startup] Found {len(files_to_download)} private model(s) to download. Using HF_TOKEN if available. ---") for file_info, dest_dir in files_to_download: filename = file_info.get("filename") repo_id = file_info.get("repo_id") repo_path = file_info.get("repository_file_path", filename) if not all([filename, repo_id]): print(f"--- [Startup] ⚠️ Skipping malformed entry in private_file_list.yaml: {file_info} ---") continue dest_path = os.path.join(dest_dir, filename) if os.path.lexists(dest_path): print(f"--- [Startup] ✅ Model '{filename}' already exists. Skipping download. ---") continue print(f"--- [Startup] ⏳ Downloading '{filename}' from repo '{repo_id}'... ---") try: cached_path = hf_hub_download(repo_id=repo_id, filename=repo_path) os.makedirs(dest_dir, exist_ok=True) os.symlink(cached_path, dest_path) print(f"--- [Startup] ✅ Successfully downloaded and linked '{filename}'. ---") except Exception as e: print(f"--- [Startup] ❌ ERROR: Failed to download '{filename}': {e}") print("--- [Startup] ❌ Please ensure your HF_TOKEN is set correctly and has access to the repository. ---") finally: if 'HF_TOKEN' in os.environ: del os.environ['HF_TOKEN'] print("--- [Startup] ✅ Cleared HF_TOKEN environment variable. ---") print(f"--- [Startup] Verifying HF_TOKEN after clearing: {os.environ.get('HF_TOKEN')}") else: print("--- [Startup] Note: HF_TOKEN environment variable was not set. Private downloads may fail without it. ---") def main(): from utils.app_utils import print_welcome_message from scripts import build_sage_attention print_welcome_message() # Handle downloads that require authentication first. handle_private_downloads() print("--- [Setup] Attempting to build and install SageAttention... ---") try: build_sage_attention.install_sage_attention() print("--- [Setup] ✅ SageAttention installation process finished. ---") except Exception as e: print(f"--- [Setup] ❌ SageAttention installation failed: {e}. Continuing with default attention. ---") print("--- [Setup] Reloading site-packages to detect newly installed packages... ---") try: site.main() print("--- [Setup] ✅ Site-packages reloaded. ---") except Exception as e: print(f"--- [Setup] ⚠️ Warning: Could not fully reload site-packages: {e} ---") from comfy_integration import setup as setup_comfyui from utils.app_utils import ( build_preprocessor_model_map, build_preprocessor_parameter_map ) from core import shared_state from core.settings import ALL_MODEL_MAP, ALL_FILE_DOWNLOAD_MAP def check_all_model_urls_on_startup(): print("--- [Setup] Checking all model URL validity (one-time check) ---") for display_name, model_info in ALL_MODEL_MAP.items(): _, components, _, _ = model_info if not components: continue for filename in components.values(): download_info = ALL_FILE_DOWNLOAD_MAP.get(filename, {}) repo_id = download_info.get('repo_id') if not repo_id: continue repo_file_path = download_info.get('repository_file_path', filename) url = f"https://huggingface.co/{repo_id}/resolve/main/{repo_file_path}" try: response = requests.head(url, timeout=5, allow_redirects=True) if response.status_code >= 400: print(f"❌ Invalid URL for '{display_name}' component '{filename}': {url} (Status: {response.status_code})") shared_state.INVALID_MODEL_URLS[display_name] = True break except requests.RequestException as e: print(f"❌ URL check failed for '{display_name}' component '{filename}': {e}") shared_state.INVALID_MODEL_URLS[display_name] = True break print("--- [Setup] ✅ Finished checking model URLs. ---") print("--- Starting Application Setup ---") setup_comfyui.initialize_comfyui() check_all_model_urls_on_startup() print("--- Building ControlNet preprocessor maps ---") from core.generation_logic import build_reverse_map build_reverse_map() build_preprocessor_model_map() build_preprocessor_parameter_map() print("--- ✅ ControlNet preprocessor setup complete. ---") print("--- Environment configured. Proceeding with module imports. ---") from ui.layout import build_ui from ui.events import attach_event_handlers print(f"✅ Working directory is stable: {os.getcwd()}") demo = build_ui(attach_event_handlers) print("--- Launching Gradio Interface ---") demo.queue().launch(server_name="0.0.0.0", server_port=7860) if __name__ == "__main__": main()