saliacoel commited on
Commit
2c8c6a2
·
verified ·
1 Parent(s): 271f489

Upload salia_hf_to_batch.py

Browse files
Files changed (1) hide show
  1. salia_hf_to_batch.py +222 -0
salia_hf_to_batch.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import zipfile
4
+ import shutil
5
+ import tempfile
6
+ from urllib.request import Request, urlopen
7
+ from urllib.error import HTTPError, URLError
8
+
9
+ import numpy as np
10
+ import torch
11
+ from PIL import Image
12
+
13
+ try:
14
+ import folder_paths # ComfyUI helper for temp dirs
15
+ except Exception:
16
+ folder_paths = None
17
+
18
+
19
+ def _get_cache_dir() -> str:
20
+ base_dir = None
21
+ if folder_paths is not None:
22
+ try:
23
+ base_dir = folder_paths.get_temp_directory()
24
+ except Exception:
25
+ base_dir = None
26
+
27
+ if not base_dir:
28
+ base_dir = tempfile.gettempdir()
29
+
30
+ cache_dir = os.path.join(base_dir, "hf_zip_cache")
31
+ os.makedirs(cache_dir, exist_ok=True)
32
+ return cache_dir
33
+
34
+
35
+ def _download_file(url: str, dest_path: str, timeout_sec: int = 60) -> None:
36
+ req = Request(url, headers={"User-Agent": "ComfyUI-HFZipLoader/1.0"})
37
+ try:
38
+ with urlopen(req, timeout=timeout_sec) as resp, open(dest_path, "wb") as out_f:
39
+ shutil.copyfileobj(resp, out_f)
40
+ except HTTPError as e:
41
+ raise ValueError(f"HTTP error while downloading: {url} (status={e.code})") from e
42
+ except URLError as e:
43
+ raise ValueError(f"Network error while downloading: {url} ({e.reason})") from e
44
+ except Exception as e:
45
+ raise ValueError(f"Unexpected error while downloading: {url} ({e})") from e
46
+
47
+
48
+ def _pil_to_tensor_rgb(pil_img: Image.Image) -> torch.Tensor:
49
+ """
50
+ Convert PIL image to ComfyUI IMAGE tensor: [H,W,3] float32 in [0..1].
51
+ """
52
+ if pil_img.mode != "RGB":
53
+ pil_img = pil_img.convert("RGB")
54
+
55
+ arr = np.asarray(pil_img, dtype=np.float32) / 255.0 # HWC
56
+ return torch.from_numpy(arr) # torch float32 HWC
57
+
58
+
59
+ class _ImageSizeMismatchError(ValueError):
60
+ """Raised when images in the zip do not share the same dimensions."""
61
+
62
+
63
+ def _alphanum_key(s: str):
64
+ """
65
+ Natural/alphanumeric sort key for filenames/paths.
66
+ Example: img_2.png comes before img_10.png.
67
+
68
+ Sorts by the full zip member name (including folders), case-insensitive.
69
+ """
70
+ s = (s or "").replace("\\", "/")
71
+ parts = re.split(r"(\d+)", s)
72
+
73
+ # Build a key composed of tagged tokens so Python never compares int vs str directly.
74
+ key = []
75
+ for p in parts:
76
+ if p.isdigit():
77
+ key.append((0, int(p)))
78
+ else:
79
+ key.append((1, p.lower()))
80
+ return key
81
+
82
+
83
+ def _load_images_from_zip(zip_path: str) -> torch.Tensor:
84
+ """
85
+ Forgiving loader:
86
+ - Accepts all filenames (any depth) in a zip
87
+ - Sorts members in alphanumeric (natural) order
88
+ - Tries to open each file as an image; skips files that PIL cannot read
89
+ - Enforces that all loaded images share the same dimensions
90
+
91
+ Returns:
92
+ [B,H,W,3] float32 in [0..1]
93
+ """
94
+ images = []
95
+ shapes = None
96
+ skipped = []
97
+
98
+ with zipfile.ZipFile(zip_path, "r") as zf:
99
+ members = [name for name in zf.namelist() if name and not name.endswith("/")]
100
+
101
+ if not members:
102
+ raise ValueError("ZIP is empty (no files found).")
103
+
104
+ members.sort(key=_alphanum_key)
105
+
106
+ for member_name in members:
107
+ try:
108
+ with zf.open(member_name) as fp:
109
+ with Image.open(fp) as im:
110
+ # Ensure image data is fully read while the zip file handle is still open
111
+ im.load()
112
+ t = _pil_to_tensor_rgb(im) # HWC, RGB, float32
113
+
114
+ if shapes is None:
115
+ shapes = tuple(t.shape)
116
+ else:
117
+ if tuple(t.shape) != shapes:
118
+ raise _ImageSizeMismatchError(
119
+ f"Image size mismatch in ZIP. Expected {shapes}, got {tuple(t.shape)} "
120
+ f"for {member_name}. All images must share the same dimensions."
121
+ )
122
+
123
+ images.append(t)
124
+
125
+ except _ImageSizeMismatchError:
126
+ # This is a hard error: the batch cannot be formed consistently.
127
+ raise
128
+ except Exception:
129
+ # Forgiving: ignore non-images, unreadable files, etc.
130
+ skipped.append(member_name)
131
+ continue
132
+
133
+ if not images:
134
+ raise ValueError(
135
+ "No loadable images found in ZIP. Ensure the archive contains valid image files "
136
+ "(png/jpg/webp/etc.)."
137
+ )
138
+
139
+ if skipped:
140
+ print(f"[HFLoadZipImageBatch] Skipped {len(skipped)} non-image/unreadable file(s) in ZIP.")
141
+
142
+ return torch.stack(images, dim=0) # BHWC
143
+
144
+
145
+ class HF_to_Batch:
146
+ """
147
+ Download public ZIP from Hugging Face resolve URL and output IMAGE batch.
148
+
149
+ URL format:
150
+ https://huggingface.co/{owner}/{repo}/resolve/{revision}/{index}.zip
151
+
152
+ Example:
153
+ owner=saliacoel, repo=pov_fs, revision=main, index=0
154
+ -> https://huggingface.co/saliacoel/pov_fs/resolve/main/0.zip
155
+ """
156
+
157
+ CATEGORY = "HuggingFace"
158
+ RETURN_TYPES = ("IMAGE", "STRING", "INT", "STRING")
159
+ RETURN_NAMES = ("images", "source_url", "count", "local_zip_path")
160
+ FUNCTION = "load"
161
+
162
+ @classmethod
163
+ def INPUT_TYPES(cls):
164
+ return {
165
+ "required": {
166
+ "repo": ("STRING", {"default": "pov_fs", "multiline": False}),
167
+ "index": ("INT", {"default": 0, "min": 0, "max": 1000000, "step": 1}),
168
+ },
169
+ "optional": {
170
+ "owner": ("STRING", {"default": "saliacoel", "multiline": False}),
171
+ "revision": ("STRING", {"default": "main", "multiline": False}),
172
+ "force_redownload": ("BOOLEAN", {"default": False}),
173
+ },
174
+ }
175
+
176
+ def load(
177
+ self,
178
+ repo: str,
179
+ index: int,
180
+ owner: str = "saliacoel",
181
+ revision: str = "main",
182
+ force_redownload: bool = False,
183
+ ):
184
+ repo = (repo or "").strip()
185
+ owner = (owner or "").strip()
186
+ revision = (revision or "").strip()
187
+
188
+ if not repo:
189
+ raise ValueError("repo must be a non-empty string (e.g., 'pov_fs' or 'car').")
190
+ if not owner:
191
+ raise ValueError("owner must be a non-empty string (e.g., 'saliacoel').")
192
+ if index is None or int(index) < 0:
193
+ raise ValueError("index must be an integer >= 0.")
194
+
195
+ index = int(index)
196
+
197
+ source_url = f"https://huggingface.co/{owner}/{repo}/resolve/{revision}/{index}.zip"
198
+
199
+ cache_dir = _get_cache_dir()
200
+ local_zip_path = os.path.join(cache_dir, f"{owner}__{repo}__{revision}__{index}.zip")
201
+
202
+ if (
203
+ force_redownload
204
+ or (not os.path.exists(local_zip_path))
205
+ or (os.path.getsize(local_zip_path) == 0)
206
+ ):
207
+ _download_file(source_url, local_zip_path)
208
+
209
+ images = _load_images_from_zip(local_zip_path)
210
+ count = int(images.shape[0])
211
+
212
+ print(f"[HFLoadZipImageBatch] Loaded {count} image(s) from {source_url}")
213
+ return (images, source_url, count, local_zip_path)
214
+
215
+
216
+ NODE_CLASS_MAPPINGS = {
217
+ "HF_to_Batch": HF_to_Batch,
218
+ }
219
+
220
+ NODE_DISPLAY_NAME_MAPPINGS = {
221
+ "HF_to_Batch": "HF_to_Batch",
222
+ }