Spaces:
Running on Zero
Running on Zero
| import os | |
| import json | |
| import time | |
| import uuid | |
| import shutil | |
| import subprocess | |
| import threading | |
| import requests | |
| from io import BytesIO | |
| from pathlib import Path | |
| # Load .env for local development (ignored if not installed / not present) | |
| try: | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| except ImportError: | |
| pass | |
| # Compatibility shim β newer huggingface_hub removed HfFolder which old gradio needs | |
| import huggingface_hub as _hfhub | |
| if not hasattr(_hfhub, "HfFolder"): | |
| class _HfFolder: | |
| def get_token(): return _hfhub.get_token() | |
| def save_token(token): pass | |
| _hfhub.HfFolder = _HfFolder | |
| import spaces | |
| import gradio as gr | |
| from PIL import Image | |
| from huggingface_hub import hf_hub_download | |
| # --------------------------------------------------------------------------- | |
| # Paths β use /data (persistent storage) if available, else /home/user | |
| # --------------------------------------------------------------------------- | |
| BASE_DIR = "/data" if os.path.exists("/data") else "/home/user" | |
| COMFYUI_DIR = f"{BASE_DIR}/ComfyUI" | |
| COMFYUI_INPUT = f"{COMFYUI_DIR}/input" | |
| COMFYUI_OUTPUT = f"{COMFYUI_DIR}/output" | |
| COMFYUI_MODELS = f"{COMFYUI_DIR}/models" | |
| COMFYUI_CUSTOM_NODES = f"{COMFYUI_DIR}/custom_nodes" | |
| COMFYUI_URL = "http://127.0.0.1:8188" | |
| # --------------------------------------------------------------------------- | |
| # Custom nodes | |
| # --------------------------------------------------------------------------- | |
| CUSTOM_NODES = { | |
| "ComfyUI-GGUF": "https://github.com/city96/ComfyUI-GGUF", | |
| "masquerade-nodes-comfyui": "https://github.com/BadCafeCode/masquerade-nodes-comfyui", | |
| "ComfyUI-KJNodes": "https://github.com/kijai/ComfyUI-KJNodes", | |
| "ComfyUI_LayerStyle_Advance": "https://github.com/chflame163/ComfyUI_LayerStyle_Advance", | |
| "comfyui-sam3": "https://github.com/PozzettiAndrea/ComfyUI-SAM3", | |
| } | |
| # --------------------------------------------------------------------------- | |
| # Models | |
| # --------------------------------------------------------------------------- | |
| MODELS = [ | |
| { | |
| "repo_id": "unsloth/FLUX.2-klein-4B-GGUF", | |
| "filename": "flux-2-klein-4b-Q8_0.gguf", | |
| "dest": f"{COMFYUI_MODELS}/unet/flux-2-klein-4b-Q8_0.gguf", | |
| }, | |
| { | |
| "repo_id": "Comfy-Org/z_image_turbo", | |
| "filename": "split_files/text_encoders/qwen_3_4b.safetensors", | |
| "dest": f"{COMFYUI_MODELS}/text_encoders/qwen_3_4b.safetensors", | |
| }, | |
| { | |
| "repo_id": "Comfy-Org/flux2-dev", | |
| "filename": "split_files/vae/flux2-vae.safetensors", | |
| "dest": f"{COMFYUI_MODELS}/vae/flux2-vae.safetensors", | |
| }, | |
| { | |
| "repo_id": "p1atdev/auraflow-v0.3-pvc-style-lora", | |
| "filename": "aura-pvc-2-_00010e_074520s.safetensors", | |
| "revision": "cafeee8ab8681ab679944b4e75ab0bdc4bdec6f7", | |
| "dest": f"{COMFYUI_MODELS}/loras/aura-pvc-2-_00010e_074520s.safetensors", | |
| }, | |
| ] | |
| # --------------------------------------------------------------------------- | |
| # Helpers | |
| # --------------------------------------------------------------------------- | |
| def run(cmd: str, **kwargs): | |
| print(f"$ {cmd}") | |
| subprocess.run(cmd, shell=True, check=True, **kwargs) | |
| def download_model(model: dict): | |
| dest = Path(model["dest"]) | |
| if dest.exists(): | |
| print(f" already exists: {dest.name}") | |
| return | |
| dest.parent.mkdir(parents=True, exist_ok=True) | |
| print(f" downloading {dest.name} ...") | |
| kwargs = dict(repo_id=model["repo_id"], filename=model["filename"]) | |
| if "revision" in model: | |
| kwargs["revision"] = model["revision"] | |
| cached = hf_hub_download(**kwargs) | |
| shutil.copy(cached, dest) | |
| print(f" saved β {dest}") | |
| # --------------------------------------------------------------------------- | |
| # Setup β split into two parts: | |
| # setup_env() : clone repos, install packages, download models (no GPU needed) | |
| # start_comfyui(): launch ComfyUI subprocess (must run inside @spaces.GPU) | |
| # --------------------------------------------------------------------------- | |
| def setup_env(): | |
| # 1. Clone ComfyUI | |
| if not Path(f"{COMFYUI_DIR}/main.py").exists(): | |
| print("=== Cloning ComfyUI ===") | |
| run(f"git clone --depth 1 https://github.com/comfyanonymous/ComfyUI {COMFYUI_DIR}") | |
| print("=== Installing ComfyUI requirements ===") | |
| run(f"pip install -r {COMFYUI_DIR}/requirements.txt -q --break-system-packages") | |
| # 2. Custom nodes | |
| print("=== Installing custom nodes ===") | |
| os.makedirs(COMFYUI_CUSTOM_NODES, exist_ok=True) | |
| for name, url in CUSTOM_NODES.items(): | |
| node_dir = Path(f"{COMFYUI_CUSTOM_NODES}/{name}") | |
| if not node_dir.exists(): | |
| print(f" cloning {name}") | |
| run(f"git clone --depth 1 {url} {node_dir}") | |
| req = node_dir / "requirements.txt" | |
| if req.exists(): | |
| run(f"pip install -r {req} -q --break-system-packages") | |
| # 3. Models | |
| print("=== Downloading models ===") | |
| for model in MODELS: | |
| download_model(model) | |
| # 4. Config | |
| config_path = Path(f"{COMFYUI_DIR}/user/__manager/config.ini") | |
| if config_path.exists(): | |
| content = config_path.read_text() | |
| content = content.replace("network_mode = public", "network_mode = personal_cloud") | |
| content = content.replace("security_level = strict", "security_level = normal") | |
| config_path.write_text(content) | |
| _comfyui_started = False | |
| def start_comfyui(): | |
| """Start ComfyUI subprocess. Must be called from within a @spaces.GPU context.""" | |
| global _comfyui_started | |
| if _comfyui_started: | |
| return | |
| print("=== Starting ComfyUI ===") | |
| os.makedirs(COMFYUI_INPUT, exist_ok=True) | |
| os.makedirs(COMFYUI_OUTPUT, exist_ok=True) | |
| proc = subprocess.Popen( | |
| f"python {COMFYUI_DIR}/main.py --listen --port 8188", | |
| shell=True, | |
| stdout=subprocess.PIPE, | |
| stderr=subprocess.STDOUT, | |
| ) | |
| def _stream(p): | |
| for line in p.stdout: | |
| print(line.decode(errors="replace"), end="") | |
| threading.Thread(target=_stream, args=(proc,), daemon=True).start() | |
| print("Waiting for ComfyUI...") | |
| for _ in range(60): | |
| try: | |
| if requests.get(f"{COMFYUI_URL}/system_stats", timeout=3).status_code == 200: | |
| print("ComfyUI ready!") | |
| _comfyui_started = True | |
| return | |
| except Exception: | |
| pass | |
| time.sleep(3) | |
| raise RuntimeError("ComfyUI failed to start within 3 minutes") | |
| # --------------------------------------------------------------------------- | |
| # Inference | |
| # --------------------------------------------------------------------------- | |
| def generate(target_img, clothing_img, progress=gr.Progress(track_tqdm=True)): | |
| start_comfyui() | |
| if target_img is None or clothing_img is None: | |
| raise gr.Error("Please upload both images before generating.") | |
| # Save user images to ComfyUI input folder | |
| uid = uuid.uuid4().hex[:8] | |
| target_name = f"target_{uid}.png" | |
| clothing_name = f"clothing_{uid}.png" | |
| Image.fromarray(target_img).save(f"{COMFYUI_INPUT}/{target_name}") | |
| Image.fromarray(clothing_img).save(f"{COMFYUI_INPUT}/{clothing_name}") | |
| # Load workflow and inject filenames + fresh seed | |
| with open("workflow_api.json") as f: | |
| workflow = json.load(f) | |
| workflow["76"]["inputs"]["image"] = target_name | |
| workflow["81"]["inputs"]["image"] = clothing_name | |
| workflow["104"]["inputs"]["noise_seed"] = int(time.time() * 1000) % (2 ** 32) | |
| # Submit to ComfyUI | |
| progress(0.05, desc="Submitting to ComfyUI...") | |
| client_id = uuid.uuid4().hex | |
| resp = requests.post( | |
| f"{COMFYUI_URL}/prompt", | |
| json={"prompt": workflow, "client_id": client_id}, | |
| ) | |
| resp.raise_for_status() | |
| prompt_id = resp.json()["prompt_id"] | |
| # Poll /history until done | |
| progress(0.1, desc="Generating β this takes 1β2 minutes...") | |
| started = time.time() | |
| while True: | |
| history = requests.get(f"{COMFYUI_URL}/history/{prompt_id}").json() | |
| if prompt_id in history: | |
| entry = history[prompt_id] | |
| status = entry.get("status", {}) | |
| if status.get("status_str") == "error" or entry.get("error"): | |
| raise gr.Error(f"Generation failed: {entry.get('error', 'unknown error')}") | |
| break | |
| elapsed = int(time.time() - started) | |
| progress(min(0.9, 0.1 + elapsed / 150 * 0.8), desc=f"Generating... ({elapsed}s)") | |
| time.sleep(3) | |
| # Retrieve output image | |
| outputs = history[prompt_id].get("outputs", {}) | |
| for node_output in outputs.values(): | |
| if "images" in node_output: | |
| img_info = node_output["images"][0] | |
| img_bytes = requests.get( | |
| f"{COMFYUI_URL}/view", | |
| params={ | |
| "filename": img_info["filename"], | |
| "subfolder": img_info.get("subfolder", ""), | |
| "type": img_info.get("type", "output"), | |
| }, | |
| ).content | |
| progress(1.0, desc="Done!") | |
| return Image.open(BytesIO(img_bytes)) | |
| raise gr.Error("No output image was returned by ComfyUI.") | |
| # --------------------------------------------------------------------------- | |
| # Access code β set the same code somewhere visible in your CV | |
| # --------------------------------------------------------------------------- | |
| ACCESS_CODE = os.getenv("SECRET_CODE") | |
| CV_FILE = "Sofia_Metelitsa_CV.pdf" | |
| # --------------------------------------------------------------------------- | |
| # Run setup once at startup (downloads + custom nodes, no GPU needed) | |
| # ComfyUI itself is started on first generate() call when GPU is available | |
| # --------------------------------------------------------------------------- | |
| # Run the non-GPU parts now (clone repos, install packages, download models) | |
| setup_env() | |
| # --------------------------------------------------------------------------- | |
| # Gradio UI | |
| # --------------------------------------------------------------------------- | |
| with gr.Blocks(title="Cloth Swap", theme=gr.themes.Soft()) as demo: | |
| # ββ Gate screen ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| with gr.Column(visible=True) as gate: | |
| gr.Markdown( | |
| """ | |
| # Cloth Swap β Virtual Try-On | |
| *AI-powered outfit swapping built with FLUX.2 Kontext + ComfyUI* | |
| --- | |
| ### π To access this tool, download my CV. | |
| The access code will appear here after you download it. | |
| """ | |
| ) | |
| dl_btn = gr.DownloadButton( | |
| label="π Download CV", | |
| value=CV_FILE, | |
| variant="secondary", | |
| size="lg", | |
| ) | |
| code_reveal = gr.Markdown(visible=False) | |
| gr.Markdown("---") | |
| code_input = gr.Textbox( | |
| label="Enter access code from CV", | |
| placeholder="Access code", | |
| type="password", | |
| max_lines=1, | |
| ) | |
| unlock_btn = gr.Button("Unlock β", variant="primary") | |
| error_msg = gr.Markdown(visible=False) | |
| # ββ Main app (hidden until unlocked) βββββββββββββββββββββββββββββββββββββ | |
| with gr.Column(visible=False) as main_app: | |
| gr.Markdown( | |
| """ | |
| # Cloth Swap | |
| Upload a photo of a **person** and a **clothing reference**. | |
| The AI swaps the outfit while preserving pose, lighting, and expression. | |
| """ | |
| ) | |
| with gr.Row(): | |
| target_input = gr.Image(label="Person (target)", type="numpy", height=400) | |
| clothing_input = gr.Image(label="Clothing reference", type="numpy", height=400) | |
| btn = gr.Button("β¨ Generate", variant="primary", size="lg") | |
| output = gr.Image(label="Result", height=500) | |
| gr.Markdown("*Generation takes ~1β2 minutes.*") | |
| btn.click( | |
| fn=generate, | |
| inputs=[target_input, clothing_input], | |
| outputs=output, | |
| ) | |
| # ββ Logic βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def reveal_code(): | |
| """Called when user clicks Download CV β reveals the access code on screen.""" | |
| return gr.update(visible=True, value=f"**Your access code:** `{ACCESS_CODE}`") | |
| def check_code(entered): | |
| if entered.strip() == ACCESS_CODE: | |
| return ( | |
| gr.update(visible=False), # hide gate | |
| gr.update(visible=True), # show app | |
| gr.update(visible=False), # hide error | |
| ) | |
| return ( | |
| gr.update(visible=True), | |
| gr.update(visible=False), | |
| gr.update(visible=True, value="Incorrect code β check your CV and try again."), | |
| ) | |
| dl_btn.click(fn=reveal_code, outputs=[code_reveal]) | |
| unlock_btn.click( | |
| fn=check_code, | |
| inputs=[code_input], | |
| outputs=[gate, main_app, error_msg], | |
| ) | |
| # The HF base image ships websockets v13+ which removed the legacy API that | |
| # the installed uvicorn version uses. Redirect uvicorn's WebSocket backend to | |
| # wsproto before demo.launch() triggers the import of uvicorn.protocols.websockets.auto. | |
| import sys as _sys | |
| from types import ModuleType as _ModuleType | |
| _ws_auto = _ModuleType("uvicorn.protocols.websockets.auto") | |
| _sys.modules["uvicorn.protocols.websockets.auto"] = _ws_auto | |
| try: | |
| from uvicorn.protocols.websockets.wsproto_impl import WSProtocol as _WSP | |
| _ws_auto.AutoWebSocketsProtocol = _WSP | |
| print("uvicorn β wsproto WebSocket backend active") | |
| except Exception as _e: | |
| print(f"WARNING: wsproto backend setup failed: {_e}") | |
| # When server_name="0.0.0.0", gradio constructs local_url as "http://0.0.0.0:7860/" | |
| # which is a bind address and can't be used as a connection target β url_ok fails. | |
| # The server IS running; patch url_ok so gradio doesn't block on this false negative. | |
| try: | |
| import gradio.networking as _gnet | |
| _gnet.url_ok = lambda url: True | |
| except Exception as _e: | |
| print(f"url_ok patch failed: {_e}") | |
| # Starlette 0.36+ changed TemplateResponse(name, context) β TemplateResponse(request, name, context). | |
| # Gradio 4.44.0 still uses the old signature, so "index.html" ends up as `request` | |
| # and the context dict ends up as `name`. Patch to restore the old behaviour. | |
| try: | |
| import starlette.templating as _st | |
| _orig_TR = _st.Jinja2Templates.TemplateResponse | |
| def _compat_TR(self, *args, **kwargs): | |
| if args and isinstance(args[0], str): | |
| # Old-style call: TemplateResponse("template.html", context_dict, ...) | |
| name = args[0] | |
| context = args[1] if len(args) > 1 else kwargs.pop("context", {}) | |
| request = context.get("request") | |
| return _orig_TR(self, request, name, context, **kwargs) | |
| return _orig_TR(self, *args, **kwargs) | |
| _st.Jinja2Templates.TemplateResponse = _compat_TR | |
| except Exception as _e: | |
| print(f"starlette TemplateResponse patch failed: {_e}") | |
| # Jinja2 3.1.4 bug: LRUCache uses unhashable dict as cache key when globals is | |
| # non-empty, causing TemplateResponse to crash. Convert TypeError β KeyError so | |
| # templates are loaded fresh when the key can't be cached. | |
| try: | |
| import jinja2.utils as _jutils | |
| _LRU = _jutils.LRUCache | |
| _orig_gi = _LRU.__getitem__ | |
| _orig_si = _LRU.__setitem__ | |
| def _safe_gi(self, key): | |
| try: | |
| return _orig_gi(self, key) | |
| except TypeError: | |
| raise KeyError(key) | |
| def _safe_si(self, key, value): | |
| try: | |
| _orig_si(self, key, value) | |
| except TypeError: | |
| pass | |
| _LRU.__getitem__ = _safe_gi | |
| _LRU.__setitem__ = _safe_si | |
| except Exception as _e: | |
| print(f"jinja2 LRUCache patch failed: {_e}") | |
| # gradio_client bug: _json_schema_to_python_type() can't handle bool schemas | |
| # (e.g. additionalProperties: true). Patch the internal recursive function so | |
| # any non-dict schema is treated as "Any" instead of raising APIInfoParseError. | |
| try: | |
| import gradio_client.utils as _gcu | |
| _orig_j2p = _gcu._json_schema_to_python_type | |
| def _safe_j2p(schema, defs=None): | |
| if not isinstance(schema, dict): | |
| return "Any" | |
| return _orig_j2p(schema, defs) | |
| _gcu._json_schema_to_python_type = _safe_j2p | |
| except Exception as _e: | |
| print(f"gradio_client patch failed: {_e}") | |
| demo.launch(server_name="0.0.0.0", show_error=True, show_api=False) | |