Spaces:
Running
on
Zero
Running
on
Zero
| import spaces | |
| import os | |
| import sys | |
| import requests | |
| import site | |
| APP_DIR = os.path.dirname(os.path.abspath(__file__)) | |
| if APP_DIR not in sys.path: | |
| sys.path.insert(0, APP_DIR) | |
| print(f"β Added project root '{APP_DIR}' to sys.path.") | |
| SAGE_PATCH_APPLIED = False | |
| def apply_sage_attention_patch(): | |
| global SAGE_PATCH_APPLIED | |
| if SAGE_PATCH_APPLIED: | |
| return "SageAttention patch already applied." | |
| try: | |
| from comfy import model_management | |
| import sageattention | |
| print("--- [Runtime Patch] sageattention package found. Applying patch... ---") | |
| model_management.sage_attention_enabled = lambda: True | |
| model_management.pytorch_attention_enabled = lambda: False | |
| SAGE_PATCH_APPLIED = True | |
| return "β Successfully enabled SageAttention." | |
| except ImportError: | |
| SAGE_PATCH_APPLIED = False | |
| msg = "--- [Runtime Patch] β οΈ sageattention package not found. Continuing with default attention. ---" | |
| print(msg) | |
| return msg | |
| except Exception as e: | |
| SAGE_PATCH_APPLIED = False | |
| msg = f"--- [Runtime Patch] β An error occurred while applying SageAttention patch: {e} ---" | |
| print(msg) | |
| return msg | |
| def dummy_gpu_for_startup(): | |
| print("--- [GPU Startup] Dummy function for startup check initiated. ---") | |
| patch_result = apply_sage_attention_patch() | |
| print(f"--- [GPU Startup] {patch_result} ---") | |
| print("--- [GPU Startup] Startup check passed. ---") | |
| return "Startup check passed." | |
| def handle_private_downloads(): | |
| """ | |
| Checks for a private_file_list.yaml, downloads required models using HF_TOKEN, | |
| and then clears the token from the environment. | |
| """ | |
| import yaml | |
| from huggingface_hub import hf_hub_download | |
| from core.settings import ( | |
| DIFFUSION_MODELS_DIR, TEXT_ENCODERS_DIR, VAE_DIR, CHECKPOINT_DIR, | |
| LORA_DIR, CONTROLNET_DIR, MODEL_PATCHES_DIR, EMBEDDING_DIR | |
| ) | |
| print("--- [Startup] Checking for private models to download... ---") | |
| private_list_path = os.path.join(APP_DIR, 'yaml', 'private_file_list.yaml') | |
| if not os.path.exists(private_list_path): | |
| print("--- [Startup] No private model list found. Skipping. ---") | |
| if 'HF_TOKEN' in os.environ: | |
| del os.environ['HF_TOKEN'] | |
| print("--- [Startup] Cleared HF_TOKEN environment variable as it is no longer needed. ---") | |
| print(f"--- [Startup] Verifying HF_TOKEN after clearing: {os.environ.get('HF_TOKEN')}") | |
| return | |
| try: | |
| with open(private_list_path, 'r', encoding='utf-8') as f: | |
| private_files_config = yaml.safe_load(f) | |
| if not private_files_config or 'file' not in private_files_config: | |
| print("--- [Startup] Private model list is empty or malformed. Skipping. ---") | |
| return | |
| category_to_dir_map = { | |
| "diffusion_models": DIFFUSION_MODELS_DIR, | |
| "text_encoders": TEXT_ENCODERS_DIR, | |
| "vae": VAE_DIR, | |
| "checkpoints": CHECKPOINT_DIR, | |
| "loras": LORA_DIR, | |
| "controlnet": CONTROLNET_DIR, | |
| "model_patches": MODEL_PATCHES_DIR, | |
| "embeddings": EMBEDDING_DIR, | |
| } | |
| files_to_download = [] | |
| for category, files in private_files_config.get('file', {}).items(): | |
| dest_dir = category_to_dir_map.get(category) | |
| if not dest_dir: | |
| print(f"--- [Startup] β οΈ Unknown category '{category}' in private_file_list.yaml. Skipping. ---") | |
| continue | |
| if isinstance(files, list): | |
| for file_info in files: | |
| files_to_download.append((file_info, dest_dir)) | |
| if not files_to_download: | |
| print("--- [Startup] No private models configured for download. ---") | |
| return | |
| print(f"--- [Startup] Found {len(files_to_download)} private model(s) to download. Using HF_TOKEN if available. ---") | |
| for file_info, dest_dir in files_to_download: | |
| filename = file_info.get("filename") | |
| repo_id = file_info.get("repo_id") | |
| repo_path = file_info.get("repository_file_path", filename) | |
| if not all([filename, repo_id]): | |
| print(f"--- [Startup] β οΈ Skipping malformed entry in private_file_list.yaml: {file_info} ---") | |
| continue | |
| dest_path = os.path.join(dest_dir, filename) | |
| if os.path.lexists(dest_path): | |
| print(f"--- [Startup] β Model '{filename}' already exists. Skipping download. ---") | |
| continue | |
| print(f"--- [Startup] β³ Downloading '{filename}' from repo '{repo_id}'... ---") | |
| try: | |
| cached_path = hf_hub_download(repo_id=repo_id, filename=repo_path) | |
| os.makedirs(dest_dir, exist_ok=True) | |
| os.symlink(cached_path, dest_path) | |
| print(f"--- [Startup] β Successfully downloaded and linked '{filename}'. ---") | |
| except Exception as e: | |
| print(f"--- [Startup] β ERROR: Failed to download '{filename}': {e}") | |
| print("--- [Startup] β Please ensure your HF_TOKEN is set correctly and has access to the repository. ---") | |
| finally: | |
| if 'HF_TOKEN' in os.environ: | |
| del os.environ['HF_TOKEN'] | |
| print("--- [Startup] β Cleared HF_TOKEN environment variable. ---") | |
| print(f"--- [Startup] Verifying HF_TOKEN after clearing: {os.environ.get('HF_TOKEN')}") | |
| else: | |
| print("--- [Startup] Note: HF_TOKEN environment variable was not set. Private downloads may fail without it. ---") | |
| def main(): | |
| from utils.app_utils import print_welcome_message | |
| from scripts import build_sage_attention | |
| print_welcome_message() | |
| # Handle downloads that require authentication first. | |
| handle_private_downloads() | |
| print("--- [Setup] Attempting to build and install SageAttention... ---") | |
| try: | |
| build_sage_attention.install_sage_attention() | |
| print("--- [Setup] β SageAttention installation process finished. ---") | |
| except Exception as e: | |
| print(f"--- [Setup] β SageAttention installation failed: {e}. Continuing with default attention. ---") | |
| print("--- [Setup] Reloading site-packages to detect newly installed packages... ---") | |
| try: | |
| site.main() | |
| print("--- [Setup] β Site-packages reloaded. ---") | |
| except Exception as e: | |
| print(f"--- [Setup] β οΈ Warning: Could not fully reload site-packages: {e} ---") | |
| from comfy_integration import setup as setup_comfyui | |
| from utils.app_utils import ( | |
| build_preprocessor_model_map, | |
| build_preprocessor_parameter_map | |
| ) | |
| from core import shared_state | |
| from core.settings import ALL_MODEL_MAP, ALL_FILE_DOWNLOAD_MAP | |
| def check_all_model_urls_on_startup(): | |
| print("--- [Setup] Checking all model URL validity (one-time check) ---") | |
| for display_name, model_info in ALL_MODEL_MAP.items(): | |
| _, components, _, _ = model_info | |
| if not components: continue | |
| for filename in components.values(): | |
| download_info = ALL_FILE_DOWNLOAD_MAP.get(filename, {}) | |
| repo_id = download_info.get('repo_id') | |
| if not repo_id: continue | |
| repo_file_path = download_info.get('repository_file_path', filename) | |
| url = f"https://huggingface.co/{repo_id}/resolve/main/{repo_file_path}" | |
| try: | |
| response = requests.head(url, timeout=5, allow_redirects=True) | |
| if response.status_code >= 400: | |
| print(f"β Invalid URL for '{display_name}' component '{filename}': {url} (Status: {response.status_code})") | |
| shared_state.INVALID_MODEL_URLS[display_name] = True | |
| break | |
| except requests.RequestException as e: | |
| print(f"β URL check failed for '{display_name}' component '{filename}': {e}") | |
| shared_state.INVALID_MODEL_URLS[display_name] = True | |
| break | |
| print("--- [Setup] β Finished checking model URLs. ---") | |
| print("--- Starting Application Setup ---") | |
| setup_comfyui.initialize_comfyui() | |
| check_all_model_urls_on_startup() | |
| print("--- Building ControlNet preprocessor maps ---") | |
| from core.generation_logic import build_reverse_map | |
| build_reverse_map() | |
| build_preprocessor_model_map() | |
| build_preprocessor_parameter_map() | |
| print("--- β ControlNet preprocessor setup complete. ---") | |
| print("--- Environment configured. Proceeding with module imports. ---") | |
| from ui.layout import build_ui | |
| from ui.events import attach_event_handlers | |
| print(f"β Working directory is stable: {os.getcwd()}") | |
| demo = build_ui(attach_event_handlers) | |
| print("--- Launching Gradio Interface ---") | |
| demo.queue().launch(server_name="0.0.0.0", server_port=7860) | |
| if __name__ == "__main__": | |
| main() |