Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python | |
| #pylint: disable=redefined-outer-name | |
| """ | |
| helper methods that creates HTTP session with managed connection pool | |
| provides async HTTP get/post methods and several helper methods | |
| """ | |
| import io | |
| import os | |
| import sys | |
| import ssl | |
| import base64 | |
| import asyncio | |
| import logging | |
| import aiohttp | |
| import requests | |
| import urllib3 | |
| from PIL import Image | |
| from util import Map, log | |
| from rich import print # pylint: disable=redefined-builtin | |
| sd_url = os.environ.get('SDAPI_URL', "http://127.0.0.1:7860") # api url root | |
| sd_username = os.environ.get('SDAPI_USR', None) | |
| sd_password = os.environ.get('SDAPI_PWD', None) | |
| use_session = True | |
| urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) | |
| ssl.create_default_context = ssl._create_unverified_context # pylint: disable=protected-access | |
| timeout = aiohttp.ClientTimeout(total = None, sock_connect = 10, sock_read = None) # default value is 5 minutes, we need longer for training | |
| sess = None | |
| quiet = False | |
| BaseThreadPolicy = asyncio.WindowsSelectorEventLoopPolicy if sys.platform == "win32" and hasattr(asyncio, "WindowsSelectorEventLoopPolicy") else asyncio.DefaultEventLoopPolicy | |
| class AnyThreadEventLoopPolicy(BaseThreadPolicy): | |
| def get_event_loop(self) -> asyncio.AbstractEventLoop: | |
| try: | |
| return super().get_event_loop() | |
| except (RuntimeError, AssertionError): | |
| loop = self.new_event_loop() | |
| self.set_event_loop(loop) | |
| return loop | |
| asyncio.set_event_loop_policy(AnyThreadEventLoopPolicy()) | |
| def authsync(): | |
| if sd_username is not None and sd_password is not None: | |
| return requests.auth.HTTPBasicAuth(sd_username, sd_password) | |
| return None | |
| def auth(): | |
| if sd_username is not None and sd_password is not None: | |
| return aiohttp.BasicAuth(sd_username, sd_password) | |
| return None | |
| async def result(req): | |
| if req.status != 200: | |
| if not quiet: | |
| log.error({ 'request error': req.status, 'reason': req.reason, 'url': req.url }) | |
| if not use_session and sess is not None: | |
| await sess.close() | |
| return Map({ 'error': req.status, 'reason': req.reason, 'url': req.url }) | |
| else: | |
| json = await req.json() | |
| if isinstance(json, list): | |
| res = json | |
| elif json is None: | |
| res = {} | |
| else: | |
| res = Map(json) | |
| log.debug({ 'request': req.status, 'url': req.url, 'reason': req.reason }) | |
| return res | |
| def resultsync(req: requests.Response): | |
| if req.status_code != 200: | |
| if not quiet: | |
| log.error({ 'request error': req.status_code, 'reason': req.reason, 'url': req.url }) | |
| return Map({ 'error': req.status_code, 'reason': req.reason, 'url': req.url }) | |
| else: | |
| json = req.json() | |
| if isinstance(json, list): | |
| res = json | |
| elif json is None: | |
| res = {} | |
| else: | |
| res = Map(json) | |
| log.debug({ 'request': req.status_code, 'url': req.url, 'reason': req.reason }) | |
| return res | |
| async def get(endpoint: str, json: dict = None): | |
| global sess # pylint: disable=global-statement | |
| sess = sess if sess is not None else await session() | |
| try: | |
| async with sess.get(url=endpoint, json=json, verify_ssl=False) as req: | |
| res = await result(req) | |
| return res | |
| except Exception as err: | |
| log.error({ 'session': err }) | |
| return {} | |
| def getsync(endpoint: str, json: dict = None): | |
| try: | |
| req = requests.get(f'{sd_url}{endpoint}', json=json, verify=False, auth=authsync()) # pylint: disable=missing-timeout | |
| res = resultsync(req) | |
| return res | |
| except Exception as err: | |
| log.error({ 'session': err }) | |
| return {} | |
| async def post(endpoint: str, json: dict = None): | |
| global sess # pylint: disable=global-statement | |
| # sess = sess if sess is not None else await session() | |
| if sess and not sess.closed: | |
| await sess.close() | |
| sess = await session() | |
| try: | |
| async with sess.post(url=endpoint, json=json, verify_ssl=False) as req: | |
| res = await result(req) | |
| return res | |
| except Exception as err: | |
| log.error({ 'session': err }) | |
| return {} | |
| def postsync(endpoint: str, json: dict = None): | |
| req = requests.post(f'{sd_url}{endpoint}', json=json, verify=False, auth=authsync()) # pylint: disable=missing-timeout | |
| res = resultsync(req) | |
| return res | |
| async def interrupt(): | |
| res = await get('/sdapi/v1/progress?skip_current_image=true') | |
| if 'state' in res and res.state.job_count > 0: | |
| log.debug({ 'interrupt': res.state }) | |
| res = await post('/sdapi/v1/interrupt') | |
| await asyncio.sleep(1) | |
| return res | |
| else: | |
| log.debug({ 'interrupt': 'idle' }) | |
| return { 'interrupt': 'idle' } | |
| def interruptsync(): | |
| res = getsync('/sdapi/v1/progress?skip_current_image=true') | |
| if 'state' in res and res.state.job_count > 0: | |
| log.debug({ 'interrupt': res.state }) | |
| res = postsync('/sdapi/v1/interrupt') | |
| return res | |
| else: | |
| log.debug({ 'interrupt': 'idle' }) | |
| return { 'interrupt': 'idle' } | |
| async def progress(): | |
| res = await get('/sdapi/v1/progress?skip_current_image=false') | |
| try: | |
| if res is not None and res.get('current_image', None) is not None: | |
| res.current_image = Image.open(io.BytesIO(base64.b64decode(res['current_image']))) | |
| except Exception: | |
| pass | |
| log.debug({ 'progress': res }) | |
| return res | |
| def progresssync(): | |
| res = getsync('/sdapi/v1/progress?skip_current_image=true') | |
| log.debug({ 'progress': res }) | |
| return res | |
| def get_log(): | |
| res = getsync('/sdapi/v1/log') | |
| for line in res: | |
| log.debug(line) | |
| return res | |
| def get_info(): | |
| import time | |
| t0 = time.time() | |
| res = getsync('/sdapi/v1/system-info/status?full=true&refresh=true') | |
| t1 = time.time() | |
| print({ 'duration': 1000 * round(t1-t0, 3), **res }) | |
| return res | |
| def options(): | |
| opts = getsync('/sdapi/v1/options') | |
| flags = getsync('/sdapi/v1/cmd-flags') | |
| return { 'options': opts, 'flags': flags } | |
| def shutdown(): | |
| try: | |
| postsync('/sdapi/v1/shutdown') | |
| except Exception as e: | |
| log.info({ 'shutdown': e }) | |
| async def session(): | |
| global sess # pylint: disable=global-statement | |
| time = aiohttp.ClientTimeout(total = None, sock_connect = 10, sock_read = None) # default value is 5 minutes, we need longer for training | |
| sess = aiohttp.ClientSession(timeout = time, base_url = sd_url, auth=auth()) | |
| log.debug({ 'sdapi': 'session created', 'endpoint': sd_url }) | |
| """ | |
| sess = await aiohttp.ClientSession(timeout = timeout).__aenter__() | |
| try: | |
| async with sess.get(url = f'{sd_url}/') as req: | |
| log.debug({ 'sdapi': 'session created', 'endpoint': sd_url }) | |
| except Exception as e: | |
| log.error({ 'sdapi': e }) | |
| await asyncio.sleep(0) | |
| await sess.__aexit__(None, None, None) | |
| sess = None | |
| return sess | |
| """ | |
| return sess | |
| async def close(): | |
| if sess is not None: | |
| await asyncio.sleep(0) | |
| await sess.close() | |
| await sess.__aexit__(None, None, None) | |
| log.debug({ 'sdapi': 'session closed', 'endpoint': sd_url }) | |
| if __name__ == "__main__": | |
| sys.argv.pop(0) | |
| log.setLevel(logging.DEBUG) | |
| if 'interrupt' in sys.argv: | |
| asyncio.run(interrupt()) | |
| elif 'progress' in sys.argv: | |
| asyncio.run(progress()) | |
| elif 'progresssync' in sys.argv: | |
| progresssync() | |
| elif 'options' in sys.argv: | |
| opt = options() | |
| log.debug({ 'options' }) | |
| import json | |
| print(json.dumps(opt['options'], indent = 2)) | |
| log.debug({ 'cmd-flags' }) | |
| print(json.dumps(opt['flags'], indent = 2)) | |
| elif 'log' in sys.argv: | |
| get_log() | |
| elif 'info' in sys.argv: | |
| get_info() | |
| elif 'shutdown' in sys.argv: | |
| shutdown() | |
| else: | |
| res = getsync(sys.argv[0]) | |
| print(res) | |
| asyncio.run(close(), debug=True) | |
| asyncio.run(asyncio.sleep(0.5)) | |