File size: 5,245 Bytes
e00eceb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
from __future__ import annotations

from typing import TypedDict
import os
import folder_paths
import glob
from aiohttp import web
import hashlib


class Source:
    custom_node = "custom_node"
    templates = "templates"

class SubgraphEntry(TypedDict):
    source: str
    """
    Source of subgraph - custom_nodes vs templates.
    """
    path: str
    """
    Relative path of the subgraph file.
    For custom nodes, will be the relative directory like <custom_node_dir>/subgraphs/<name>.json
    """
    name: str
    """
    Name of subgraph file.
    """
    info: CustomNodeSubgraphEntryInfo
    """
    Additional info about subgraph; in the case of custom_nodes, will contain nodepack name
    """
    data: str

class CustomNodeSubgraphEntryInfo(TypedDict):
    node_pack: str
    """Node pack name."""

class SubgraphManager:
    def __init__(self):
        self.cached_custom_node_subgraphs: dict[SubgraphEntry] | None = None
        self.cached_blueprint_subgraphs: dict[SubgraphEntry] | None = None

    def _create_entry(self, file: str, source: str, node_pack: str) -> tuple[str, SubgraphEntry]:
        """Create a subgraph entry from a file path. Expects normalized path (forward slashes)."""
        entry_id = hashlib.sha256(f"{source}{file}".encode()).hexdigest()
        entry: SubgraphEntry = {
            "source": source,
            "name": os.path.splitext(os.path.basename(file))[0],
            "path": file,
            "info": {"node_pack": node_pack},
        }
        return entry_id, entry

    async def load_entry_data(self, entry: SubgraphEntry):
        with open(entry['path'], 'r', encoding='utf-8') as f:
            entry['data'] = f.read()
        return entry

    async def sanitize_entry(self, entry: SubgraphEntry | None, remove_data=False) -> SubgraphEntry | None:
        if entry is None:
            return None
        entry = entry.copy()
        entry.pop('path', None)
        if remove_data:
            entry.pop('data', None)
        return entry

    async def sanitize_entries(self, entries: dict[str, SubgraphEntry], remove_data=False) -> dict[str, SubgraphEntry]:
        entries = entries.copy()
        for key in list(entries.keys()):
            entries[key] = await self.sanitize_entry(entries[key], remove_data)
        return entries

    async def get_custom_node_subgraphs(self, loadedModules, force_reload=False):
        """Load subgraphs from custom nodes."""
        if not force_reload and self.cached_custom_node_subgraphs is not None:
            return self.cached_custom_node_subgraphs

        subgraphs_dict: dict[SubgraphEntry] = {}
        for folder in folder_paths.get_folder_paths("custom_nodes"):
            pattern = os.path.join(folder, "*/subgraphs/*.json")
            for file in glob.glob(pattern):
                file = file.replace('\\', '/')
                node_pack = "custom_nodes." + file.split('/')[-3]
                entry_id, entry = self._create_entry(file, Source.custom_node, node_pack)
                subgraphs_dict[entry_id] = entry

        self.cached_custom_node_subgraphs = subgraphs_dict
        return subgraphs_dict

    async def get_blueprint_subgraphs(self, force_reload=False):
        """Load subgraphs from the blueprints directory."""
        if not force_reload and self.cached_blueprint_subgraphs is not None:
            return self.cached_blueprint_subgraphs

        subgraphs_dict: dict[SubgraphEntry] = {}
        blueprints_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'blueprints')

        if os.path.exists(blueprints_dir):
            for file in glob.glob(os.path.join(blueprints_dir, "*.json")):
                file = file.replace('\\', '/')
                entry_id, entry = self._create_entry(file, Source.templates, "comfyui")
                subgraphs_dict[entry_id] = entry

        self.cached_blueprint_subgraphs = subgraphs_dict
        return subgraphs_dict

    async def get_all_subgraphs(self, loadedModules, force_reload=False):
        """Get all subgraphs from all sources (custom nodes and blueprints)."""
        custom_node_subgraphs = await self.get_custom_node_subgraphs(loadedModules, force_reload)
        blueprint_subgraphs = await self.get_blueprint_subgraphs(force_reload)
        return {**custom_node_subgraphs, **blueprint_subgraphs}

    async def get_subgraph(self, id: str, loadedModules):
        """Get a specific subgraph by ID from any source."""
        entry = (await self.get_all_subgraphs(loadedModules)).get(id)
        if entry is not None and entry.get('data') is None:
            await self.load_entry_data(entry)
        return entry

    def add_routes(self, routes, loadedModules):
        @routes.get("/global_subgraphs")
        async def get_global_subgraphs(request):
            subgraphs_dict = await self.get_all_subgraphs(loadedModules)
            return web.json_response(await self.sanitize_entries(subgraphs_dict, remove_data=True))

        @routes.get("/global_subgraphs/{id}")
        async def get_global_subgraph(request):
            id = request.match_info.get("id", None)
            subgraph = await self.get_subgraph(id, loadedModules)
            return web.json_response(await self.sanitize_entry(subgraph))