| from __future__ import annotations
|
| import argparse
|
| import logging
|
| import os
|
| import re
|
| import tempfile
|
| import zipfile
|
| from dataclasses import dataclass
|
| from functools import cached_property
|
| from pathlib import Path
|
| from typing import TypedDict
|
|
|
| import requests
|
| from typing_extensions import NotRequired
|
| from comfy.cli_args import DEFAULT_VERSION_STRING
|
|
|
|
|
| REQUEST_TIMEOUT = 10
|
|
|
|
|
| class Asset(TypedDict):
|
| url: str
|
|
|
|
|
| class Release(TypedDict):
|
| id: int
|
| tag_name: str
|
| name: str
|
| prerelease: bool
|
| created_at: str
|
| published_at: str
|
| body: str
|
| assets: NotRequired[list[Asset]]
|
|
|
|
|
| @dataclass
|
| class FrontEndProvider:
|
| owner: str
|
| repo: str
|
|
|
| @property
|
| def folder_name(self) -> str:
|
| return f"{self.owner}_{self.repo}"
|
|
|
| @property
|
| def release_url(self) -> str:
|
| return f"https://api.github.com/repos/{self.owner}/{self.repo}/releases"
|
|
|
| @cached_property
|
| def all_releases(self) -> list[Release]:
|
| releases = []
|
| api_url = self.release_url
|
| while api_url:
|
| response = requests.get(api_url, timeout=REQUEST_TIMEOUT)
|
| response.raise_for_status()
|
| releases.extend(response.json())
|
|
|
| if "next" in response.links:
|
| api_url = response.links["next"]["url"]
|
| else:
|
| api_url = None
|
| return releases
|
|
|
| @cached_property
|
| def latest_release(self) -> Release:
|
| latest_release_url = f"{self.release_url}/latest"
|
| response = requests.get(latest_release_url, timeout=REQUEST_TIMEOUT)
|
| response.raise_for_status()
|
| return response.json()
|
|
|
| def get_release(self, version: str) -> Release:
|
| if version == "latest":
|
| return self.latest_release
|
| else:
|
| for release in self.all_releases:
|
| if release["tag_name"] in [version, f"v{version}"]:
|
| return release
|
| raise ValueError(f"Version {version} not found in releases")
|
|
|
|
|
| def download_release_asset_zip(release: Release, destination_path: str) -> None:
|
| """Download dist.zip from github release."""
|
| asset_url = None
|
| for asset in release.get("assets", []):
|
| if asset["name"] == "dist.zip":
|
| asset_url = asset["url"]
|
| break
|
|
|
| if not asset_url:
|
| raise ValueError("dist.zip not found in the release assets")
|
|
|
|
|
| with tempfile.TemporaryFile() as tmp_file:
|
| headers = {"Accept": "application/octet-stream"}
|
| response = requests.get(
|
| asset_url, headers=headers, allow_redirects=True, timeout=REQUEST_TIMEOUT
|
| )
|
| response.raise_for_status()
|
|
|
|
|
| tmp_file.write(response.content)
|
|
|
|
|
| tmp_file.seek(0)
|
|
|
|
|
| with zipfile.ZipFile(tmp_file, "r") as zip_ref:
|
| zip_ref.extractall(destination_path)
|
|
|
|
|
| class FrontendManager:
|
| DEFAULT_FRONTEND_PATH = str(Path(__file__).parents[1] / "web")
|
| CUSTOM_FRONTENDS_ROOT = str(Path(__file__).parents[1] / "web_custom_versions")
|
|
|
| @classmethod
|
| def parse_version_string(cls, value: str) -> tuple[str, str, str]:
|
| """
|
| Args:
|
| value (str): The version string to parse.
|
|
|
| Returns:
|
| tuple[str, str]: A tuple containing provider name and version.
|
|
|
| Raises:
|
| argparse.ArgumentTypeError: If the version string is invalid.
|
| """
|
| VERSION_PATTERN = r"^([a-zA-Z0-9][a-zA-Z0-9-]{0,38})/([a-zA-Z0-9_.-]+)@(v?\d+\.\d+\.\d+|latest)$"
|
| match_result = re.match(VERSION_PATTERN, value)
|
| if match_result is None:
|
| raise argparse.ArgumentTypeError(f"Invalid version string: {value}")
|
|
|
| return match_result.group(1), match_result.group(2), match_result.group(3)
|
|
|
| @classmethod
|
| def init_frontend_unsafe(cls, version_string: str) -> str:
|
| """
|
| Initializes the frontend for the specified version.
|
|
|
| Args:
|
| version_string (str): The version string.
|
|
|
| Returns:
|
| str: The path to the initialized frontend.
|
|
|
| Raises:
|
| Exception: If there is an error during the initialization process.
|
| main error source might be request timeout or invalid URL.
|
| """
|
| if version_string == DEFAULT_VERSION_STRING:
|
| return cls.DEFAULT_FRONTEND_PATH
|
|
|
| repo_owner, repo_name, version = cls.parse_version_string(version_string)
|
| provider = FrontEndProvider(repo_owner, repo_name)
|
| release = provider.get_release(version)
|
|
|
| semantic_version = release["tag_name"].lstrip("v")
|
| web_root = str(
|
| Path(cls.CUSTOM_FRONTENDS_ROOT) / provider.folder_name / semantic_version
|
| )
|
| if not os.path.exists(web_root):
|
| os.makedirs(web_root, exist_ok=True)
|
| logging.info(
|
| "Downloading frontend(%s) version(%s) to (%s)",
|
| provider.folder_name,
|
| semantic_version,
|
| web_root,
|
| )
|
| logging.debug(release)
|
| download_release_asset_zip(release, destination_path=web_root)
|
| return web_root
|
|
|
| @classmethod
|
| def init_frontend(cls, version_string: str) -> str:
|
| """
|
| Initializes the frontend with the specified version string.
|
|
|
| Args:
|
| version_string (str): The version string to initialize the frontend with.
|
|
|
| Returns:
|
| str: The path of the initialized frontend.
|
| """
|
| try:
|
| return cls.init_frontend_unsafe(version_string)
|
| except Exception as e:
|
| logging.error("Failed to initialize frontend: %s", e)
|
| logging.info("Falling back to the default frontend.")
|
| return cls.DEFAULT_FRONTEND_PATH
|
|
|