File size: 13,708 Bytes
c336648
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
import time
from typing import Union

import gradio
from jinja2 import Template
from modules import script_callbacks, scripts
from modules.processing import Processed, StableDiffusionProcessing
from modules.ui_components import InputAccordion

from lib.custom_text_overlay.align import Position
from lib.custom_text_overlay.drawText import drawText
from lib.custom_text_overlay.logger import logger
from lib.custom_text_overlay.options import getOption, onUiSettings
from src.custom_text_overlay.extension import extensionId, extensionTitle

templateBracketLeft = '{{'
templateBracketRight = '}}'

normalFontSize = 32 # for 1024x1024 pixels or anything with the same megapixel value
hookType = 'postprocess_image'

keysFromImg = ['cfg_scale', 'width', 'height', 'seed', 'subseed', 'prompt', 'negative_prompt', 'steps']

specialReplacements = {
  'all_seeds': 'seed',
  'all_subseeds': 'subseed',
  'all_prompts': 'prompt',
  'all_negative_prompts': 'negative_prompt',
}

class CustomTextOverlay(scripts.Script):
  def title(self):
    return extensionTitle

  def show(self, is_img2img):
    return scripts.AlwaysVisible

  def ui(self, is_img2img):
    minWidth = 200
    templateEngine = getOption('template_engine')
    useExamples = getOption('examples')

    def getTemplateInput(positionLabel: str, defaultValue: str = '') -> tuple[gradio.Checkbox, gradio.Textbox]:
      if not useExamples:
        defaultValue = ''
      checkbox = gradio.Checkbox(label=f'{positionLabel} text', value=True)
      textbox = gradio.Textbox(label=f'{positionLabel} text template', value=defaultValue, lines=1, show_label=False)
      return (checkbox, textbox)

    with InputAccordion(False, label=extensionTitle, elem_id=self.elem_id(extensionId)) as enabled:
      with gradio.Accordion('Text', open=True):
        with gradio.Row():
          with gradio.Column(min_width=minWidth):
            timeTemplate = "{{ ('%.1f'|format(time)).rstrip('0').rstrip('.') }} sec\n{{ ('%.1f'|format(steps / time)).rstrip('0').rstrip('.') }} steps/s" if templateEngine == 'jinja2' else '{{time}}s'
            (textEnabled1, textTemplate1) = getTemplateInput('Top left', timeTemplate)
          with gradio.Column(min_width=minWidth):
            (textEnabled2, textTemplate2) = getTemplateInput('Top')
          with gradio.Column(min_width=minWidth):
            (textEnabled3, textTemplate3) = getTemplateInput('Top right')
        with gradio.Row():
          with gradio.Column(min_width=minWidth):
            (textEnabled4, textTemplate4) = getTemplateInput('Left')
          with gradio.Column(min_width=minWidth):
            (textEnabled5, textTemplate5) = getTemplateInput('Center')
          with gradio.Column(min_width=minWidth):
            (textEnabled6, textTemplate6) = getTemplateInput('Right')
        with gradio.Row():
          with gradio.Column(min_width=minWidth):
            seedTemplate = 'Seed {{seed}}'
            (textEnabled7, textTemplate7) = getTemplateInput('Bottom left', seedTemplate)
          with gradio.Column(min_width=minWidth):
            (textEnabled8, textTemplate8) = getTemplateInput('Bottom')
          with gradio.Column(min_width=minWidth):
            (textEnabled9, textTemplate9) = getTemplateInput('Bottom right')
      with gradio.Accordion('Style'):
        gradio.HTML('<p><h2>General</h2><hr style="margin-top: 0; margin-bottom: 1em"/></p>')
        with gradio.Row():
          with gradio.Column(min_width=minWidth):
            textColor = gradio.ColorPicker(label='Text color', value='#ffffff')
            textScale = gradio.Slider(minimum=20, maximum=300, step=10, label='Text scale', value=120)
          with gradio.Column(min_width=minWidth):
            paddingScale = gradio.Slider(minimum=0, maximum=200, step=5, label='Padding scale', value=25)
            marginScale = gradio.Slider(minimum=0, maximum=200, step=5, label='Margin scale', value=0)
        gradio.HTML('<p style="margin-top: 2em"><h2>Outline</h2><hr style="margin-top: 0; margin-bottom: 1em"/></p>')
        with gradio.Row():
          outlineScale = gradio.Slider(minimum=0, maximum=25, step=1, label='Outline scale', value=12)
          outlineColor = gradio.ColorPicker(label='Outline color', value='#000000')
          outlineOpacity = gradio.Slider(minimum=0, maximum=100, step=5, label='Outline opacity', value=100)
        gradio.HTML('<p style="margin-top: 2em"><h2>Background Box</h2><hr style="margin-top: 0; margin-bottom: 1em"/></p>')
        with gradio.Row():
          backgroundColor = gradio.ColorPicker(label='Background color', value='#000000')
          backgroundOpacity = gradio.Slider(minimum=0, maximum=100, step=5, label='Background opacity', value=0)

    return [enabled, textScale, textColor, backgroundColor, backgroundOpacity, paddingScale, marginScale, outlineScale, outlineColor, outlineOpacity, textEnabled1, textEnabled2, textEnabled3, textEnabled4, textEnabled5, textEnabled6, textEnabled7, textEnabled8, textEnabled9, textTemplate1, textTemplate2, textTemplate3, textTemplate4, textTemplate5, textTemplate6, textTemplate7, textTemplate8, textTemplate9]

  def process(self, processing: StableDiffusionProcessing, enabled: bool, textScale: int, textColor: str, backgroundColor: str, backgroundOpacity: int, paddingScale: int, marginScale: int, outlineScale: int, outlineColor: str, outlineOpacity: int, textEnabled1: bool, textEnabled2: bool, textEnabled3: bool, textEnabled4: bool, textEnabled5: bool, textEnabled6: bool, textEnabled7: bool, textEnabled8: bool, textEnabled9: bool, textTemplate1: str, textTemplate2: str, textTemplate3: str, textTemplate4: str, textTemplate5: str, textTemplate6: str, textTemplate7: str, textTemplate8: str, textTemplate9: str):
    if not enabled:
      return
    self.startTime = time.perf_counter()

  def collectReplacements(self, staticReplacements: dict = {}, replacementSources: dict = {}, imageIndex: int = 0, timeSeconds: float = 0):
    tempateEngine = getOption('template_engine')
    logger.info(tempateEngine)
    replacements = self.makeReplacementTable(staticReplacements, keysFromImg, replacementSources)
    if timeSeconds is not None:
      if tempateEngine == 'jinja2':
        replacements['time'] = timeSeconds
      else:
        replacements['time.00'] = f'{timeSeconds:.2f}'
        replacements['time.0'] = f'{timeSeconds:.1f}'
        replacements['time'] = int(timeSeconds)
    if imageIndex is not None:
      replacements['image_index'] = imageIndex
    if tempateEngine == 'jinja2':
      for sourceName, source in replacementSources.items():
        replacements[sourceName] = source
    for arrayKey, singleKey in specialReplacements.items():
      for replacementSource in replacementSources.values():
        if hasattr(replacementSource, arrayKey):
          value = getattr(replacementSource, arrayKey)
          if value is not None:
            if imageIndex < len(value):
              replacements[singleKey] = str(value[imageIndex])
              break
    return replacements

  def makeReplacementTable(self, baseReplacements: dict, needles: list, haystacks: dict):
    replacements = baseReplacements.copy()
    for needle in needles:
      for hackstackName, hackstack in haystacks.items():
        value = getattr(hackstack, needle, None)
        if value is not None:
          logger.debug(f'{templateBracketLeft}{needle}{templateBracketRight} = {hackstackName}.{needle}')
          replacements[needle] = value
          break
      else:
        logger.debug(f'{templateBracketLeft}{needle}{templateBracketRight} = ?')
        replacements[needle] = '?'
    return replacements

  def applyReplacements(self, templateString: str, replacements: dict):
    templateEngine = getOption('template_engine')
    templateHandler = self.applyReplacementsJinja if templateEngine == 'jinja2' else self.applyReplacementsBasic
    inputTemplate = templateString.strip()
    output = templateHandler(inputTemplate, replacements)
    if inputTemplate == output:
      return inputTemplate
    logger.debug(f'Resolving template “{inputTemplate}” to “{output}” with engine {templateEngine} and keys {", ".join(replacements.keys())}')
    return output

  def applyReplacementsJinja(self, templateString: str, replacements: dict):
    jinjaTemplate = Template(templateString)
    return jinjaTemplate.render(replacements)

  def applyReplacementsBasic(self, templateString: str, replacements: dict):
    for key, value in replacements.items():
      if isinstance(value, int) or isinstance(value, float):
        value = str(value)
      if isinstance(value, list):
        value = ', '.join(value)
      if not isinstance(value, str):
        continue
      templateString = templateString.replace(f'{templateBracketLeft}{key}{templateBracketRight}', value)
    return templateString

  def postprocess_image(self, processing: StableDiffusionProcessing, processed, enabled: bool, textScale: int, textColor: str, backgroundColor: str, backgroundOpacity: int, paddingScale: int, marginScale: int, outlineScale: int, outlineColor: str, outlineOpacity: int, textEnabled1: bool, textEnabled2: bool, textEnabled3: bool, textEnabled4: bool, textEnabled5: bool, textEnabled6: bool, textEnabled7: bool, textEnabled8: bool, textEnabled9: bool, textTemplate1: str, textTemplate2: str, textTemplate3: str, textTemplate4: str, textTemplate5: str, textTemplate6: str, textTemplate7: str, textTemplate8: str, textTemplate9: str):
    if not enabled or hookType != 'postprocess_image':
      return
    enabledTextTemplates: list[int] = []
    for textTemplateIndex in range(1, 10):
      if locals()[f'textEnabled{textTemplateIndex}'] and locals()[f'textTemplate{textTemplateIndex}'].strip() != '':
        enabledTextTemplates.append(textTemplateIndex)
    if len(enabledTextTemplates) == 0:
      return
    endTime = time.perf_counter()
    generationTimeSeconds = endTime - self.startTime
    self.startTime = endTime
    furtherReplacementsSources = {
      'img': processed.image,
      'processed': processed,
      'processing': processing,
    }
    replacements = self.collectReplacements(timeSeconds=generationTimeSeconds, replacementSources=furtherReplacementsSources, imageIndex=processing.iteration)

    for textTemplateIndex in enabledTextTemplates:
      textTemplate = locals()[f'textTemplate{textTemplateIndex}'].strip()
      text = self.applyReplacements(textTemplate, replacements)
      position = Position(textTemplateIndex)
      pixels = processed.image.width * processed.image.height
      fontSize = int((textScale / 100) * (normalFontSize * (pixels / 1048576)))
      margin = int((marginScale / 100) * fontSize)
      padding = int((paddingScale / 100) * fontSize)
      outline = int((outlineScale / 100) * fontSize)
      processed.image = drawText(processed.image, text=text, fontSize=fontSize, textColor=textColor, position=position, backgroundColor=backgroundColor, backgroundOpacity=backgroundOpacity, margin=margin, padding=padding, outline=outline, outlineColor=outlineColor, outlineOpacity=outlineOpacity)

  def postprocess(self, processing: StableDiffusionProcessing, processed: Processed, enabled: bool, textScale: int, textColor: str, backgroundColor: str, backgroundOpacity: int, paddingScale: int, marginScale: int, outlineScale: int, outlineColor: str, outlineOpacity: int, textEnabled1: bool, textEnabled2: bool, textEnabled3: bool, textEnabled4: bool, textEnabled5: bool, textEnabled6: bool, textEnabled7: bool, textEnabled8: bool, textEnabled9: bool, textTemplate1: str, textTemplate2: str, textTemplate3: str, textTemplate4: str, textTemplate5: str, textTemplate6: str, textTemplate7: str, textTemplate8: str, textTemplate9: str):
    if not enabled or hookType != 'postprocess':
      return
    enabledTextTemplates: list[int] = []
    for textTemplateIndex in range(1, 10):
      if locals()[f'textEnabled{textTemplateIndex}'] and locals()[f'textTemplate{textTemplateIndex}'].strip() != '':
        enabledTextTemplates.append(textTemplateIndex)
    if len(enabledTextTemplates) == 0:
      return
    endTime = time.perf_counter()
    generationTimeSeconds = endTime - self.startTime

    images = processed.images[processed.index_of_first_image:]
    for imageIndex in range(len(images)):
      processedImageIndex = imageIndex + processed.index_of_first_image
      img = images[imageIndex]
      replacements = {
        'processed_image_index': processedImageIndex,
        'real_image_index': imageIndex,
      }
      furtherReplacementsSources = {
        'img': img,
        'processed': processed,
        'processing': processing,
      }
      replacements = self.collectReplacements(staticReplacements=replacements, replacementSources=furtherReplacementsSources, imageIndex=imageIndex, timeSeconds=generationTimeSeconds)

      for textTemplateIndex in enabledTextTemplates:
        textTemplate = locals()[f'textTemplate{textTemplateIndex}'].strip()
        text = self.applyReplacements(textTemplate, replacements)
        position = Position(textTemplateIndex)
        pixels = img.width * img.height
        fontSize = int((textScale / 100) * (normalFontSize * (pixels / 1048576)))
        margin = int((marginScale / 100) * fontSize)
        padding = int((paddingScale / 100) * fontSize)
        outline = int((outlineScale / 100) * fontSize)
        img = drawText(img, text=text, fontSize=fontSize, textColor=textColor, position=position, backgroundColor=backgroundColor, backgroundOpacity=backgroundOpacity, margin=margin, padding=padding, outline=outline, outlineColor=outlineColor, outlineOpacity=outlineOpacity)
        processed.images[processedImageIndex] = img

script_callbacks.on_ui_settings(onUiSettings)