File size: 5,871 Bytes
9c8d2ad
 
 
 
 
 
 
 
4bbe159
9c8d2ad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d44afdb
31f2f07
9c8d2ad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
195e56f
9c8d2ad
 
 
 
 
 
 
195e56f
9c8d2ad
 
 
 
 
 
3c18c09
9c8d2ad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d44afdb
 
 
 
9c8d2ad
4bbe159
 
 
9c8d2ad
4bbe159
9c8d2ad
 
4bbe159
 
 
 
 
 
 
 
195e56f
4bbe159
 
 
 
48e3224
4bbe159
 
 
 
 
 
 
 
 
 
 
 
 
 
48e3224
 
4bbe159
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
# Prediction interface for Cog ⚙️
# https://github.com/replicate/cog/blob/main/docs/python.md

from cog import BasePredictor, Input, Path

import os
import sys
import signal
import time
import re
from typing import Dict, List, Any

import logging
logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())

from modules import errors
from modules.call_queue import wrap_queued_call, queue_lock, wrap_gradio_gpu_call

import torch

# Truncate version number of nightly/local build of PyTorch to not cause exceptions with CodeFormer or Safetensors
if ".dev" in torch.__version__ or "+git" in torch.__version__:
    torch.__long_version__ = torch.__version__
    torch.__version__ = re.search(r'[\d.]+[\d]', torch.__version__).group(0)

from modules import shared, devices, ui_tempdir, extra_networks_hypernet, extra_networks
from modules.api.api import encode_pil_to_base64
import modules.codeformer_model as codeformer
import modules.face_restoration
import modules.gfpgan_model as gfpgan
import modules.img2img

import modules.lowvram
import modules.paths
import modules.scripts
import modules.sd_hijack
import modules.sd_models
import modules.sd_vae
import modules.txt2img
import modules.script_callbacks
import modules.textual_inversion.textual_inversion
import modules.progress

import modules.ui
from modules import modelloader, extensions
from modules.shared import cmd_opts, opts
import modules.hypernetworks.hypernetwork

from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images


def initialize():
    extensions.list_extensions()
    modelloader.cleanup_models()
    modules.sd_models.setup_model()
    codeformer.setup_model(cmd_opts.codeformer_models_path)
    gfpgan.setup_model(cmd_opts.gfpgan_models_path)

    modelloader.list_builtin_upscalers()
    modules.scripts.load_scripts()
    modelloader.load_upscalers()
    modules.sd_vae.refresh_vae_list()

    try:
        modules.sd_models.load_model()
    except Exception as e:
        errors.display(e, "loading stable diffusion model")
        print("", file=sys.stderr)
        print("Stable diffusion model failed to load, exiting", file=sys.stderr)
        exit(1)

    shared.opts.data["sd_model_checkpoint"] = shared.sd_model.sd_checkpoint_info.title
    shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights()))
    shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
    shared.opts.onchange("sd_vae_as_default", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
    shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed)

    shared.reload_hypernetworks()
    extra_networks.initialize()
    extra_networks.register_extra_network(extra_networks_hypernet.ExtraNetworkHypernet())
    modules.script_callbacks.before_ui_callback()
    # make the program just exit at ctrl+c without waiting for anything
    # def sigint_handler(sig, frame):
    #     print(f'Interrupted with signal {sig} in {frame}')
    #     os._exit(0)

    # signal.signal(signal.SIGINT, sigint_handler)

class Predictor(BasePredictor):
    def setup(self):
        """Load the model into memory to make running multiple predictions efficient"""
        initialize()

    def predict(
        self,
        prompt: str = Input(description="prompt en", default="lora:koreanDollLikeness_v15:0.66, best quality, ultra high res, (photorealistic:1.4), 1girl, beige sweater, black choker, smile, laughing, bare shoulders, solo focus, ((full body), (brown hair:1), looking at viewer"),
        negative_prompt: str = Input(description="negative prompt", default="paintings, sketches, (worst quality:2), (low quality:2), (normal quality:2), lowres, normal quality, ((monochrome)), ((grayscale)), skin spots, acnes, skin blemishes, age spot, glans, (ugly:1.331), (duplicate:1.331), (morbid:1.21), (mutilated:1.21), (tranny:1.331), mutated hands, (poorly drawn hands:1.331), blurry, 3hands,4fingers,3arms, bad anatomy, missing fingers, extra digit, fewer digits, cropped, jpeg artifacts,poorly drawn face,mutation,deformed"),
        sampler_name: str = Input(description="sampler name", default="DPM++ 2M Karras", choices=["DPM++ SDE Karras", "DPM++ 2M Karras", "DPM++ 2S a Karras", "DPM2 a Karras", "DPM2 Karras", "LMS Karras", "DPM adaptive", "DPM fast", "DPM++ SDE", "DPM++ 2M", "DPM++ 2S a", "DPM2 a", "DPM2", "Heun", "LMS", "Euler", "Euler a"]),
        steps: int = Input(description="steps", default=20),
        cfg_scale: int = Input(description="cfg scale", default=8),
        width: int = Input(description="width", default=512),
        height: int = Input(description="height", default=768),
        enable_hr: bool = Input(description="Generate high resoultion version", default=False),
        seed: int = Input(description="seed", default=-1),
    ) -> Path:
        """Run a single prediction on the model"""
        args = {
            "do_not_save_samples": True,
            "do_not_save_grid": True,
            "outpath_samples": "./output",
            "prompt": prompt,
            "negative_prompt": negative_prompt,
            "sampler_name": sampler_name,
            "steps": steps, # 25
            "cfg_scale": cfg_scale,
            "width": width,
            "height": height,
            "enable_hr": enable_hr,
            "hr_upscaler": "R-ESRGAN 4x+",
            "seed": seed,
        }
        p = StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **args)
        processed = process_images(p)
        filename = str(int(time.time())) + ".png"
        processed.images[0].save(fp=filename, format="PNG")
        # single_image_b64 = encode_pil_to_base64(processed.images[0]).decode('utf-8')
        return Path(filename)