Vanwise commited on
Commit
3cb3e08
·
1 Parent(s): 74c5e9a

Upload 3 files

Browse files
Files changed (3) hide show
  1. _main.ipynb +86 -0
  2. _start.ipynb +43 -0
  3. apply_func.py +285 -0
_main.ipynb ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "# start.sh\n",
10
+ "import os\n",
11
+ "import subprocess\n",
12
+ "print(\"\\nEasy Diffusion - v3\\n\")\n",
13
+ "current_dir = os.getcwd()\n",
14
+ "update_branch = os.path.isfile(f\"{current_dir}/sd-ui-files/scripts/get_config.py\") and subprocess.run(f\"python {current_dir}/sd-ui-files/scripts/get_config.py --default=main update_branch\", shell=True, capture_output=True, text=True).stdout.strip() or \"main\"\n",
15
+ "\n",
16
+ "if os.path.isfile(f\"{current_dir}/install_status.txt\") and \"sd_ui_git_cloned\" in open(f\"{current_dir}/install_status.txt\").read():\n",
17
+ " print(f\"Easy Diffusion's git repository was already installed. Updating from {update_branch}..\")\n",
18
+ " os.chdir(\"sd-ui-files\")\n",
19
+ " subprocess.run(\"git add -A .\", shell=True, capture_output=True, text=True)\n",
20
+ " subprocess.run(\"git stash\", shell=True, capture_output=True, text=True)\n",
21
+ " subprocess.run(\"git reset --hard\", shell=True, capture_output=True, text=True)\n",
22
+ " subprocess.run(f\"git -c advice.detachedHead=false checkout {update_branch}\", shell=True, capture_output=True, text=True)\n",
23
+ " subprocess.run(\"git pull\", shell=True, capture_output=True, text=True)\n",
24
+ " os.chdir(\"..\")\n",
25
+ "else:\n",
26
+ " print(\"\\nDownloading Easy Diffusion..\\n\")\n",
27
+ " print(f\"Using the {update_branch} channel\\n\")\n",
28
+ " if subprocess.run(f\"git clone -b {update_branch} https://github.com/easydiffusion/easydiffusion.git sd-ui-files\", shell=True, capture_output=True, text=True).returncode == 0:\n",
29
+ " with open(f\"{current_dir}/sd-ui-files/scripts/install_status.txt\", \"a\") as status_file:\n",
30
+ " status_file.write(\"sd_ui_git_cloned\\n\")\n",
31
+ " else:\n",
32
+ " print(\"git clone failed\")\n",
33
+ "\n",
34
+ "os.chdir(\"sd-ui-files\")\n",
35
+ "\n",
36
+ "def fix_script():\n",
37
+ " file_path = f\"{current_dir}/sd-ui-files/scripts/check_modules.py\"\n",
38
+ "\n",
39
+ " with open(file_path, \"r\") as file:\n",
40
+ " content = file.read()\n",
41
+ "\n",
42
+ " # 将 python3.8 修改为 python3.10\n",
43
+ " content = content.replace(\"python3.8\", \"python3.10\")\n",
44
+ "\n",
45
+ " # Replace 'os.environ[\"PYTHONPATH\"]' with 'colab_fix'\n",
46
+ " content = content.replace('os.environ[\"PYTHONPATH\"]', 'colab_fix')\n",
47
+ "\n",
48
+ " # 删除 os.chdir(\"stable-diffusion\")\n",
49
+ " content = content.replace('os.chdir(\"stable-diffusion\")', \"\")\n",
50
+ "\n",
51
+ " # 将修改后的内容写回文件\n",
52
+ " with open(file_path, \"w\") as file:\n",
53
+ " file.write(content)\n",
54
+ "\n",
55
+ " print(\"check_modules.py file has been repaired.\")\n",
56
+ "\n",
57
+ "fix_script()\n",
58
+ "\n",
59
+ "!python ./scripts/check_modules.py #安装依赖&环境\n",
60
+ "\n",
61
+ "\n",
62
+ "!curl -Lo /content/models/stable-diffusion/chilloutmix_NiPrunedFp32Fix-chonghui.safetensors https://huggingface.co/spaces/weo1101/111/resolve/main/chilloutmix_NiPrunedFp32Fix-inpainting.inpainting.safetensors\n",
63
+ "!curl -Lo /content/models/stable-diffusion/chilloutmix_NiPrunedFp32Fix.safetensors https://huggingface.co/spaces/weo1101/111/resolve/main/chilloutmix_NiPrunedFp32Fix-inpainting.inpainting.safetensors\n",
64
+ "\n",
65
+ "!curl -Lo content/apply_func.py https://github.com/Van-wise/sd-colab/raw/main/easydiffusion/apply_func.py\n",
66
+ "!rm -rf /usr/local/lib/python3.10/dist-packages/pytorch_lightning/utilities/apply_func.py\n",
67
+ "!cp /content/apply_func.py /usr/local/lib/python3.10/dist-packages/pytorch_lightning/utilities/apply_func.py\n",
68
+ "\n",
69
+ "from pathlib import Path\n",
70
+ "os.environ[\"SD_UI_PATH\"] = str(Path(Path.cwd(), \"ui\"))\n",
71
+ "os.environ[\"INSTALL_ENV_DIR\"] = \"/usr/local\"\n",
72
+ "print(str(Path(Path.cwd(), \"ui\")))\n",
73
+ "\n",
74
+ "os.chdir(\"sd-ui-files\")\n",
75
+ "!python ./scripts/check_modules.py --launch-uvicorn"
76
+ ]
77
+ }
78
+ ],
79
+ "metadata": {
80
+ "language_info": {
81
+ "name": "python"
82
+ }
83
+ },
84
+ "nbformat": 4,
85
+ "nbformat_minor": 2
86
+ }
_start.ipynb ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "import threading\n",
10
+ "import time\n",
11
+ "import os\n",
12
+ "if not os.path.exists('/content/key/id_rsa'):\n",
13
+ " !mkdir -p /content/key/id_rsa\n",
14
+ " !chmod 600 /content/key/id_rsa\n",
15
+ " !ssh-keygen -t rsa -b 4096 -N \"\" -f /content/key/id_rsa\n",
16
+ "\n",
17
+ "def tunnel():\n",
18
+ " time.sleep(1)\n",
19
+ " from pyngrok import ngrok\n",
20
+ " ngrok_tunnel = ngrok.connect(9000, \"http\")\n",
21
+ " from pycloudflared import try_cloudflare\n",
22
+ " cloudflare_url = try_cloudflare(9000, verbose=False)\n",
23
+ " print(ngrok_tunnel)\n",
24
+ " print(cloudflare_url)\n",
25
+ " time.sleep(1)\n",
26
+ " !ssh -R 80:127.0.0.1:9000 -o StrictHostKeyChecking=no -i /content/key/id_rsa remote.moe\n",
27
+ "\n",
28
+ "threading.Thread(target=tunnel, daemon=True).start()\n",
29
+ "\n",
30
+ "os.chdir(\"sd-ui-files\")\n",
31
+ "!python ./scripts/check_modules.py --launch-uvicorn\n",
32
+ "#!python ./scripts//ui/main.py"
33
+ ]
34
+ }
35
+ ],
36
+ "metadata": {
37
+ "language_info": {
38
+ "name": "python"
39
+ }
40
+ },
41
+ "nbformat": 4,
42
+ "nbformat_minor": 2
43
+ }
apply_func.py ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright The PyTorch Lightning team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import dataclasses
16
+ import operator
17
+ from abc import ABC
18
+ from collections import OrderedDict
19
+ from collections.abc import Mapping, Sequence
20
+ from copy import copy
21
+ from functools import partial
22
+ from typing import Any, Callable, Optional, Union
23
+
24
+ import numpy as np
25
+ import torch
26
+
27
+ from pytorch_lightning.utilities.exceptions import MisconfigurationException
28
+ from pytorch_lightning.utilities.imports import _compare_version, _TORCHTEXT_AVAILABLE
29
+ from pytorch_lightning.utilities.imports import _module_available
30
+
31
+ if _TORCHTEXT_AVAILABLE:
32
+ if _module_available("torchtext.legacy.data"):
33
+ from torchtext.legacy.data import Batch
34
+ else:
35
+ from torchtext.data import Batch
36
+ else:
37
+ Batch = type(None)
38
+
39
+
40
+ def to_dtype_tensor(value, dtype: torch.dtype = None, device: torch.device = None):
41
+ if device is None:
42
+ raise MisconfigurationException("device (torch.device) should be provided.")
43
+ return torch.tensor(value, dtype=dtype, device=device)
44
+
45
+
46
+ def from_numpy(value, device: torch.device = None):
47
+ if device is None:
48
+ raise MisconfigurationException("device (torch.device) should be provided.")
49
+ return torch.from_numpy(value).to(device)
50
+
51
+
52
+ CONVERSION_DTYPES = [
53
+ # bool -> uint8 as bool -> torch.bool triggers RuntimeError: Unsupported data type for NCCL process group
54
+ (bool, partial(to_dtype_tensor, dtype=torch.uint8)),
55
+ (int, partial(to_dtype_tensor, dtype=torch.int)),
56
+ (float, partial(to_dtype_tensor, dtype=torch.float)),
57
+ (np.ndarray, from_numpy),
58
+ ]
59
+
60
+
61
+ def _is_namedtuple(obj: object) -> bool:
62
+ # https://github.com/pytorch/pytorch/blob/v1.8.1/torch/nn/parallel/scatter_gather.py#L4-L8
63
+ return isinstance(obj, tuple) and hasattr(obj, "_asdict") and hasattr(obj, "_fields")
64
+
65
+
66
+ def _is_dataclass_instance(obj):
67
+ # https://docs.python.org/3/library/dataclasses.html#module-level-decorators-classes-and-functions
68
+ return dataclasses.is_dataclass(obj) and not isinstance(obj, type)
69
+
70
+
71
+ def apply_to_collection(
72
+ data: Any,
73
+ dtype: Union[type, tuple],
74
+ function: Callable,
75
+ *args,
76
+ wrong_dtype: Optional[Union[type, tuple]] = None,
77
+ include_none: bool = True,
78
+ **kwargs
79
+ ) -> Any:
80
+ """
81
+ Recursively applies a function to all elements of a certain dtype.
82
+
83
+ Args:
84
+ data: the collection to apply the function to
85
+ dtype: the given function will be applied to all elements of this dtype
86
+ function: the function to apply
87
+ *args: positional arguments (will be forwarded to calls of ``function``)
88
+ wrong_dtype: the given function won't be applied if this type is specified and the given collections
89
+ is of the ``wrong_dtype`` even if it is of type ``dtype``
90
+ include_none: Whether to include an element if the output of ``function`` is ``None``.
91
+ **kwargs: keyword arguments (will be forwarded to calls of ``function``)
92
+
93
+ Returns:
94
+ The resulting collection
95
+ """
96
+ # Breaking condition
97
+ if isinstance(data, dtype) and (wrong_dtype is None or not isinstance(data, wrong_dtype)):
98
+ return function(data, *args, **kwargs)
99
+
100
+ elem_type = type(data)
101
+
102
+ # Recursively apply to collection items
103
+ if isinstance(data, Mapping):
104
+ out = []
105
+ for k, v in data.items():
106
+ v = apply_to_collection(
107
+ v, dtype, function, *args, wrong_dtype=wrong_dtype, include_none=include_none, **kwargs
108
+ )
109
+ if include_none or v is not None:
110
+ out.append((k, v))
111
+ return elem_type(OrderedDict(out))
112
+
113
+ is_namedtuple = _is_namedtuple(data)
114
+ is_sequence = isinstance(data, Sequence) and not isinstance(data, str)
115
+ if is_namedtuple or is_sequence:
116
+ out = []
117
+ for d in data:
118
+ v = apply_to_collection(
119
+ d, dtype, function, *args, wrong_dtype=wrong_dtype, include_none=include_none, **kwargs
120
+ )
121
+ if include_none or v is not None:
122
+ out.append(v)
123
+ return elem_type(*out) if is_namedtuple else elem_type(out)
124
+
125
+ if _is_dataclass_instance(data):
126
+ out = {}
127
+ for field in data.__dataclass_fields__:
128
+ v = apply_to_collection(
129
+ getattr(data, field),
130
+ dtype,
131
+ function,
132
+ *args,
133
+ wrong_dtype=wrong_dtype,
134
+ include_none=include_none,
135
+ **kwargs
136
+ )
137
+ if include_none or v is not None:
138
+ out[field] = v
139
+ return elem_type(**out)
140
+
141
+ # data is neither of dtype, nor a collection
142
+ return data
143
+
144
+
145
+ def apply_to_collections(
146
+ data1: Optional[Any],
147
+ data2: Optional[Any],
148
+ dtype: Union[type, tuple],
149
+ function: Callable,
150
+ *args,
151
+ wrong_dtype: Optional[Union[type, tuple]] = None,
152
+ **kwargs
153
+ ) -> Any:
154
+ """
155
+ Zips two collections and applies a function to their items of a certain dtype.
156
+
157
+ Args:
158
+ data1: The first collection
159
+ data2: The second collection
160
+ dtype: the given function will be applied to all elements of this dtype
161
+ function: the function to apply
162
+ *args: positional arguments (will be forwarded to calls of ``function``)
163
+ wrong_dtype: the given function won't be applied if this type is specified and the given collections
164
+ is of the ``wrong_dtype`` even if it is of type ``dtype``
165
+ **kwargs: keyword arguments (will be forwarded to calls of ``function``)
166
+
167
+ Returns:
168
+ The resulting collection
169
+
170
+ Raises:
171
+ AssertionError:
172
+ If sequence collections have different data sizes.
173
+ """
174
+ if data1 is None and data2 is not None:
175
+ # in case they were passed reversed
176
+ data1, data2 = data2, None
177
+
178
+ elem_type = type(data1)
179
+
180
+ if isinstance(data1, dtype) and data2 is not None and (wrong_dtype is None or not isinstance(data1, wrong_dtype)):
181
+ return function(data1, data2, *args, **kwargs)
182
+
183
+ if isinstance(data1, Mapping) and data2 is not None:
184
+ # use union because we want to fail if a key does not exist in both
185
+ zipped = {k: (data1[k], data2[k]) for k in data1.keys() | data2.keys()}
186
+ return elem_type(
187
+ {
188
+ k: apply_to_collections(*v, dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs)
189
+ for k, v in zipped.items()
190
+ }
191
+ )
192
+
193
+ is_namedtuple = _is_namedtuple(data1)
194
+ is_sequence = isinstance(data1, Sequence) and not isinstance(data1, str)
195
+ if (is_namedtuple or is_sequence) and data2 is not None:
196
+ assert len(data1) == len(data2), "Sequence collections have different sizes"
197
+ out = [
198
+ apply_to_collections(v1, v2, dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs)
199
+ for v1, v2 in zip(data1, data2)
200
+ ]
201
+ return elem_type(*out) if is_namedtuple else elem_type(out)
202
+
203
+ return apply_to_collection(data1, dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs)
204
+
205
+
206
+ class TransferableDataType(ABC):
207
+ """
208
+ A custom type for data that can be moved to a torch device via `.to(...)`.
209
+ Example:
210
+ >>> isinstance(dict, TransferableDataType)
211
+ False
212
+ >>> isinstance(torch.rand(2, 3), TransferableDataType)
213
+ True
214
+ >>> class CustomObject:
215
+ ... def __init__(self):
216
+ ... self.x = torch.rand(2, 2)
217
+ ... def to(self, device):
218
+ ... self.x = self.x.to(device)
219
+ ... return self
220
+ >>> isinstance(CustomObject(), TransferableDataType)
221
+ True
222
+ """
223
+
224
+ @classmethod
225
+ def __subclasshook__(cls, subclass):
226
+ if cls is TransferableDataType:
227
+ to = getattr(subclass, "to", None)
228
+ return callable(to)
229
+ return NotImplemented
230
+
231
+
232
+ def move_data_to_device(batch: Any, device: torch.device):
233
+ """
234
+ Transfers a collection of data to the given device. Any object that defines a method
235
+ ``to(device)`` will be moved and all other objects in the collection will be left untouched.
236
+
237
+ Args:
238
+ batch: A tensor or collection of tensors or anything that has a method `.to(...)`.
239
+ See :func:`apply_to_collection` for a list of supported collection types.
240
+ device: The device to which the data should be moved
241
+
242
+ Return:
243
+ the same collection but with all contained tensors residing on the new device.
244
+
245
+ See Also:
246
+ - :meth:`torch.Tensor.to`
247
+ - :class:`torch.device`
248
+ """
249
+
250
+ def batch_to(data):
251
+ # try to move torchtext data first
252
+ if _TORCHTEXT_AVAILABLE and isinstance(data, Batch):
253
+
254
+ # Shallow copy because each Batch has a reference to Dataset which contains all examples
255
+ device_data = copy(data)
256
+ for field, field_value in data.dataset.fields.items():
257
+ if field_value is None:
258
+ continue
259
+ device_field = move_data_to_device(getattr(data, field), device)
260
+ setattr(device_data, field, device_field)
261
+ return device_data
262
+
263
+ kwargs = dict(non_blocking=True) if isinstance(data, torch.Tensor) else {}
264
+ data_output = data.to(device, **kwargs)
265
+ if data_output is not None:
266
+ return data_output
267
+ # user wrongly implemented the `TransferableDataType` and forgot to return `self`.
268
+ return data
269
+
270
+ dtype = (TransferableDataType, Batch) if _TORCHTEXT_AVAILABLE else TransferableDataType
271
+ return apply_to_collection(batch, dtype=dtype, function=batch_to)
272
+
273
+
274
+ def convert_to_tensors(data: Any, device: torch.device) -> Any:
275
+ if device is None:
276
+ raise MisconfigurationException("`torch.device` should be provided.")
277
+
278
+ for src_dtype, conversion_func in CONVERSION_DTYPES:
279
+ data = apply_to_collection(data, src_dtype, conversion_func, device=device)
280
+
281
+ def _move_to_device_and_make_contiguous(t: torch.Tensor, device: torch.device) -> torch.Tensor:
282
+ return t.to(device).contiguous()
283
+
284
+ data = apply_to_collection(data, torch.Tensor, _move_to_device_and_make_contiguous, device=device)
285
+ return data