File size: 4,780 Bytes
50261d7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
onUiLoaded(() => {
	addFormatButton('txt2img');
	addFormatButton('img2img');
});

function addFormatButton(type) {
	let formatButton = gradioApp().querySelector(`#${type}_format_prompt`);
	const lastButton = gradioApp().querySelector(`#${type}_clear_prompt`);
	if (formatButton || !lastButton || !lastButton.parentNode)
		return;
	formatButton = createFormatButton(`${type}_format_prompt`, type);
	lastButton.parentNode.append(formatButton);
}

function createFormatButton(id, type) {
	const button = document.createElement('button');
	button.id = id;
	button.type = 'button';
	button.innerHTML = '🪄';
	button.title = 'Format prompt~🪄'
	button.className = 'lg secondary gradio-button tool svelte-cmf5ev';
	button.addEventListener('click', () => formatPrompts(type));
	return button;
}

function formatPrompts(type) {
	for (let kind of ['_prompt', '_neg_prompt']) {
		const prompt = gradioApp().querySelector(`#${type + kind} > label > textarea`);
		const result = formatPrompt(prompt.value);
		prompt.value = result;
		dispatchInputEvent(prompt);
	}
}

function dispatchInputEvent(target) {
	let inputEvent = new Event('input');
	Object.defineProperty(inputEvent, 'target', { value: target });
	target.dispatchEvent(inputEvent);
}

function round(value) {
	return Math.round(value * 10000) / 10000
}

function convertStr(str) {
	return str.replace(/:/g, ':').replace(/(/g, '(').replace(/)/g, ')')
}

function convertStr2Array(str) {
	const bracketRegex = /([()<>[\]])/g

	const splitByBracket = str => {
		const arr = []
		let start = 0
		let depth = 0
		let match
		while ((match = bracketRegex.exec(str)) !== null) {
			if (depth === 0 && match.index > start) {
				arr.push(str.substring(start, match.index))
				start = match.index
			}
			if (match[0] === '(' || match[0] === '<' || match[0] === '[') {
				depth++
			} else if (match[0] === ')' || match[0] === '>' || match[0] === ']') {
				depth--
			}
			if (depth === 0) {
				arr.push(str.substring(start, match.index + 1))
				start = match.index + 1
			}
		}
		if (start < str.length) {
			arr.push(str.substring(start))
		}
		return arr
	}

	const splitByComma = str => {
		const arr = []
		let start = 0
		let inBracket = false
		for (let i = 0; i < str.length; i++) {
			if (str[i] === ',' && !inBracket) {
				arr.push(str.substring(start, i).trim())
				start = i + 1
			} else if (str[i].match(bracketRegex)) {
				inBracket = !inBracket
			}
		}
		arr.push(str.substring(start).trim())
		return arr
	}

	const cleanStr = str => {
		let arr = splitByBracket(str)
		arr = arr.flatMap((s) => splitByComma(s))
		return arr.filter((s) => s !== '')
	}

	return cleanStr(str)
		.filter((item) => {
			const pattern = /^[,\s, ]+$/
			return !pattern.test(item)
		})
		.filter(Boolean)
		.sort((a, b) => {
			return a.includes('<') && !b.includes('<') ? 1 : b.includes('<') && !a.includes('<') ? -1 : 0
		})
}

function convertArray2Str(array) {
	const newArray = array.map((item) => {
		if (item.includes('<')) return item
		const newItem = item
			.replace(/\s+/g, ' ')
			.replace(/,|\.\|。/g, ',')
			.replace(/“|‘|”|"|\/'/g, '')
			.replace(/, /g, ',')
			.replace(/,,/g, ',')
			.replace(/,/g, ', ')
		return convertStr2Array(newItem).join(', ')
	})
	return newArray.join(', ')
}

function formatPrompt(input) {
	const re_attention = /\{|\[|\}|\]|[^{}[\]]+/gmu

	let text = convertStr(input)
	const textArray = convertStr2Array(text)
	text = convertArray2Str(textArray)

	let res = []

	const curly_bracket_multiplier = 1.05
	const square_bracket_multiplier = 1 / 1.05

	const brackets = {
		'{': { stack: [], multiplier: curly_bracket_multiplier },
		'[': { stack: [], multiplier: square_bracket_multiplier },
	}

	function multiply_range(start_position, multiplier) {
		for (let pos = start_position; pos < res.length; pos++) {
			res[pos][1] = round(res[pos][1] * multiplier)
		}
	}

	for (const match of text.matchAll(re_attention)) {
		let word = match[0]

		if (word in brackets) {
			brackets[word].stack.push(res.length)
		} else if (word === '}' || word === ']') {
			const bracket = brackets[word === '}' ? '{' : '[']
			if (bracket.stack.length > 0) {
				multiply_range(bracket.stack.pop(), bracket.multiplier)
			}
		} else {
			res.push([word, 1.0])
		}
	}

	Object.keys(brackets).forEach((bracketType) => {
		brackets[bracketType].stack.forEach((pos) => {
			multiply_range(pos, brackets[bracketType].multiplier)
		})
	})

	if (res.length === 0) {
		res = [['', 1.0]]
	}

	let i = 0
	while (i + 1 < res.length) {
		if (res[i][1] === res[i + 1][1]) {
			res[i][0] += res[i + 1][0]
			res.splice(i + 1, 1)
		} else {
			i += 1
		}
	}

	let result = ''
	for (const [word, value] of res) {
		result += value === 1.0 ? word : `(${word}:${value.toString()})`
	}
	return result
}