import time
from typing import Union
import gradio
from jinja2 import Template
from modules import script_callbacks, scripts
from modules.processing import Processed, StableDiffusionProcessing
from modules.ui_components import InputAccordion
from lib.custom_text_overlay.align import Position
from lib.custom_text_overlay.drawText import drawText
from lib.custom_text_overlay.logger import logger
from lib.custom_text_overlay.options import getOption, onUiSettings
from src.custom_text_overlay.extension import extensionId, extensionTitle
templateBracketLeft = '{{'
templateBracketRight = '}}'
normalFontSize = 32 # for 1024x1024 pixels or anything with the same megapixel value
hookType = 'postprocess_image'
keysFromImg = ['cfg_scale', 'width', 'height', 'seed', 'subseed', 'prompt', 'negative_prompt', 'steps']
specialReplacements = {
'all_seeds': 'seed',
'all_subseeds': 'subseed',
'all_prompts': 'prompt',
'all_negative_prompts': 'negative_prompt',
}
class CustomTextOverlay(scripts.Script):
def title(self):
return extensionTitle
def show(self, is_img2img):
return scripts.AlwaysVisible
def ui(self, is_img2img):
minWidth = 200
templateEngine = getOption('template_engine')
useExamples = getOption('examples')
def getTemplateInput(positionLabel: str, defaultValue: str = '') -> tuple[gradio.Checkbox, gradio.Textbox]:
if not useExamples:
defaultValue = ''
checkbox = gradio.Checkbox(label=f'{positionLabel} text', value=True)
textbox = gradio.Textbox(label=f'{positionLabel} text template', value=defaultValue, lines=1, show_label=False)
return (checkbox, textbox)
with InputAccordion(False, label=extensionTitle, elem_id=self.elem_id(extensionId)) as enabled:
with gradio.Accordion('Text', open=True):
with gradio.Row():
with gradio.Column(min_width=minWidth):
timeTemplate = "{{ ('%.1f'|format(time)).rstrip('0').rstrip('.') }} sec\n{{ ('%.1f'|format(steps / time)).rstrip('0').rstrip('.') }} steps/s" if templateEngine == 'jinja2' else '{{time}}s'
(textEnabled1, textTemplate1) = getTemplateInput('Top left', timeTemplate)
with gradio.Column(min_width=minWidth):
(textEnabled2, textTemplate2) = getTemplateInput('Top')
with gradio.Column(min_width=minWidth):
(textEnabled3, textTemplate3) = getTemplateInput('Top right')
with gradio.Row():
with gradio.Column(min_width=minWidth):
(textEnabled4, textTemplate4) = getTemplateInput('Left')
with gradio.Column(min_width=minWidth):
(textEnabled5, textTemplate5) = getTemplateInput('Center')
with gradio.Column(min_width=minWidth):
(textEnabled6, textTemplate6) = getTemplateInput('Right')
with gradio.Row():
with gradio.Column(min_width=minWidth):
seedTemplate = 'Seed {{seed}}'
(textEnabled7, textTemplate7) = getTemplateInput('Bottom left', seedTemplate)
with gradio.Column(min_width=minWidth):
(textEnabled8, textTemplate8) = getTemplateInput('Bottom')
with gradio.Column(min_width=minWidth):
(textEnabled9, textTemplate9) = getTemplateInput('Bottom right')
with gradio.Accordion('Style'):
gradio.HTML('
General
')
with gradio.Row():
with gradio.Column(min_width=minWidth):
textColor = gradio.ColorPicker(label='Text color', value='#ffffff')
textScale = gradio.Slider(minimum=20, maximum=300, step=10, label='Text scale', value=120)
with gradio.Column(min_width=minWidth):
paddingScale = gradio.Slider(minimum=0, maximum=200, step=5, label='Padding scale', value=25)
marginScale = gradio.Slider(minimum=0, maximum=200, step=5, label='Margin scale', value=0)
gradio.HTML('Outline
')
with gradio.Row():
outlineScale = gradio.Slider(minimum=0, maximum=25, step=1, label='Outline scale', value=12)
outlineColor = gradio.ColorPicker(label='Outline color', value='#000000')
outlineOpacity = gradio.Slider(minimum=0, maximum=100, step=5, label='Outline opacity', value=100)
gradio.HTML('Background Box
')
with gradio.Row():
backgroundColor = gradio.ColorPicker(label='Background color', value='#000000')
backgroundOpacity = gradio.Slider(minimum=0, maximum=100, step=5, label='Background opacity', value=0)
return [enabled, textScale, textColor, backgroundColor, backgroundOpacity, paddingScale, marginScale, outlineScale, outlineColor, outlineOpacity, textEnabled1, textEnabled2, textEnabled3, textEnabled4, textEnabled5, textEnabled6, textEnabled7, textEnabled8, textEnabled9, textTemplate1, textTemplate2, textTemplate3, textTemplate4, textTemplate5, textTemplate6, textTemplate7, textTemplate8, textTemplate9]
def process(self, processing: StableDiffusionProcessing, enabled: bool, textScale: int, textColor: str, backgroundColor: str, backgroundOpacity: int, paddingScale: int, marginScale: int, outlineScale: int, outlineColor: str, outlineOpacity: int, textEnabled1: bool, textEnabled2: bool, textEnabled3: bool, textEnabled4: bool, textEnabled5: bool, textEnabled6: bool, textEnabled7: bool, textEnabled8: bool, textEnabled9: bool, textTemplate1: str, textTemplate2: str, textTemplate3: str, textTemplate4: str, textTemplate5: str, textTemplate6: str, textTemplate7: str, textTemplate8: str, textTemplate9: str):
if not enabled:
return
self.startTime = time.perf_counter()
def collectReplacements(self, staticReplacements: dict = {}, replacementSources: dict = {}, imageIndex: int = 0, timeSeconds: float = 0):
tempateEngine = getOption('template_engine')
logger.info(tempateEngine)
replacements = self.makeReplacementTable(staticReplacements, keysFromImg, replacementSources)
if timeSeconds is not None:
if tempateEngine == 'jinja2':
replacements['time'] = timeSeconds
else:
replacements['time.00'] = f'{timeSeconds:.2f}'
replacements['time.0'] = f'{timeSeconds:.1f}'
replacements['time'] = int(timeSeconds)
if imageIndex is not None:
replacements['image_index'] = imageIndex
if tempateEngine == 'jinja2':
for sourceName, source in replacementSources.items():
replacements[sourceName] = source
for arrayKey, singleKey in specialReplacements.items():
for replacementSource in replacementSources.values():
if hasattr(replacementSource, arrayKey):
value = getattr(replacementSource, arrayKey)
if value is not None:
if imageIndex < len(value):
replacements[singleKey] = str(value[imageIndex])
break
return replacements
def makeReplacementTable(self, baseReplacements: dict, needles: list, haystacks: dict):
replacements = baseReplacements.copy()
for needle in needles:
for hackstackName, hackstack in haystacks.items():
value = getattr(hackstack, needle, None)
if value is not None:
logger.debug(f'{templateBracketLeft}{needle}{templateBracketRight} = {hackstackName}.{needle}')
replacements[needle] = value
break
else:
logger.debug(f'{templateBracketLeft}{needle}{templateBracketRight} = ?')
replacements[needle] = '?'
return replacements
def applyReplacements(self, templateString: str, replacements: dict):
templateEngine = getOption('template_engine')
templateHandler = self.applyReplacementsJinja if templateEngine == 'jinja2' else self.applyReplacementsBasic
inputTemplate = templateString.strip()
output = templateHandler(inputTemplate, replacements)
if inputTemplate == output:
return inputTemplate
logger.debug(f'Resolving template “{inputTemplate}” to “{output}” with engine {templateEngine} and keys {", ".join(replacements.keys())}')
return output
def applyReplacementsJinja(self, templateString: str, replacements: dict):
jinjaTemplate = Template(templateString)
return jinjaTemplate.render(replacements)
def applyReplacementsBasic(self, templateString: str, replacements: dict):
for key, value in replacements.items():
if isinstance(value, int) or isinstance(value, float):
value = str(value)
if isinstance(value, list):
value = ', '.join(value)
if not isinstance(value, str):
continue
templateString = templateString.replace(f'{templateBracketLeft}{key}{templateBracketRight}', value)
return templateString
def postprocess_image(self, processing: StableDiffusionProcessing, processed, enabled: bool, textScale: int, textColor: str, backgroundColor: str, backgroundOpacity: int, paddingScale: int, marginScale: int, outlineScale: int, outlineColor: str, outlineOpacity: int, textEnabled1: bool, textEnabled2: bool, textEnabled3: bool, textEnabled4: bool, textEnabled5: bool, textEnabled6: bool, textEnabled7: bool, textEnabled8: bool, textEnabled9: bool, textTemplate1: str, textTemplate2: str, textTemplate3: str, textTemplate4: str, textTemplate5: str, textTemplate6: str, textTemplate7: str, textTemplate8: str, textTemplate9: str):
if not enabled or hookType != 'postprocess_image':
return
enabledTextTemplates: list[int] = []
for textTemplateIndex in range(1, 10):
if locals()[f'textEnabled{textTemplateIndex}'] and locals()[f'textTemplate{textTemplateIndex}'].strip() != '':
enabledTextTemplates.append(textTemplateIndex)
if len(enabledTextTemplates) == 0:
return
endTime = time.perf_counter()
generationTimeSeconds = endTime - self.startTime
self.startTime = endTime
furtherReplacementsSources = {
'img': processed.image,
'processed': processed,
'processing': processing,
}
replacements = self.collectReplacements(timeSeconds=generationTimeSeconds, replacementSources=furtherReplacementsSources, imageIndex=processing.iteration)
for textTemplateIndex in enabledTextTemplates:
textTemplate = locals()[f'textTemplate{textTemplateIndex}'].strip()
text = self.applyReplacements(textTemplate, replacements)
position = Position(textTemplateIndex)
pixels = processed.image.width * processed.image.height
fontSize = int((textScale / 100) * (normalFontSize * (pixels / 1048576)))
margin = int((marginScale / 100) * fontSize)
padding = int((paddingScale / 100) * fontSize)
outline = int((outlineScale / 100) * fontSize)
processed.image = drawText(processed.image, text=text, fontSize=fontSize, textColor=textColor, position=position, backgroundColor=backgroundColor, backgroundOpacity=backgroundOpacity, margin=margin, padding=padding, outline=outline, outlineColor=outlineColor, outlineOpacity=outlineOpacity)
def postprocess(self, processing: StableDiffusionProcessing, processed: Processed, enabled: bool, textScale: int, textColor: str, backgroundColor: str, backgroundOpacity: int, paddingScale: int, marginScale: int, outlineScale: int, outlineColor: str, outlineOpacity: int, textEnabled1: bool, textEnabled2: bool, textEnabled3: bool, textEnabled4: bool, textEnabled5: bool, textEnabled6: bool, textEnabled7: bool, textEnabled8: bool, textEnabled9: bool, textTemplate1: str, textTemplate2: str, textTemplate3: str, textTemplate4: str, textTemplate5: str, textTemplate6: str, textTemplate7: str, textTemplate8: str, textTemplate9: str):
if not enabled or hookType != 'postprocess':
return
enabledTextTemplates: list[int] = []
for textTemplateIndex in range(1, 10):
if locals()[f'textEnabled{textTemplateIndex}'] and locals()[f'textTemplate{textTemplateIndex}'].strip() != '':
enabledTextTemplates.append(textTemplateIndex)
if len(enabledTextTemplates) == 0:
return
endTime = time.perf_counter()
generationTimeSeconds = endTime - self.startTime
images = processed.images[processed.index_of_first_image:]
for imageIndex in range(len(images)):
processedImageIndex = imageIndex + processed.index_of_first_image
img = images[imageIndex]
replacements = {
'processed_image_index': processedImageIndex,
'real_image_index': imageIndex,
}
furtherReplacementsSources = {
'img': img,
'processed': processed,
'processing': processing,
}
replacements = self.collectReplacements(staticReplacements=replacements, replacementSources=furtherReplacementsSources, imageIndex=imageIndex, timeSeconds=generationTimeSeconds)
for textTemplateIndex in enabledTextTemplates:
textTemplate = locals()[f'textTemplate{textTemplateIndex}'].strip()
text = self.applyReplacements(textTemplate, replacements)
position = Position(textTemplateIndex)
pixels = img.width * img.height
fontSize = int((textScale / 100) * (normalFontSize * (pixels / 1048576)))
margin = int((marginScale / 100) * fontSize)
padding = int((paddingScale / 100) * fontSize)
outline = int((outlineScale / 100) * fontSize)
img = drawText(img, text=text, fontSize=fontSize, textColor=textColor, position=position, backgroundColor=backgroundColor, backgroundOpacity=backgroundOpacity, margin=margin, padding=padding, outline=outline, outlineColor=outlineColor, outlineOpacity=outlineOpacity)
processed.images[processedImageIndex] = img
script_callbacks.on_ui_settings(onUiSettings)