File size: 9,412 Bytes
c009d4f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
import spaces
import os
import sys
import requests
import site

APP_DIR = os.path.dirname(os.path.abspath(__file__))
if APP_DIR not in sys.path:
    sys.path.insert(0, APP_DIR)
    print(f"βœ… Added project root '{APP_DIR}' to sys.path.")

SAGE_PATCH_APPLIED = False

def apply_sage_attention_patch():
    global SAGE_PATCH_APPLIED
    if SAGE_PATCH_APPLIED:
        return "SageAttention patch already applied."

    try:
        from comfy import model_management
        import sageattention
        
        print("--- [Runtime Patch] sageattention package found. Applying patch... ---")
        model_management.sage_attention_enabled = lambda: True
        model_management.pytorch_attention_enabled = lambda: False
        
        SAGE_PATCH_APPLIED = True
        return "βœ… Successfully enabled SageAttention."
    except ImportError:
        SAGE_PATCH_APPLIED = False
        msg = "--- [Runtime Patch] ⚠️ sageattention package not found. Continuing with default attention. ---"
        print(msg)
        return msg
    except Exception as e:
        SAGE_PATCH_APPLIED = False
        msg = f"--- [Runtime Patch] ❌ An error occurred while applying SageAttention patch: {e} ---"
        print(msg)
        return msg

@spaces.GPU
def dummy_gpu_for_startup():
    print("--- [GPU Startup] Dummy function for startup check initiated. ---")
    patch_result = apply_sage_attention_patch()
    print(f"--- [GPU Startup] {patch_result} ---")
    print("--- [GPU Startup] Startup check passed. ---")
    return "Startup check passed."

def handle_private_downloads():
    """

    Checks for a private_file_list.yaml, downloads required models using HF_TOKEN,

    and then clears the token from the environment.

    """
    import yaml
    from huggingface_hub import hf_hub_download
    from core.settings import (
        DIFFUSION_MODELS_DIR, TEXT_ENCODERS_DIR, VAE_DIR, CHECKPOINT_DIR, 
        LORA_DIR, CONTROLNET_DIR, MODEL_PATCHES_DIR, EMBEDDING_DIR
    )

    print("--- [Startup] Checking for private models to download... ---")
    private_list_path = os.path.join(APP_DIR, 'yaml', 'private_file_list.yaml')

    if not os.path.exists(private_list_path):
        print("--- [Startup] No private model list found. Skipping. ---")
        if 'HF_TOKEN' in os.environ:
             del os.environ['HF_TOKEN']
             print("--- [Startup] Cleared HF_TOKEN environment variable as it is no longer needed. ---")
             print(f"--- [Startup] Verifying HF_TOKEN after clearing: {os.environ.get('HF_TOKEN')}")
        return

    try:
        with open(private_list_path, 'r', encoding='utf-8') as f:
            private_files_config = yaml.safe_load(f)

        if not private_files_config or 'file' not in private_files_config:
            print("--- [Startup] Private model list is empty or malformed. Skipping. ---")
            return
        
        category_to_dir_map = {
            "diffusion_models": DIFFUSION_MODELS_DIR,
            "text_encoders": TEXT_ENCODERS_DIR,
            "vae": VAE_DIR,
            "checkpoints": CHECKPOINT_DIR,
            "loras": LORA_DIR,
            "controlnet": CONTROLNET_DIR,
            "model_patches": MODEL_PATCHES_DIR,
            "embeddings": EMBEDDING_DIR,
        }

        files_to_download = []
        for category, files in private_files_config.get('file', {}).items():
            dest_dir = category_to_dir_map.get(category)
            if not dest_dir:
                print(f"--- [Startup] ⚠️ Unknown category '{category}' in private_file_list.yaml. Skipping. ---")
                continue
            
            if isinstance(files, list):
                for file_info in files:
                    files_to_download.append((file_info, dest_dir))

        if not files_to_download:
            print("--- [Startup] No private models configured for download. ---")
            return

        print(f"--- [Startup] Found {len(files_to_download)} private model(s) to download. Using HF_TOKEN if available. ---")

        for file_info, dest_dir in files_to_download:
            filename = file_info.get("filename")
            repo_id = file_info.get("repo_id")
            repo_path = file_info.get("repository_file_path", filename)
            
            if not all([filename, repo_id]):
                print(f"--- [Startup] ⚠️ Skipping malformed entry in private_file_list.yaml: {file_info} ---")
                continue

            dest_path = os.path.join(dest_dir, filename)
            if os.path.lexists(dest_path):
                print(f"--- [Startup] βœ… Model '{filename}' already exists. Skipping download. ---")
                continue
            
            print(f"--- [Startup] ⏳ Downloading '{filename}' from repo '{repo_id}'... ---")
            try:
                cached_path = hf_hub_download(repo_id=repo_id, filename=repo_path)
                os.makedirs(dest_dir, exist_ok=True)
                os.symlink(cached_path, dest_path)
                print(f"--- [Startup] βœ… Successfully downloaded and linked '{filename}'. ---")
            except Exception as e:
                print(f"--- [Startup] ❌ ERROR: Failed to download '{filename}': {e}")
                print("--- [Startup] ❌ Please ensure your HF_TOKEN is set correctly and has access to the repository. ---")

    finally:
        if 'HF_TOKEN' in os.environ:
            del os.environ['HF_TOKEN']
            print("--- [Startup] βœ… Cleared HF_TOKEN environment variable. ---")
            print(f"--- [Startup] Verifying HF_TOKEN after clearing: {os.environ.get('HF_TOKEN')}")
        else:
            print("--- [Startup] Note: HF_TOKEN environment variable was not set. Private downloads may fail without it. ---")

def main():
    from utils.app_utils import print_welcome_message
    from scripts import build_sage_attention

    print_welcome_message()
    
    # Handle downloads that require authentication first.
    handle_private_downloads()
    
    print("--- [Setup] Attempting to build and install SageAttention... ---")
    try:
        build_sage_attention.install_sage_attention()
        print("--- [Setup] βœ… SageAttention installation process finished. ---")
    except Exception as e:
        print(f"--- [Setup] ❌ SageAttention installation failed: {e}. Continuing with default attention. ---")


    print("--- [Setup] Reloading site-packages to detect newly installed packages... ---")
    try:
        site.main()
        print("--- [Setup] βœ… Site-packages reloaded. ---")
    except Exception as e:
        print(f"--- [Setup] ⚠️  Warning: Could not fully reload site-packages: {e} ---")

    from comfy_integration import setup as setup_comfyui
    from utils.app_utils import (
        build_preprocessor_model_map,
        build_preprocessor_parameter_map
    )
    from core import shared_state
    from core.settings import ALL_MODEL_MAP, ALL_FILE_DOWNLOAD_MAP

    def check_all_model_urls_on_startup():
        print("--- [Setup] Checking all model URL validity (one-time check) ---")
        for display_name, model_info in ALL_MODEL_MAP.items():
            _, components, _, _ = model_info
            if not components: continue
            
            for filename in components.values():
                download_info = ALL_FILE_DOWNLOAD_MAP.get(filename, {})
                repo_id = download_info.get('repo_id')
                if not repo_id: continue
                
                repo_file_path = download_info.get('repository_file_path', filename)
                url = f"https://huggingface.co/{repo_id}/resolve/main/{repo_file_path}"
                
                try:
                    response = requests.head(url, timeout=5, allow_redirects=True)
                    if response.status_code >= 400:
                        print(f"❌ Invalid URL for '{display_name}' component '{filename}': {url} (Status: {response.status_code})")
                        shared_state.INVALID_MODEL_URLS[display_name] = True
                        break 
                except requests.RequestException as e:
                    print(f"❌ URL check failed for '{display_name}' component '{filename}': {e}")
                    shared_state.INVALID_MODEL_URLS[display_name] = True
                    break
        print("--- [Setup] βœ… Finished checking model URLs. ---")

    print("--- Starting Application Setup ---")
    
    setup_comfyui.initialize_comfyui()

    check_all_model_urls_on_startup()
    
    print("--- Building ControlNet preprocessor maps ---")
    from core.generation_logic import build_reverse_map
    build_reverse_map()
    build_preprocessor_model_map()
    build_preprocessor_parameter_map()
    print("--- βœ… ControlNet preprocessor setup complete. ---")

    print("--- Environment configured. Proceeding with module imports. ---")
    from ui.layout import build_ui
    from ui.events import attach_event_handlers

    print(f"βœ… Working directory is stable: {os.getcwd()}")

    demo = build_ui(attach_event_handlers)
        
    print("--- Launching Gradio Interface ---")
    demo.queue().launch(server_name="0.0.0.0", server_port=7860)


if __name__ == "__main__":
    main()