| import sys | |
| import launch | |
| import platform | |
| import os | |
| import shutil | |
| import site | |
| import glob | |
| import re | |
| dirname = os.path.dirname(__file__) | |
| repo_dir = os.path.join(dirname, "kohya_ss") | |
| def prepare_environment(): | |
| torch_command = os.environ.get( | |
| "TORCH_COMMAND", | |
| "pip install torch==2.0.0+cu118 torchvision==0.15.1+cu118 --extra-index-url https://download.pytorch.org/whl/cu118", | |
| ) | |
| sd_scripts_repo = os.environ.get("SD_SCRIPTS_REPO", "https://github.com/kohya-ss/sd-scripts.git") | |
| sd_scripts_branch = os.environ.get("SD_SCRIPTS_BRANCH", "main") | |
| requirements_file = os.environ.get("REQS_FILE", "requirements.txt") | |
| sys.argv, skip_install = launch.extract_arg(sys.argv, "--skip-install") | |
| sys.argv, disable_strict_version = launch.extract_arg( | |
| sys.argv, "--disable-strict-version" | |
| ) | |
| sys.argv, skip_torch_cuda_test = launch.extract_arg( | |
| sys.argv, "--skip-torch-cuda-test" | |
| ) | |
| sys.argv, skip_checkout_repo = launch.extract_arg(sys.argv, "--skip-checkout-repo") | |
| sys.argv, update = launch.extract_arg(sys.argv, "--update") | |
| sys.argv, reinstall_xformers = launch.extract_arg(sys.argv, "--reinstall-xformers") | |
| sys.argv, reinstall_torch = launch.extract_arg(sys.argv, "--reinstall-torch") | |
| xformers = "--xformers" in sys.argv | |
| ngrok = "--ngrok" in sys.argv | |
| if skip_install: | |
| return | |
| if ( | |
| reinstall_torch | |
| or not launch.is_installed("torch") | |
| or not launch.is_installed("torchvision") | |
| ) and not disable_strict_version: | |
| launch.run( | |
| f'"{launch.python}" -m {torch_command}', | |
| "Installing torch and torchvision", | |
| "Couldn't install torch", | |
| ) | |
| if not skip_torch_cuda_test: | |
| launch.run_python( | |
| "import torch; assert torch.cuda.is_available(), 'Torch is not able to use GPU; add --skip-torch-cuda-test to COMMANDLINE_ARGS variable to disable this check'" | |
| ) | |
| if (not launch.is_installed("xformers") or reinstall_xformers) and xformers: | |
| launch.run_pip("install xformers --pre", "xformers") | |
| if update and os.path.exists(repo_dir): | |
| launch.run(f'cd "{repo_dir}" && {launch.git} fetch --prune') | |
| launch.run(f'cd "{repo_dir}" && {launch.git} reset --hard origin/main') | |
| elif not os.path.exists(repo_dir): | |
| launch.run( | |
| f'{launch.git} clone {sd_scripts_repo} "{repo_dir}"' | |
| ) | |
| if not skip_checkout_repo: | |
| launch.run(f'cd "{repo_dir}" && {launch.git} checkout {sd_scripts_branch}') | |
| if not launch.is_installed("gradio"): | |
| launch.run_pip("install gradio==3.16.2", "gradio") | |
| if not launch.is_installed("pyngrok") and ngrok: | |
| launch.run_pip("install pyngrok", "ngrok") | |
| if platform.system() == "Linux": | |
| if not launch.is_installed("triton"): | |
| launch.run_pip("install triton", "triton") | |
| if disable_strict_version: | |
| with open(os.path.join(repo_dir, requirements_file), "r") as f: | |
| txt = f.read() | |
| requirements = [ | |
| re.split("==|<|>", a)[0] | |
| for a in txt.split("\n") | |
| if (not a.startswith("#") and a != ".") | |
| ] | |
| requirements = " ".join(requirements) | |
| launch.run_pip( | |
| f'install "{requirements}" "{repo_dir}"', | |
| "requirements for kohya sd-scripts", | |
| ) | |
| else: | |
| launch.run( | |
| f'cd "{repo_dir}" && "{launch.python}" -m pip install -r requirements.txt', | |
| desc=f"Installing requirements for kohya sd-scripts", | |
| errdesc=f"Couldn't install requirements for kohya sd-scripts", | |
| ) | |
| if platform.system() == "Windows": | |
| for file in glob.glob(os.path.join(repo_dir, "bitsandbytes_windows", "*")): | |
| filename = os.path.basename(file) | |
| for dir in site.getsitepackages(): | |
| outfile = ( | |
| os.path.join(dir, "bitsandbytes", "cuda_setup", filename) | |
| if filename == "main.py" | |
| else os.path.join(dir, "bitsandbytes", filename) | |
| ) | |
| if not os.path.exists(os.path.dirname(outfile)): | |
| continue | |
| shutil.copy(file, outfile) | |
| if __name__ == "__main__": | |
| prepare_environment() | |