Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python | |
| # pylint: disable=no-member | |
| """generate batches of images from prompts and upscale them | |
| params: run with `--help` | |
| default workflow runs infinite loop and prints stats when interrupted: | |
| 1. choose random scheduler lookup all available and pick one | |
| 2. generate dynamic prompt based on styles, embeddings, places, artists, suffixes | |
| 3. beautify prompt | |
| 4. generate 3x3 images | |
| 5. create image grid | |
| 6. upscale images with face restoration | |
| """ | |
| import argparse | |
| import asyncio | |
| import base64 | |
| import io | |
| import json | |
| import logging | |
| import math | |
| import os | |
| import pathlib | |
| import secrets | |
| import time | |
| import sys | |
| import importlib | |
| from random import randrange | |
| from PIL import Image | |
| from PIL.ExifTags import TAGS | |
| from PIL.TiffImagePlugin import ImageFileDirectory_v2 | |
| from sdapi import close, get, interrupt, post, session | |
| from util import Map, log, safestring | |
| sd = {} | |
| random = {} | |
| stats = Map({ 'images': 0, 'wall': 0, 'generate': 0, 'upscale': 0 }) | |
| avg = {} | |
| def grid(data): | |
| if len(data.image) > 1: | |
| w, h = data.image[0].size | |
| rows = round(math.sqrt(len(data.image))) | |
| cols = math.ceil(len(data.image) / rows) | |
| image = Image.new('RGB', size = (cols * w, rows * h), color = 'black') | |
| for i, img in enumerate(data.image): | |
| image.paste(img, box=(i % cols * w, i // cols * h)) | |
| short = data.info.prompt[:min(len(data.info.prompt), 96)] # limit prompt part of filename to 96 chars | |
| name = '{seed:0>9} {short}'.format(short = short, seed = data.info.all_seeds[0]) # pylint: disable=consider-using-f-string | |
| name = safestring(name) + '.jpg' | |
| f = os.path.join(sd.paths.root, sd.paths.grid, name) | |
| log.info({ 'grid': { 'name': f, 'size': image.size, 'images': len(data.image) } }) | |
| image.save(f, 'JPEG', exif = exif(data.info, None, 'grid'), optimize = True, quality = 70) | |
| return image | |
| return data.image | |
| def exif(info, i = None, op = 'generate'): | |
| seed = [info.all_seeds[i]] if len(info.all_seeds) > 0 and i is not None else info.all_seeds # always returns list | |
| seed = ', '.join([str(x) for x in seed]) # int list to str list to single str | |
| template = '{prompt} | negative {negative_prompt} | seed {s} | steps {steps} | cfgscale {cfg_scale} | sampler {sampler_name} | batch {batch_size} | timestamp {job_timestamp} | model {model} | vae {vae}'.format(s = seed, model = sd.options['sd_model_checkpoint'], vae = sd.options['sd_vae'], **info) # pylint: disable=consider-using-f-string | |
| if op == 'upscale': | |
| template += ' | faces gfpgan' if sd.upscale.gfpgan_visibility > 0 else '' | |
| template += ' | faces codeformer' if sd.upscale.codeformer_visibility > 0 else '' | |
| template += ' | upscale {resize}x {upscaler}'.format(resize = sd.upscale.upscaling_resize, upscaler = sd.upscale.upscaler_1) if sd.upscale.upscaler_1 != 'None' else '' # pylint: disable=consider-using-f-string | |
| template += ' | upscale {resize}x {upscaler}'.format(resize = sd.upscale.upscaling_resize, upscaler = sd.upscale.upscaler_2) if sd.upscale.upscaler_2 != 'None' else '' # pylint: disable=consider-using-f-string | |
| if op == 'grid': | |
| template += ' | grid {num}'.format(num = sd.generate.batch_size * sd.generate.n_iter) # pylint: disable=consider-using-f-string | |
| ifd = ImageFileDirectory_v2() | |
| exif_stream = io.BytesIO() | |
| _TAGS = {v: k for k, v in TAGS.items()} # enumerate possible exif tags | |
| ifd[_TAGS['ImageDescription']] = template | |
| ifd.save(exif_stream) | |
| val = b'Exif\x00\x00' + exif_stream.getvalue() | |
| return val | |
| def randomize(lst): | |
| if len(lst) > 0: | |
| return secrets.choice(lst) | |
| else: | |
| return '' | |
| def prompt(params): # generate dynamic prompt or use one if provided | |
| sd.generate.prompt = params.prompt if params.prompt != 'dynamic' else randomize(random.prompts) | |
| sd.generate.negative_prompt = params.negative if params.negative != 'dynamic' else randomize(random.negative) | |
| embedding = params.embedding if params.embedding != 'random' else randomize(random.embeddings) | |
| sd.generate.prompt = sd.generate.prompt.replace('<embedding>', embedding) | |
| artist = params.artist if params.artist != 'random' else randomize(random.artists) | |
| sd.generate.prompt = sd.generate.prompt.replace('<artist>', artist) | |
| style = params.style if params.style != 'random' else randomize(random.styles) | |
| sd.generate.prompt = sd.generate.prompt.replace('<style>', style) | |
| suffix = params.suffix if params.suffix != 'random' else randomize(random.suffixes) | |
| sd.generate.prompt = sd.generate.prompt.replace('<suffix>', suffix) | |
| place = params.suffix if params.suffix != 'random' else randomize(random.places) | |
| sd.generate.prompt = sd.generate.prompt.replace('<place>', place) | |
| if params.prompts or params.debug: | |
| log.info({ 'random initializers': random }) | |
| if params.prompt == 'dynamic': | |
| log.info({ 'dynamic prompt': sd.generate.prompt }) | |
| return sd.generate.prompt | |
| def sampler(params, options): # find sampler | |
| if params.sampler == 'random': | |
| sd.generate.sampler_name = randomize(options.samplers) | |
| log.info({ 'random sampler': sd.generate.sampler_name }) | |
| else: | |
| found = [i for i in options.samplers if i.startswith(params.sampler)] | |
| if len(found) == 0: | |
| log.error({ 'sampler error': sd.generate.sampler_name, 'available': options.samplers}) | |
| exit() | |
| sd.generate.sampler_name = found[0] | |
| return sd.generate.sampler_name | |
| async def generate(prompt = None, options = None, quiet = False): # pylint: disable=redefined-outer-name | |
| global sd # pylint: disable=global-statement | |
| if options: | |
| sd = Map(options) | |
| if prompt is not None: | |
| sd.generate.prompt = prompt | |
| if not quiet: | |
| log.info({ 'generate': sd.generate }) | |
| if sd.get('options', None) is None: | |
| sd['options'] = await get('/sdapi/v1/options') | |
| names = [] | |
| b64s = [] | |
| images = [] | |
| info = Map({}) | |
| data = await post('/sdapi/v1/txt2img', sd.generate) | |
| if 'error' in data: | |
| log.error({ 'generate': data['error'], 'reason': data['reason'] }) | |
| return Map({}) | |
| info = Map(json.loads(data['info'])) | |
| log.debug({ 'info': info }) | |
| images = data['images'] | |
| short = info.prompt[:min(len(info.prompt), 96)] # limit prompt part of filename to 64 chars | |
| for i in range(len(images)): | |
| b64s.append(images[i]) | |
| images[i] = Image.open(io.BytesIO(base64.b64decode(images[i].split(',',1)[0]))) | |
| name = '{seed:0>9} {short}'.format(short = short, seed = info.all_seeds[i]) # pylint: disable=consider-using-f-string | |
| name = safestring(name) + '.jpg' | |
| f = os.path.join(sd.paths.root, sd.paths.generate, name) | |
| names.append(f) | |
| if not quiet: | |
| log.info({ 'image': { 'name': f, 'size': images[i].size } }) | |
| images[i].save(f, 'JPEG', exif = exif(info, i), optimize = True, quality = 70) | |
| return Map({ 'name': names, 'image': images, 'b64': b64s, 'info': info }) | |
| async def upscale(data): | |
| data.upscaled = [] | |
| if sd.upscale.upscaling_resize <=1: | |
| return data | |
| sd.upscale.image = '' | |
| log.info({ 'upscale': sd.upscale }) | |
| for i in range(len(data.image)): | |
| f = data.name[i].replace(sd.paths.generate, sd.paths.upscale) | |
| sd.upscale.image = data.b64[i] | |
| res = await post('/sdapi/v1/extra-single-image', sd.upscale) | |
| image = Image.open(io.BytesIO(base64.b64decode(res['image'].split(',',1)[0]))) | |
| data.upscaled.append(image) | |
| log.info({ 'image': { 'name': f, 'size': image.size } }) | |
| image.save(f, 'JPEG', exif = exif(data.info, i, 'upscale'), optimize = True, quality = 70) | |
| return data | |
| async def init(): | |
| ''' | |
| import torch | |
| log.info({ 'torch': torch.__version__, 'available': torch.cuda.is_available() }) | |
| current_device = torch.cuda.current_device() | |
| mem_free, mem_total = torch.cuda.mem_get_info() | |
| log.info({ 'cuda': torch.version.cuda, 'available': torch.cuda.is_available(), 'arch': torch.cuda.get_arch_list(), 'device': torch.cuda.get_device_name(current_device), 'memory': { 'free': round(mem_free / 1024 / 1024), 'total': (mem_total / 1024 / 1024) } }) | |
| ''' | |
| options = Map({}) | |
| options.flags = await get('/sdapi/v1/cmd-flags') | |
| log.debug({ 'flags': options.flags }) | |
| data = await get('/sdapi/v1/sd-models') | |
| options.models = [obj['title'] for obj in data] | |
| log.debug({ 'registered models': options.models }) | |
| found = sd.options.sd_model_checkpoint if sd.options.sd_model_checkpoint in options.models else None | |
| if found is None: | |
| found = [i for i in options.models if i.startswith(sd.options.sd_model_checkpoint)] | |
| if len(found) == 0: | |
| log.error({ 'model error': sd.generate.sd_model_checkpoint, 'available': options.models}) | |
| exit() | |
| sd.options.sd_model_checkpoint = found[0] | |
| data = await get('/sdapi/v1/samplers') | |
| options.samplers = [obj['name'] for obj in data] | |
| log.debug({ 'registered samplers': options.samplers }) | |
| data = await get('/sdapi/v1/upscalers') | |
| options.upscalers = [obj['name'] for obj in data] | |
| log.debug({ 'registered upscalers': options.upscalers }) | |
| data = await get('/sdapi/v1/face-restorers') | |
| options.restorers = [obj['name'] for obj in data] | |
| log.debug({ 'registered face restorers': options.restorers }) | |
| await interrupt() | |
| await post('/sdapi/v1/options', sd.options) | |
| options.options = await get('/sdapi/v1/options') | |
| log.info({ 'target models': { 'diffuser': options.options['sd_model_checkpoint'], 'vae': options.options['sd_vae'] } }) | |
| log.info({ 'paths': sd.paths }) | |
| options.queue = await get('/queue/status') | |
| log.info({ 'queue': options.queue }) | |
| pathlib.Path(sd.paths.root).mkdir(parents = True, exist_ok = True) | |
| pathlib.Path(os.path.join(sd.paths.root, sd.paths.generate)).mkdir(parents = True, exist_ok = True) | |
| pathlib.Path(os.path.join(sd.paths.root, sd.paths.upscale)).mkdir(parents = True, exist_ok = True) | |
| pathlib.Path(os.path.join(sd.paths.root, sd.paths.grid)).mkdir(parents = True, exist_ok = True) | |
| return options | |
| def args(): # parse cmd arguments | |
| global sd # pylint: disable=global-statement | |
| global random # pylint: disable=global-statement | |
| parser = argparse.ArgumentParser(description = 'sd pipeline') | |
| parser.add_argument('--config', type = str, default = 'generate.json', required = False, help = 'configuration file') | |
| parser.add_argument('--random', type = str, default = 'random.json', required = False, help = 'prompt file with randomized sections') | |
| parser.add_argument('--max', type = int, default = 1, required = False, help = 'maximum number of generated images') | |
| parser.add_argument('--prompt', type = str, default = 'dynamic', required = False, help = 'prompt') | |
| parser.add_argument('--negative', type = str, default = 'dynamic', required = False, help = 'negative prompt') | |
| parser.add_argument('--artist', type = str, default = 'random', required = False, help = 'artist style, used to guide dynamic prompt when prompt is not provided') | |
| parser.add_argument('--embedding', type = str, default = 'random', required = False, help = 'use embedding, used to guide dynamic prompt when prompt is not provided') | |
| parser.add_argument('--style', type = str, default = 'random', required = False, help = 'image style, used to guide dynamic prompt when prompt is not provided') | |
| parser.add_argument('--suffix', type = str, default = 'random', required = False, help = 'style suffix, used to guide dynamic prompt when prompt is not provided') | |
| parser.add_argument('--place', type = str, default = 'random', required = False, help = 'place locator, used to guide dynamic prompt when prompt is not provided') | |
| parser.add_argument('--faces', default = False, action='store_true', help = 'restore faces during upscaling') | |
| parser.add_argument('--steps', type = int, default = 0, required = False, help = 'number of steps') | |
| parser.add_argument('--batch', type = int, default = 0, required = False, help = 'batch size, limited by gpu vram') | |
| parser.add_argument('--n', type = int, default = 0, required = False, help = 'number of iterations') | |
| parser.add_argument('--cfg', type = int, default = 0, required = False, help = 'classifier free guidance scale') | |
| parser.add_argument('--sampler', type = str, default = 'random', required = False, help = 'sampler') | |
| parser.add_argument('--seed', type = int, default = 0, required = False, help = 'seed, default is random') | |
| parser.add_argument('--upscale', type = int, default = 0, required = False, help = 'upscale factor, disabled if 0') | |
| parser.add_argument('--model', type = str, default = '', required = False, help = 'diffusion model') | |
| parser.add_argument('--vae', type = str, default = '', required = False, help = 'vae model') | |
| parser.add_argument('--path', type = str, default = '', required = False, help = 'output path') | |
| parser.add_argument('--width', type = int, default = 0, required = False, help = 'width') | |
| parser.add_argument('--height', type = int, default = 0, required = False, help = 'height') | |
| parser.add_argument('--beautify', default = False, action='store_true', help = 'beautify prompt') | |
| parser.add_argument('--prompts', default = False, action='store_true', help = 'print dynamic prompt templates') | |
| parser.add_argument('--debug', default = False, action='store_true', help = 'print extra debug information') | |
| params = parser.parse_args() | |
| if params.debug: | |
| log.setLevel(logging.DEBUG) | |
| log.debug({ 'debug': True }) | |
| log.debug({ 'args': params.__dict__ }) | |
| home = pathlib.Path(sys.argv[0]).parent | |
| if os.path.isfile(params.config): | |
| try: | |
| with open(params.config, 'r', encoding='utf-8') as f: | |
| data = json.load(f) | |
| sd = Map(data) | |
| log.debug({ 'config': sd }) | |
| except Exception as e: | |
| log.error({ 'config error': params.config, 'exception': e }) | |
| exit() | |
| elif os.path.isfile(os.path.join(home, params.config)): | |
| try: | |
| with open(os.path.join(home, params.config), 'r', encoding='utf-8') as f: | |
| data = json.load(f) | |
| sd = Map(data) | |
| log.debug({ 'config': sd }) | |
| except Exception as e: | |
| log.error({ 'config error': params.config, 'exception': e }) | |
| exit() | |
| else: | |
| log.error({ 'config file not found': params.config}) | |
| exit() | |
| if params.prompt == 'dynamic': | |
| log.info({ 'prompt template': params.random }) | |
| if os.path.isfile(params.random): | |
| try: | |
| with open(params.random, 'r', encoding='utf-8') as f: | |
| data = json.load(f) | |
| random = Map(data) | |
| log.debug({ 'random template': sd }) | |
| except Exception: | |
| log.error({ 'random template error': params.random}) | |
| exit() | |
| elif os.path.isfile(os.path.join(home, params.random)): | |
| try: | |
| with open(os.path.join(home, params.random), 'r', encoding='utf-8') as f: | |
| data = json.load(f) | |
| random = Map(data) | |
| log.debug({ 'random template': sd }) | |
| except Exception: | |
| log.error({ 'random template error': params.random}) | |
| exit() | |
| else: | |
| log.error({ 'random template file not found': params.random}) | |
| exit() | |
| _dynamic = prompt(params) | |
| sd.paths.root = params.path if params.path != '' else sd.paths.root | |
| sd.generate.restore_faces = params.faces if params.faces is not None else sd.generate.restore_faces | |
| sd.generate.seed = params.seed if params.seed > 0 else sd.generate.seed | |
| sd.generate.sampler_name = params.sampler if params.sampler != 'random' else sd.generate.sampler_name | |
| sd.generate.batch_size = params.batch if params.batch > 0 else sd.generate.batch_size | |
| sd.generate.cfg_scale = params.cfg if params.cfg > 0 else sd.generate.cfg_scale | |
| sd.generate.n_iter = params.n if params.n > 0 else sd.generate.n_iter | |
| sd.generate.width = params.width if params.width > 0 else sd.generate.width | |
| sd.generate.height = params.height if params.height > 0 else sd.generate.height | |
| sd.generate.steps = params.steps if params.steps > 0 else sd.generate.steps | |
| sd.upscale.upscaling_resize = params.upscale if params.upscale > 0 else sd.upscale.upscaling_resize | |
| sd.upscale.codeformer_visibility = 1 if params.faces else sd.upscale.codeformer_visibility | |
| sd.options.sd_vae = params.vae if params.vae != '' else sd.options.sd_vae | |
| sd.options.sd_model_checkpoint = params.model if params.model != '' else sd.options.sd_model_checkpoint | |
| sd.upscale.upscaler_1 = 'SwinIR_4x' if params.upscale > 1 else sd.upscale.upscaler_1 | |
| if sd.generate.cfg_scale == 0: | |
| sd.generate.cfg_scale = randrange(5, 10) | |
| return params | |
| async def main(): | |
| params = args() | |
| sess = await session() | |
| if sess is None: | |
| await close() | |
| exit() | |
| options = await init() | |
| iteration = 0 | |
| while True: | |
| iteration += 1 | |
| log.info('') | |
| log.info({ 'iteration': iteration, 'batch': sd.generate.batch_size, 'n': sd.generate.n_iter, 'total': sd.generate.n_iter * sd.generate.batch_size }) | |
| dynamic = prompt(params) | |
| if params.beautify: | |
| try: | |
| promptist = importlib.import_module('modules.promptist') | |
| sd.generate.prompt = promptist.beautify(dynamic) | |
| except Exception as e: | |
| log.error({ 'beautify': e }) | |
| scheduler = sampler(params, options) | |
| t0 = time.perf_counter() | |
| data = await generate() # generate returns list of images | |
| if 'image' not in data: | |
| break | |
| stats.images += len(data.image) | |
| t1 = time.perf_counter() | |
| if len(data.image) > 0: | |
| avg[scheduler] = (t1 - t0) / len(data.image) | |
| stats.generate += t1 - t0 | |
| _image = grid(data) | |
| data = await upscale(data) | |
| t2 = time.perf_counter() | |
| stats.upscale += t2 - t1 | |
| stats.wall += t2 - t0 | |
| its = sd.generate.steps / ((t1 - t0) / len(data.image)) if len(data.image) > 0 else 0 | |
| avg_time = round((t1 - t0) / len(data.image)) if len(data.image) > 0 else 0 | |
| log.info({ 'time' : { 'wall': round(t1 - t0), 'average': avg_time, 'upscale': round(t2 - t1), 'its': round(its, 2) } }) | |
| log.info({ 'generated': stats.images, 'max': params.max, 'progress': round(100 * stats.images / params.max, 1) }) | |
| if params.max != 0 and stats.images >= params.max: | |
| break | |
| if __name__ == '__main__': | |
| try: | |
| asyncio.run(main()) | |
| except KeyboardInterrupt: | |
| asyncio.run(interrupt()) | |
| asyncio.run(close()) | |
| log.info({ 'interrupt': True }) | |
| finally: | |
| log.info({ 'sampler performance': avg }) | |
| log.info({ 'stats' : stats }) | |
| asyncio.run(close()) | |