import importlib.util import os import shlex import subprocess import sys commandline_args = os.environ.get("COMMANDLINE_ARGS", "") sys.argv += shlex.split(commandline_args) python = sys.executable git = os.environ.get("GIT", "git") index_url = os.environ.get("INDEX_URL", "") stored_commit_hash = None skip_install = False def run(command, desc=None, errdesc=None, custom_env=None): if desc is not None: print(desc) result = subprocess.run( command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True, env=os.environ if custom_env is None else custom_env, ) if result.returncode != 0: message = f"""{errdesc or 'Error running command'}. Command: {command} Error code: {result.returncode} stdout: {result.stdout.decode(encoding="utf8", errors="ignore") if len(result.stdout)>0 else ''} stderr: {result.stderr.decode(encoding="utf8", errors="ignore") if len(result.stderr)>0 else ''} """ raise RuntimeError(message) return result.stdout.decode(encoding="utf8", errors="ignore") def check_run(command): result = subprocess.run( command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True ) return result.returncode == 0 def is_installed(package): try: spec = importlib.util.find_spec(package) except ModuleNotFoundError: return False return spec is not None def commit_hash(): global stored_commit_hash if stored_commit_hash is not None: return stored_commit_hash try: stored_commit_hash = run(f"{git} rev-parse HEAD").strip() except Exception: stored_commit_hash = "" return stored_commit_hash def run_pip(args, desc=None): if skip_install: return index_url_line = f" --index-url {index_url}" if index_url != "" else "" return run( f'"{python}" -m pip {args} --prefer-binary{index_url_line}', desc=f"Installing {desc}", errdesc=f"Couldn't install {desc}", ) def run_python(code, desc=None, errdesc=None): return run(f'"{python}" -c "{code}"', desc, errdesc) def extract_arg(args, name): return [x for x in args if x != name], name in args def prepare_environment(): commit = commit_hash() print(f"Python {sys.version}") print(f"Commit hash: {commit}") torch_command = os.environ.get( "TORCH_COMMAND", "pip install torch torchaudio --extra-index-url https://download.pytorch.org/whl/cu118", ) sys.argv, skip_install = extract_arg(sys.argv, "--skip-install") if skip_install: return sys.argv, reinstall_torch = extract_arg(sys.argv, "--reinstall-torch") ngrok = "--ngrok" in sys.argv if reinstall_torch or not is_installed("torch") or not is_installed("torchaudio"): run( f'"{python}" -m {torch_command}', "Installing torch and torchaudio", "Couldn't install torch", ) if not is_installed("pyngrok") and ngrok: run_pip("install pyngrok", "ngrok") run( f'"{python}" -m pip install -r requirements.txt', desc=f"Installing requirements", errdesc=f"Couldn't install requirements", ) def start(): os.environ["PATH"] = ( os.path.join(os.path.dirname(__file__), "bin") + os.pathsep + os.environ.get("PATH", "") ) subprocess.run( [python, "webui.py", *sys.argv[1:]], ) if __name__ == "__main__": prepare_environment() start()