ImageGen-FLUX.2 / app.py
RioShiina's picture
Upload folder using huggingface_hub
c009d4f verified
import spaces
import os
import sys
import requests
import site
APP_DIR = os.path.dirname(os.path.abspath(__file__))
if APP_DIR not in sys.path:
sys.path.insert(0, APP_DIR)
print(f"βœ… Added project root '{APP_DIR}' to sys.path.")
SAGE_PATCH_APPLIED = False
def apply_sage_attention_patch():
global SAGE_PATCH_APPLIED
if SAGE_PATCH_APPLIED:
return "SageAttention patch already applied."
try:
from comfy import model_management
import sageattention
print("--- [Runtime Patch] sageattention package found. Applying patch... ---")
model_management.sage_attention_enabled = lambda: True
model_management.pytorch_attention_enabled = lambda: False
SAGE_PATCH_APPLIED = True
return "βœ… Successfully enabled SageAttention."
except ImportError:
SAGE_PATCH_APPLIED = False
msg = "--- [Runtime Patch] ⚠️ sageattention package not found. Continuing with default attention. ---"
print(msg)
return msg
except Exception as e:
SAGE_PATCH_APPLIED = False
msg = f"--- [Runtime Patch] ❌ An error occurred while applying SageAttention patch: {e} ---"
print(msg)
return msg
@spaces.GPU
def dummy_gpu_for_startup():
print("--- [GPU Startup] Dummy function for startup check initiated. ---")
patch_result = apply_sage_attention_patch()
print(f"--- [GPU Startup] {patch_result} ---")
print("--- [GPU Startup] Startup check passed. ---")
return "Startup check passed."
def handle_private_downloads():
"""
Checks for a private_file_list.yaml, downloads required models using HF_TOKEN,
and then clears the token from the environment.
"""
import yaml
from huggingface_hub import hf_hub_download
from core.settings import (
DIFFUSION_MODELS_DIR, TEXT_ENCODERS_DIR, VAE_DIR, CHECKPOINT_DIR,
LORA_DIR, CONTROLNET_DIR, MODEL_PATCHES_DIR, EMBEDDING_DIR
)
print("--- [Startup] Checking for private models to download... ---")
private_list_path = os.path.join(APP_DIR, 'yaml', 'private_file_list.yaml')
if not os.path.exists(private_list_path):
print("--- [Startup] No private model list found. Skipping. ---")
if 'HF_TOKEN' in os.environ:
del os.environ['HF_TOKEN']
print("--- [Startup] Cleared HF_TOKEN environment variable as it is no longer needed. ---")
print(f"--- [Startup] Verifying HF_TOKEN after clearing: {os.environ.get('HF_TOKEN')}")
return
try:
with open(private_list_path, 'r', encoding='utf-8') as f:
private_files_config = yaml.safe_load(f)
if not private_files_config or 'file' not in private_files_config:
print("--- [Startup] Private model list is empty or malformed. Skipping. ---")
return
category_to_dir_map = {
"diffusion_models": DIFFUSION_MODELS_DIR,
"text_encoders": TEXT_ENCODERS_DIR,
"vae": VAE_DIR,
"checkpoints": CHECKPOINT_DIR,
"loras": LORA_DIR,
"controlnet": CONTROLNET_DIR,
"model_patches": MODEL_PATCHES_DIR,
"embeddings": EMBEDDING_DIR,
}
files_to_download = []
for category, files in private_files_config.get('file', {}).items():
dest_dir = category_to_dir_map.get(category)
if not dest_dir:
print(f"--- [Startup] ⚠️ Unknown category '{category}' in private_file_list.yaml. Skipping. ---")
continue
if isinstance(files, list):
for file_info in files:
files_to_download.append((file_info, dest_dir))
if not files_to_download:
print("--- [Startup] No private models configured for download. ---")
return
print(f"--- [Startup] Found {len(files_to_download)} private model(s) to download. Using HF_TOKEN if available. ---")
for file_info, dest_dir in files_to_download:
filename = file_info.get("filename")
repo_id = file_info.get("repo_id")
repo_path = file_info.get("repository_file_path", filename)
if not all([filename, repo_id]):
print(f"--- [Startup] ⚠️ Skipping malformed entry in private_file_list.yaml: {file_info} ---")
continue
dest_path = os.path.join(dest_dir, filename)
if os.path.lexists(dest_path):
print(f"--- [Startup] βœ… Model '{filename}' already exists. Skipping download. ---")
continue
print(f"--- [Startup] ⏳ Downloading '{filename}' from repo '{repo_id}'... ---")
try:
cached_path = hf_hub_download(repo_id=repo_id, filename=repo_path)
os.makedirs(dest_dir, exist_ok=True)
os.symlink(cached_path, dest_path)
print(f"--- [Startup] βœ… Successfully downloaded and linked '{filename}'. ---")
except Exception as e:
print(f"--- [Startup] ❌ ERROR: Failed to download '{filename}': {e}")
print("--- [Startup] ❌ Please ensure your HF_TOKEN is set correctly and has access to the repository. ---")
finally:
if 'HF_TOKEN' in os.environ:
del os.environ['HF_TOKEN']
print("--- [Startup] βœ… Cleared HF_TOKEN environment variable. ---")
print(f"--- [Startup] Verifying HF_TOKEN after clearing: {os.environ.get('HF_TOKEN')}")
else:
print("--- [Startup] Note: HF_TOKEN environment variable was not set. Private downloads may fail without it. ---")
def main():
from utils.app_utils import print_welcome_message
from scripts import build_sage_attention
print_welcome_message()
# Handle downloads that require authentication first.
handle_private_downloads()
print("--- [Setup] Attempting to build and install SageAttention... ---")
try:
build_sage_attention.install_sage_attention()
print("--- [Setup] βœ… SageAttention installation process finished. ---")
except Exception as e:
print(f"--- [Setup] ❌ SageAttention installation failed: {e}. Continuing with default attention. ---")
print("--- [Setup] Reloading site-packages to detect newly installed packages... ---")
try:
site.main()
print("--- [Setup] βœ… Site-packages reloaded. ---")
except Exception as e:
print(f"--- [Setup] ⚠️ Warning: Could not fully reload site-packages: {e} ---")
from comfy_integration import setup as setup_comfyui
from utils.app_utils import (
build_preprocessor_model_map,
build_preprocessor_parameter_map
)
from core import shared_state
from core.settings import ALL_MODEL_MAP, ALL_FILE_DOWNLOAD_MAP
def check_all_model_urls_on_startup():
print("--- [Setup] Checking all model URL validity (one-time check) ---")
for display_name, model_info in ALL_MODEL_MAP.items():
_, components, _, _ = model_info
if not components: continue
for filename in components.values():
download_info = ALL_FILE_DOWNLOAD_MAP.get(filename, {})
repo_id = download_info.get('repo_id')
if not repo_id: continue
repo_file_path = download_info.get('repository_file_path', filename)
url = f"https://huggingface.co/{repo_id}/resolve/main/{repo_file_path}"
try:
response = requests.head(url, timeout=5, allow_redirects=True)
if response.status_code >= 400:
print(f"❌ Invalid URL for '{display_name}' component '{filename}': {url} (Status: {response.status_code})")
shared_state.INVALID_MODEL_URLS[display_name] = True
break
except requests.RequestException as e:
print(f"❌ URL check failed for '{display_name}' component '{filename}': {e}")
shared_state.INVALID_MODEL_URLS[display_name] = True
break
print("--- [Setup] βœ… Finished checking model URLs. ---")
print("--- Starting Application Setup ---")
setup_comfyui.initialize_comfyui()
check_all_model_urls_on_startup()
print("--- Building ControlNet preprocessor maps ---")
from core.generation_logic import build_reverse_map
build_reverse_map()
build_preprocessor_model_map()
build_preprocessor_parameter_map()
print("--- βœ… ControlNet preprocessor setup complete. ---")
print("--- Environment configured. Proceeding with module imports. ---")
from ui.layout import build_ui
from ui.events import attach_event_handlers
print(f"βœ… Working directory is stable: {os.getcwd()}")
demo = build_ui(attach_event_handlers)
print("--- Launching Gradio Interface ---")
demo.queue().launch(server_name="0.0.0.0", server_port=7860)
if __name__ == "__main__":
main()