PoemForSmallFThings / NAIA /wd14_test.py
baqu2213's picture
Upload 2 files
03f219a verified
import tkinter as tk
from tkinter import filedialog, messagebox, scrolledtext, ttk
from PIL import Image, ImageTk, ImageGrab
import threading
import os
import numpy as np
import pandas as pd
import cv2
import tempfile
import io
import json
import sys
import os
def resource_path(relative_path):
""" PyInstaller์— ์˜ํ•ด ์ž„์‹œ ํด๋”์— ์ƒ์„ฑ๋œ ๋ฆฌ์†Œ์Šค์˜ ์ ˆ๋Œ€ ๊ฒฝ๋กœ๋ฅผ ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค. """
try:
# PyInstaller๋Š” ์ž„์‹œ ํด๋”๋ฅผ ์ƒ์„ฑํ•˜๊ณ  _MEIPASS์— ๊ฒฝ๋กœ๋ฅผ ์ €์žฅํ•ฉ๋‹ˆ๋‹ค.
base_path = sys._MEIPASS
except Exception:
base_path = os.path.abspath(".")
return os.path.join(base_path, relative_path)
class WD14TaggerGUI:
def __init__(self, root):
self.root = root
self.root.title("WD14 Auto Prompt Generator (aiart)")
self.root.geometry("700x870")
self.SETTINGS_FILE = "tagger_app_settings.json"
self.current_image_path = None
self.session = None
self.tags_df = None
# ๋ณ€์ˆ˜ ์ดˆ๊ธฐํ™”
self.general_threshold_var = tk.DoubleVar(value=0.5)
self.char_threshold_var = tk.DoubleVar(value=0.85)
self.rm_c_var = tk.BooleanVar()
self.rm_color_var = tk.BooleanVar()
self.rm_clothes_var = tk.BooleanVar()
self.webui_mode_var = tk.BooleanVar()
self.instant_inference_var = tk.BooleanVar(value=False)
## 1) ์ ‘๊ธฐ/ํŽผ์น˜๊ธฐ ์ƒํƒœ ๋ณ€์ˆ˜ ์ถ”๊ฐ€
self.show_removed_tags_var = tk.BooleanVar(value=True)
self.show_ignore_tags_var = tk.BooleanVar(value=False)
self.prompt_active_var = tk.BooleanVar(value=False)
self.init_filter_data()
self.setup_ui()
self.load_model()
self.show_initial_preview()
self.setup_drag_drop()
self.load_settings()
self.root.protocol("WM_DELETE_WINDOW", self.on_closing)
## 2) ์ „์—ญ Ctrl+V ์ด๋ฒคํŠธ ๋ฐ”์ธ๋”ฉ
self.root.bind('<Control-v>', self.handle_paste)
def init_filter_data(self):
# ... (๊ธฐ์กด๊ณผ ๋™์ผ)
try:
from tagbag import bag_of_tags, clothes_list
from character_dictionary import character_dictionary as cd
self.bag_of_tags = bag_of_tags
self.clothes_list = clothes_list
self.character_keys = list(cd.keys()) if isinstance(cd, dict) else []
except ImportError as e:
print(f"โš  NAIA ํŒŒ์ผ import ์‹คํŒจ: {e}")
print("๊ธฐ๋ณธ ๋ฐ์ดํ„ฐ๋ฅผ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.")
self.bag_of_tags = [
'smile', 'blush', 'open_mouth', 'closed_eyes', 'wink', 'frown',
'twintails', 'ponytail', 'braid', 'long_hair', 'short_hair',
'large_breasts', 'small_breasts', 'cleavage', 'navel'
]
self.clothes_list = [
'dress', 'skirt', 'shirt', 'jacket', 'uniform', 'school_uniform'
]
self.character_keys = []
self.colors = ['black', 'white', 'blond', 'silver', 'gray', 'yellow',
'blue', 'purple', 'red', 'pink', 'brown', 'orange',
'green', 'aqua', 'gradient']
## 1) ์ ‘๊ธฐ/ํŽผ์น˜๊ธฐ ํ† ๊ธ€ ํ•จ์ˆ˜
def toggle_removed_tags(self):
if self.show_removed_tags_var.get():
self.removed_tags_text.grid(row=3, column=0, sticky=(tk.W, tk.E, tk.N, tk.S))
else:
self.removed_tags_text.grid_forget()
## 1) ์ ‘๊ธฐ/ํŽผ์น˜๊ธฐ ํ† ๊ธ€ ํ•จ์ˆ˜
def toggle_ignore_tags(self):
if self.show_ignore_tags_var.get():
self.ignore_tags_text.grid(row=1, column=0, columnspan=2, sticky=(tk.W, tk.E, tk.N, tk.S))
self.save_settings_btn.grid(row=1, column=2, sticky="ne", padx=(10, 0))
else:
self.ignore_tags_text.grid_forget()
self.save_settings_btn.grid_forget()
def toggle_prompts(self):
pass
def setup_ui(self):
main_frame = ttk.Frame(self.root, padding="10")
main_frame.grid(row=0, column=0, sticky=(tk.W, tk.E, tk.N, tk.S))
# ... (์ƒ๋‹จ UI ์„ค์ •์€ ๊ธฐ์กด๊ณผ ๊ฑฐ์˜ ๋™์ผ)
load_frame = ttk.Frame(main_frame)
load_frame.grid(row=0, column=0, columnspan=2, sticky=(tk.W, tk.E), pady=(0, 10))
self.load_btn = ttk.Button(load_frame, text="์ด๋ฏธ์ง€ ๋ถˆ๋Ÿฌ์˜ค๊ธฐ", command=self.load_image)
self.load_btn.grid(row=0, column=0, padx=(0, 5))
self.clipboard_btn = ttk.Button(load_frame, text="์ด๋ฏธ์ง€ ๋ถˆ๋Ÿฌ์˜ค๊ธฐ (ํด๋ฆฝ๋ณด๋“œ)", command=self.load_from_clipboard)
self.clipboard_btn.grid(row=0, column=1)
content_frame = ttk.Frame(main_frame)
content_frame.grid(row=1, column=0, columnspan=2, sticky=(tk.W, tk.E, tk.N, tk.S), pady=(0, 10))
self.preview_frame = ttk.LabelFrame(content_frame, text="๋ฏธ๋ฆฌ๋ณด๊ธฐ", padding="10")
self.preview_frame.grid(row=0, column=0, sticky=(tk.W, tk.E, tk.N, tk.S), padx=(0, 2))
self.image_label = ttk.Label(self.preview_frame, text="")
self.image_label.grid(row=0, column=0)
settings_frame = ttk.LabelFrame(content_frame, text="์„ค์ •", padding="10")
settings_frame.grid(row=0, column=1, sticky=(tk.W, tk.E, tk.N, tk.S), padx=(2, 0))
ttk.Label(settings_frame, text="ํƒœ๊ทธ ์ž„๊ณ„๊ฐ’:").grid(row=0, column=0, sticky=tk.W)
ttk.Spinbox(settings_frame, from_=0.0, to=1.0, increment=0.05,
textvariable=self.general_threshold_var, width=15).grid(row=0, column=1, pady=(0, 10))
ttk.Label(settings_frame, text="์บ๋ฆญํ„ฐ ์ž„๊ณ„๊ฐ’:").grid(row=1, column=0, sticky=tk.W)
ttk.Spinbox(settings_frame, from_=0.0, to=1.0, increment=0.05,
textvariable=self.char_threshold_var, width=15).grid(row=1, column=1, pady=(0, 10))
extract_control_frame = ttk.Frame(settings_frame)
extract_control_frame.grid(row=2, column=0, columnspan=2, pady=10)
self.extract_btn = ttk.Button(extract_control_frame, text="ํƒœ๊ทธ ์ถ”์ถœ",
command=self.extract_tags, state='disabled')
self.extract_btn.pack(side=tk.LEFT, padx=(0, 10))
self.instant_inference_cb = ttk.Checkbutton(extract_control_frame, text="์ฆ‰์‹œ ์ถ”๋ก ", variable=self.instant_inference_var)
self.instant_inference_cb.pack(side=tk.LEFT)
filter_frame = ttk.LabelFrame(settings_frame, text="ํƒœ๊ทธ ํ•„ํ„ฐ๋ง", padding="5")
filter_frame.grid(row=3, column=0, columnspan=2, sticky=(tk.W, tk.E), pady=(10, 0))
ttk.Checkbutton(filter_frame, text="์บ๋ฆญํ„ฐ ํŠน์ง•+์ด๋ฆ„ ์ œ๊ฑฐ", variable=self.rm_c_var).grid(row=0, column=0, sticky=tk.W, pady=2)
ttk.Checkbutton(filter_frame, text="์˜๋ฅ˜ ์ƒ‰์ƒ ์ œ๊ฑฐ", variable=self.rm_color_var).grid(row=1, column=0, sticky=tk.W, pady=2)
ttk.Checkbutton(filter_frame, text="์˜์ƒ ์ œ๊ฑฐ (๋น„์ถ”์ฒœ)", variable=self.rm_clothes_var).grid(row=2, column=0, sticky=tk.W, pady=2)
ttk.Checkbutton(filter_frame, text="WEBUI/Comfy ๋ชจ๋“œ", variable=self.webui_mode_var).grid(row=3, column=0, sticky=tk.W, pady=2)
# ์„ ํ–‰/ํ›„ํ–‰ ํ”„๋กฌํ”„ํŠธ ํ”„๋ ˆ์ž„
prompt_frame = ttk.Frame(settings_frame)
prompt_frame.grid(row=4, column=0, columnspan=2, sticky="ew", pady=(15, 0))
self.prompt_active_cb = ttk.Checkbutton(prompt_frame, text="์„ ํ–‰/ํ›„ํ–‰ ํ”„๋กฌํ”„ํŠธ ํ™œ์„ฑ",
variable=self.prompt_active_var,
command=self.toggle_prompts)
self.prompt_active_cb.grid(row=0, column=0, sticky="w")
# ํƒญ ์œ„์ ฏ (Notebook) ์ƒ์„ฑ
self.prompt_notebook = ttk.Notebook(prompt_frame)
self.prompt_notebook.grid(row=1, column=0, columnspan=2, sticky="ew", pady=(5, 0))
# ์„ ํ–‰ ํ”„๋กฌํ”„ํŠธ ํƒญ
prefix_frame = ttk.Frame(self.prompt_notebook, padding=5)
self.prompt_notebook.add(prefix_frame, text="์„ ํ–‰ ํ”„๋กฌํ”„ํŠธ")
self.prefix_text = scrolledtext.ScrolledText(prefix_frame, height=3, width=30)
self.prefix_text.pack(fill="both", expand=True)
# ํ›„ํ–‰ ํ”„๋กฌํ”„ํŠธ ํƒญ
suffix_frame = ttk.Frame(self.prompt_notebook, padding=5)
self.prompt_notebook.add(suffix_frame, text="ํ›„ํ–‰ ํ”„๋กฌํ”„ํŠธ")
self.suffix_text = scrolledtext.ScrolledText(suffix_frame, height=3, width=30)
self.suffix_text.pack(fill="both", expand=True)
# ์ดˆ๊ธฐ ์ƒํƒœ ์„ค์ •
self.toggle_prompts()
result_frame = ttk.LabelFrame(main_frame, text="์ถ”์ถœ๋œ ํƒœ๊ทธ", padding="10")
result_frame.grid(row=2, column=0, columnspan=2, sticky=(tk.W, tk.E, tk.N, tk.S), pady=(10, 0))
self.result_text = scrolledtext.ScrolledText(result_frame, height=8)
self.result_text.grid(row=0, column=0, sticky=(tk.W, tk.E, tk.N, tk.S))
self.result_text.tag_config("prompt", foreground="grey")
copy_frame = ttk.Frame(result_frame)
copy_frame.grid(row=1, column=0, pady=(10, 5), sticky=(tk.W, tk.E))
ttk.Button(copy_frame, text="์ „์ฒด ๋ณต์‚ฌ", command=self.copy_all_tags).grid(row=0, column=0, padx=(0, 5))
ttk.Button(copy_frame, text="2๋ฒˆ ํƒœ๊ทธ๋ถ€ํ„ฐ ๋ณต์‚ฌ", command=self.copy_from_second).grid(row=0, column=1, padx=5)
ttk.Button(copy_frame, text="3๋ฒˆ ํƒœ๊ทธ๋ถ€ํ„ฐ ๋ณต์‚ฌ", command=self.copy_from_third).grid(row=0, column=2, padx=5)
## 1) '์ œ๊ฑฐ๋œ ํƒœ๊ทธ' ๋ผ๋ฒจ์„ ์ฒดํฌ๋ฐ•์Šค๋กœ ๋ณ€๊ฒฝ
self.removed_tags_cb = ttk.Checkbutton(result_frame, text="์ œ๊ฑฐ๋œ ํƒœ๊ทธ",
variable=self.show_removed_tags_var,
command=self.toggle_removed_tags)
self.removed_tags_cb.grid(row=2, column=0, sticky=tk.W, pady=(10, 2))
self.removed_tags_text = scrolledtext.ScrolledText(result_frame, height=5)
# self.removed_tags_text.grid(...)๋Š” toggle ํ•จ์ˆ˜์—์„œ ๊ด€๋ฆฌ
## 1) '๋ฌด์‹œํ•  ํƒœ๊ทธ' ๋ผ๋ฒจ์„ ์ฒดํฌ๋ฐ•์Šค๋กœ ๋ณ€๊ฒฝํ•˜๊ณ  ๋ณ„๋„ ํ”„๋ ˆ์ž„์œผ๋กœ ๋ถ„๋ฆฌ
ignore_frame = ttk.Frame(main_frame, padding="0")
ignore_frame.grid(row=3, column=0, columnspan=2, sticky=(tk.W, tk.E), pady=(10, 0))
self.ignore_tags_cb = ttk.Checkbutton(ignore_frame, text="๋ฌด์‹œํ•  ํƒœ๊ทธ (์‰ผํ‘œ๋กœ ๊ตฌ๋ถ„)",
variable=self.show_ignore_tags_var,
command=self.toggle_ignore_tags)
self.ignore_tags_cb.grid(row=0, column=0, sticky=tk.W)
self.ignore_tags_text = scrolledtext.ScrolledText(ignore_frame, height=4)
self.save_settings_btn = ttk.Button(ignore_frame, text="์„ค์ • ์ €์žฅ", command=self.save_settings)
ignore_frame.columnconfigure(1, weight=1)
## 1) ์ดˆ๊ธฐ ์ƒํƒœ ์„ค์ •
self.toggle_removed_tags()
self.toggle_ignore_tags()
# ... (๊ทธ๋ฆฌ๋“œ ์„ค์ •)
self.root.columnconfigure(0, weight=1)
self.root.rowconfigure(0, weight=1)
main_frame.columnconfigure(0, weight=1)
# main_frame.rowconfigure(2)์™€ (3)์€ ๋‚ด์šฉ์— ๋”ฐ๋ผ ์ž๋™ ์กฐ์ ˆ๋˜๋„๋ก weight ๋ฏธ์„ค์ •
content_frame.columnconfigure(0, weight=1)
content_frame.columnconfigure(1, weight=1)
result_frame.columnconfigure(0, weight=1)
result_frame.rowconfigure(0, weight=1)
def on_closing(self):
self.save_settings()
self.root.destroy()
def save_settings(self):
settings = {
'general_threshold': self.general_threshold_var.get(),
'char_threshold': self.char_threshold_var.get(),
'rm_c': self.rm_c_var.get(),
'rm_color': self.rm_color_var.get(),
'rm_clothes': self.rm_clothes_var.get(),
'webui_mode': self.webui_mode_var.get(),
'instant_inference': self.instant_inference_var.get(),
'ignore_tags': self.ignore_tags_text.get(1.0, tk.END).strip(),
## 1) ์ ‘๊ธฐ/ํŽผ์น˜๊ธฐ ์ƒํƒœ ์ €์žฅ
'show_removed_tags': self.show_removed_tags_var.get(),
'show_ignore_tags': self.show_ignore_tags_var.get(),
'prompt_active': self.prompt_active_var.get(),
'prefix_prompt': self.prefix_text.get(1.0, tk.END).strip(),
'suffix_prompt': self.suffix_text.get(1.0, tk.END).strip()
}
try:
with open(self.SETTINGS_FILE, 'w', encoding='utf-8') as f:
json.dump(settings, f, indent=4)
# if self.root.focus_get() == self.save_settings_btn:
# messagebox.showinfo("์ €์žฅ ์™„๋ฃŒ", f"์„ค์ •์ด {self.SETTINGS_FILE}์— ์ €์žฅ๋˜์—ˆ์Šต๋‹ˆ๋‹ค.")
except Exception as e:
messagebox.showerror("์ €์žฅ ์‹คํŒจ", f"์„ค์ • ์ €์žฅ ์ค‘ ์˜ค๋ฅ˜๊ฐ€ ๋ฐœ์ƒํ–ˆ์Šต๋‹ˆ๋‹ค: {e}")
def load_settings(self):
try:
if os.path.exists(self.SETTINGS_FILE):
with open(self.SETTINGS_FILE, 'r', encoding='utf-8') as f:
settings = json.load(f)
# ... (๊ธฐ์กด ์„ค์ • ๋กœ๋“œ)
self.general_threshold_var.set(settings.get('general_threshold', 0.5))
self.char_threshold_var.set(settings.get('char_threshold', 0.85))
self.rm_c_var.set(settings.get('rm_c', False))
self.rm_color_var.set(settings.get('rm_color', False))
self.rm_clothes_var.set(settings.get('rm_clothes', False))
self.webui_mode_var.set(settings.get('webui_mode', False))
self.instant_inference_var.set(settings.get('instant_inference', False))
self.ignore_tags_text.delete(1.0, tk.END)
self.ignore_tags_text.insert(1.0, settings.get('ignore_tags', ''))
## 1) ์ ‘๊ธฐ/ํŽผ์น˜๊ธฐ ์ƒํƒœ ๋กœ๋“œ ๋ฐ UI ์—…๋ฐ์ดํŠธ
self.show_removed_tags_var.set(settings.get('show_removed_tags', True))
self.show_ignore_tags_var.set(settings.get('show_ignore_tags', False))
self.toggle_removed_tags()
self.toggle_ignore_tags()
self.prompt_active_var.set(settings.get('prompt_active', False))
self.prefix_text.delete(1.0, tk.END)
self.prefix_text.insert(1.0, settings.get('prefix_prompt', ''))
self.suffix_text.delete(1.0, tk.END)
self.suffix_text.insert(1.0, settings.get('suffix_prompt', ''))
self.toggle_prompts() # UI ์ƒํƒœ ์—…๋ฐ์ดํŠธ
except Exception as e:
messagebox.showwarning("์„ค์ • ๋กœ๋“œ ์‹คํŒจ", f"์„ค์ • ํŒŒ์ผ({self.SETTINGS_FILE})์„ ๋ถˆ๋Ÿฌ์˜ค๋Š” ๋ฐ ์‹คํŒจํ–ˆ์Šต๋‹ˆ๋‹ค: {e}")
## 2) ์ „์—ญ ๋ถ™์—ฌ๋„ฃ๊ธฐ ์ด๋ฒคํŠธ ํ•ธ๋“ค๋Ÿฌ ์ถ”๊ฐ€
def handle_paste(self, event):
try:
# 1. ํด๋ฆฝ๋ณด๋“œ์— ์ด๋ฏธ์ง€๊ฐ€ ์žˆ๋Š”์ง€ ํ™•์ธ
clipboard_image = ImageGrab.grabclipboard()
if isinstance(clipboard_image, Image.Image):
self.load_from_clipboard()
return "break" # ์ด๋ฒคํŠธ ์ „ํŒŒ ์ค‘๋‹จ
except (ValueError, TypeError):
pass
except Exception as e:
print(f"ํด๋ฆฝ๋ณด๋“œ ์ด๋ฏธ์ง€ ํ™•์ธ ์ค‘ ์˜ค๋ฅ˜: {e}")
try:
# ํด๋ฆฝ๋ณด๋“œ์˜ ํ…์ŠคํŠธ ๋‚ด์šฉ์„ ๊ฐ€์ ธ์˜ต๋‹ˆ๋‹ค.
clipboard_text = self.root.clipboard_get()
# 2. ํด๋ฆฝ๋ณด๋“œ์— ํ…์ŠคํŠธ๊ฐ€ ์žˆ๊ณ , URL ํ˜•์‹์ธ์ง€ ํ™•์ธ
if isinstance(clipboard_text, str) and clipboard_text.startswith(('http://', 'https://')):
self._load_image_from_url(clipboard_text)
return "break" # ์ด๋ฒคํŠธ ์ „ํŒŒ ์ค‘๋‹จ
## --- ์—ฌ๊ธฐ๋ถ€ํ„ฐ ์ถ”๊ฐ€๋˜๋Š” ๋ถ€๋ถ„ --- ##
# 3. ํด๋ฆฝ๋ณด๋“œ ํ…์ŠคํŠธ๊ฐ€ ์ด๋ฏธ์ง€ ํŒŒ์ผ ๊ฒฝ๋กœ์ธ์ง€ ํ™•์ธ
image_extensions = ('.jpg', '.jpeg', '.png', '.webp', '.bmp', '.tiff')
if isinstance(clipboard_text, str) and clipboard_text.lower().endswith(image_extensions):
# ๋”ฐ์˜ดํ‘œ๋กœ ๊ฐ์‹ธ์ง„ ๊ฒฝ๋กœ์ผ ๊ฒฝ์šฐ ์ œ๊ฑฐ (ํŒŒ์ผ ํƒ์ƒ‰๊ธฐ์—์„œ ๋ณต์‚ฌ ์‹œ)
path = clipboard_text.strip('"')
# ํŒŒ์ผ ์‹œ์Šคํ…œ์— ์‹ค์ œ ์กด์žฌํ•˜๋Š” ๊ฒฝ๋กœ์ธ์ง€ ํ™•์ธ ํ›„ ํ—ฌํผ ๋ฉ”์„œ๋“œ ํ˜ธ์ถœ
if os.path.exists(path):
self._load_image_from_path(path)
return "break" # ์ด๋ฒคํŠธ ์ „ํŒŒ ์ค‘๋‹จ
## --- ์—ฌ๊ธฐ๊นŒ์ง€ ์ถ”๊ฐ€ --- ##
except tk.TclError:
# ํด๋ฆฝ๋ณด๋“œ์— ํ…์ŠคํŠธ๊ฐ€ ์—†๋Š” ๊ฒฝ์šฐ
pass
except Exception as e:
messagebox.showerror("๋ถ™์—ฌ๋„ฃ๊ธฐ ์˜ค๋ฅ˜", f"์ฒ˜๋ฆฌ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {e}")
return "break"
# ์œ„ ์กฐ๊ฑด์— ํ•ด๋‹นํ•˜์ง€ ์•Š์œผ๋ฉด ๊ธฐ๋ณธ ๋ถ™์—ฌ๋„ฃ๊ธฐ ๋™์ž‘
return
## 2) URL ์ฒ˜๋ฆฌ ๋กœ์ง์„ ๋ณ„๋„ ๋ฉ”์„œ๋“œ๋กœ ๋ถ„๋ฆฌ (์žฌ์‚ฌ์šฉ์„ฑ)
def _load_image_from_url(self, url):
try:
img = self.download_image_from_url(url)
temp_dir = tempfile.gettempdir()
temp_file = os.path.join(temp_dir, f"wd14_url_{os.getpid()}.png")
img.save(temp_file)
self.current_image_path = temp_file
self.show_preview(temp_file)
if self.session is not None:
self.extract_btn.config(state='normal')
if self.instant_inference_var.get():
self.extract_tags()
except Exception as e:
messagebox.showerror("์˜ค๋ฅ˜", f"์›น ์ด๋ฏธ์ง€ ๋กœ๋“œ ์‹คํŒจ: {str(e)}")
def load_model(self):
try:
import onnxruntime as ort
# --- ์ˆ˜์ • ํ›„ ---
model_path = resource_path("model.onnx")
tags_path = resource_path("selected_tags.csv")
self.session = ort.InferenceSession(model_path, providers=['CPUExecutionProvider'])
self.tags_df = pd.read_csv(tags_path)
except Exception as e:
messagebox.showerror("์˜ค๋ฅ˜", f"๋ชจ๋ธ ๋กœ๋“œ ์‹คํŒจ: {str(e)}")
def show_initial_preview(self):
# ... (๊ธฐ์กด๊ณผ ๋™์ผ)
try:
white_image = Image.new('RGB', (374, 374), 'white')
photo = ImageTk.PhotoImage(white_image)
self.image_label.config(image=photo, text="์ด๋ฏธ์ง€๋ฅผ Drag&Drop\nํ•ด์ฃผ์„ธ์š”")
self.image_label.image = photo
except Exception as e:
self.image_label.config(text="์ด๋ฏธ์ง€๋ฅผ Drag&Drop\nํ•ด์ฃผ์„ธ์š”")
def setup_drag_drop(self):
# ... (๊ธฐ์กด๊ณผ ๋™์ผ)
try:
from tkinterdnd2 import TkinterDnD, DND_ALL
self.TkdndVersion = TkinterDnD._require(self.root)
self.image_label.drop_target_register(DND_ALL)
self.image_label.dnd_bind('<<Drop>>', self.on_drop)
self.image_label.dnd_bind('<<DragEnter>>', self.on_drag_enter)
self.image_label.dnd_bind('<<DragLeave>>', self.on_drag_leave)
except ImportError:
print("tkinterdnd2๊ฐ€ ์„ค์น˜๋˜์ง€ ์•Š์Œ. pip install tkinterdnd2๋กœ ์„ค์น˜ํ•˜๋ฉด ๋“œ๋ž˜๊ทธ์•ค๋“œ๋กญ์„ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.")
except Exception as e:
print(f"Drag & Drop ์ดˆ๊ธฐํ™” ์˜ค๋ฅ˜: {e}")
# ... (on_drag_enter, on_drag_leave, download_image_from_url ๊ธฐ์กด๊ณผ ๋™์ผ)
def on_drag_enter(self, event):
try:
self.preview_frame.config(text="๋ฏธ๋ฆฌ๋ณด๊ธฐ (์ด๋ฏธ์ง€๋ฅผ ๋†“์•„์ฃผ์„ธ์š”!)")
except: pass
def on_drag_leave(self, event):
try:
self.preview_frame.config(text="๋ฏธ๋ฆฌ๋ณด๊ธฐ")
except: pass
def download_image_from_url(self, url):
import requests
response = requests.get(url)
response.raise_for_status()
return Image.open(io.BytesIO(response.content))
def on_drop(self, event):
try:
if not isinstance(event, str):
file_path_or_url = event.data.strip("{}")
else:
file_path_or_url = event
if not file_path_or_url:
messagebox.showwarning("๊ฒฝ๊ณ ", "๋“œ๋กญ๋œ ๋ฐ์ดํ„ฐ๋ฅผ ์ฝ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค.")
self.preview_frame.config(text="๋ฏธ๋ฆฌ๋ณด๊ธฐ")
return
# URL ์ฒ˜๋ฆฌ
if file_path_or_url.startswith(('http://', 'https://')):
self._load_image_from_url(file_path_or_url) ## 2) ์žฌ์‚ฌ์šฉ๋˜๋Š” URL ๋กœ๋“œ ํ•จ์ˆ˜ ํ˜ธ์ถœ
self.preview_frame.config(text="๋ฏธ๋ฆฌ๋ณด๊ธฐ")
return
else:
if file_path_or_url.startswith(('blob:')):
messagebox.showwarning("Blob URL ์˜ค๋ฅ˜", "NAI ํ™ˆํŽ˜์ด์ง€์—์„œ ์ƒ์„ฑํ•œ ์ด๋ฏธ์ง€๋Š” ์ฆ‰์‹œ ๋ณต์‚ฌํ•  ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค.\n์ด๋ฏธ์ง€๋ฅผ ๋‹ค์šด๋กœ๋“œํ•œ ํ›„ ๋“œ๋ž˜๊ทธ & ๋“œ๋กญํ•ด์ฃผ์„ธ์š”.")
self.preview_frame.config(text="๋ฏธ๋ฆฌ๋ณด๊ธฐ")
return
# ํŒŒ์ผ ์ฒ˜๋ฆฌ
file_path_or_url = file_path_or_url.replace("\\", '/')
if os.path.exists(file_path_or_url):
self.current_image_path = file_path_or_url
self.show_preview(file_path_or_url)
if self.session is not None:
self.extract_btn.config(state='normal')
if self.instant_inference_var.get():
self.extract_tags()
self.preview_frame.config(text="๋ฏธ๋ฆฌ๋ณด๊ธฐ")
else:
messagebox.showerror("์˜ค๋ฅ˜", f"์œ ํšจํ•˜์ง€ ์•Š์€ ํŒŒ์ผ ๊ฒฝ๋กœ์ž…๋‹ˆ๋‹ค: '{file_path_or_url}'")
self.preview_frame.config(text="๋ฏธ๋ฆฌ๋ณด๊ธฐ")
except Exception as e:
messagebox.showerror("์˜ค๋ฅ˜", f"ํŒŒ์ผ ๋“œ๋กญ ์˜ค๋ฅ˜: {str(e)}")
self.preview_frame.config(text="๋ฏธ๋ฆฌ๋ณด๊ธฐ")
def _load_image_from_path(self, path):
"""์ฃผ์–ด์ง„ ํŒŒ์ผ ๊ฒฝ๋กœ์—์„œ ์ด๋ฏธ์ง€๋ฅผ ๋กœ๋“œํ•˜๊ณ  ๋ฏธ๋ฆฌ๋ณด๊ธฐ๋ฅผ ํ‘œ์‹œํ•ฉ๋‹ˆ๋‹ค."""
# ๊ฒฝ๋กœ์˜ ์œ ํšจ์„ฑ์„ ํ•œ ๋ฒˆ ๋” ํ™•์ธํ•ฉ๋‹ˆ๋‹ค.
if os.path.exists(path):
self.current_image_path = path
self.show_preview(path)
if self.session is not None:
self.extract_btn.config(state='normal')
if self.instant_inference_var.get():
self.extract_tags()
else:
# ํ˜น์‹œ ๋ชจ๋ฅผ ๊ฒฝ์šฐ๋ฅผ ๋Œ€๋น„ํ•œ ๊ฒฝ๊ณ ์ฐฝ
messagebox.showwarning("๊ฒฝ๋กœ ์˜ค๋ฅ˜", f"ํŒŒ์ผ์„ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค: {path}")
def load_from_clipboard(self):
try:
image = ImageGrab.grabclipboard()
if isinstance(image, Image.Image):
temp_dir = tempfile.gettempdir()
temp_file = os.path.join(temp_dir, f"wd14_clipboard_{os.getpid()}.png")
# RGBA ์ด๋ฏธ์ง€๋ฅผ RGB๋กœ ๋ณ€ํ™˜ํ•˜์—ฌ JPEG/PNG ํ˜ธํ™˜์„ฑ ๋ฌธ์ œ ๋ฐฉ์ง€
if image.mode == 'RGBA':
image = image.convert('RGB')
image.save(temp_file, 'PNG') # ํฌ๋งท์„ ๋ช…์‹œ์ ์œผ๋กœ ์ง€์ •
self.current_image_path = temp_file
self.show_preview(temp_file)
if self.session is not None:
self.extract_btn.config(state='normal')
if self.instant_inference_var.get():
self.extract_tags()
else:
# ์ด ๋ฉ”์‹œ์ง€๋Š” handle_paste ๋กœ์ง ๋•Œ๋ฌธ์— ๊ฑฐ์˜ ๋ณด์ด์ง€ ์•Š๊ฒŒ ๋จ
messagebox.showinfo("์ •๋ณด", "ํด๋ฆฝ๋ณด๋“œ์— ์ด๋ฏธ์ง€๊ฐ€ ์—†์Šต๋‹ˆ๋‹ค.")
except Exception as e:
messagebox.showerror("์˜ค๋ฅ˜", f"ํด๋ฆฝ๋ณด๋“œ ๋ถˆ๋Ÿฌ์˜ค๊ธฐ ์˜ค๋ฅ˜: {str(e)}")
# ... (load_image, show_preview, extract_tags, _extract_tags_thread, apply_tag_filters,
# _is_character_name, _format_results, copy ํ•จ์ˆ˜๋“ค, _update_results, make_square, smart_resize ๊ธฐ์กด๊ณผ ๋™์ผ)
def load_image(self):
file_path = filedialog.askopenfilename(
title="์ด๋ฏธ์ง€ ์„ ํƒ",
filetypes=[("์ด๋ฏธ์ง€ ํŒŒ์ผ", "*.jpg *.jpeg *.png *.bmp *.tiff *.webp")]
)
if file_path:
self.current_image_path = file_path
self.show_preview(file_path)
if self.session is not None:
self.extract_btn.config(state='normal')
if self.instant_inference_var.get():
self.extract_tags()
def show_preview(self, image_path):
try:
image = Image.open(image_path).convert('RGB')
size = max(image.size)
square_image = Image.new('RGB', (size, size), 'white')
paste_x = (size - image.width) // 2
paste_y = (size - image.height) // 2
square_image.paste(image, (paste_x, paste_y))
square_image.thumbnail((374, 374), Image.Resampling.LANCZOS)
photo = ImageTk.PhotoImage(square_image)
self.image_label.config(image=photo, text="")
self.image_label.image = photo
except Exception as e:
self.image_label.config(text=f"์ด๋ฏธ์ง€ ๋กœ๋“œ ์‹คํŒจ: {str(e)}")
def extract_tags(self):
if not self.current_image_path or not self.session:
return
self.extract_btn.config(state='disabled', text='์ถ”์ถœ ์ค‘...')
self.result_text.delete(1.0, tk.END)
self.result_text.insert(tk.END, "ํƒœ๊ทธ ์ถ”์ถœ ์ค‘...")
self.removed_tags_text.delete(1.0, tk.END)
thread = threading.Thread(target=self._extract_tags_thread)
thread.daemon = True
thread.start()
def _extract_tags_thread(self):
try:
input_image = Image.open(self.current_image_path)
_, height, _, _ = self.session.get_inputs()[0].shape
image = input_image.convert('RGBA')
new_image = Image.new('RGBA', image.size, 'WHITE')
new_image.paste(image, mask=image)
image = new_image.convert('RGB')
image = np.asarray(image)[:, :, ::-1]
image = self.make_square(image, height)
image = self.smart_resize(image, height)
image = image.astype(np.float32)
image = np.expand_dims(image, 0)
input_name = self.session.get_inputs()[0].name
label_name = self.session.get_outputs()[0].name
confidents = self.session.run([label_name], {input_name: image})[0]
tags_with_conf = []
for i, conf in enumerate(confidents[0]):
if i < len(self.tags_df):
tags_with_conf.append((self.tags_df.iloc[i]['name'], float(conf)))
tags = dict(tags_with_conf[4:])
result_tuple = self._format_results(tags)
self.root.after(0, self._update_results, result_tuple)
except Exception as e:
error_msg = f"ํƒœ๊ทธ ์ถ”์ถœ ์˜ค๋ฅ˜: {str(e)}"
self.root.after(0, self._update_results, (error_msg, ""))
def apply_tag_filters(self, tags_list):
"""ํƒœ๊ทธ ํ•„ํ„ฐ๋ง ์ ์šฉ (๊ณต๋ฐฑ ๊ธฐ์ค€์œผ๋กœ ๋กœ์ง ๋‹จ์ˆœํ™”)"""
# ๋ชจ๋ธ์—์„œ ๋„˜์–ด์˜จ ์›๋ณธ ํƒœ๊ทธ(_ ํฌํ•จ) ๋ฆฌ์ŠคํŠธ
original_tags = tags_list.copy()
# ๋ชจ๋“  ์ฒ˜๋ฆฌ๋ฅผ ์œ„ํ•ด ๊ณต๋ฐฑ ๊ธฐ์ค€์œผ๋กœ ๋ณ€ํ™˜ํ•œ ํƒœ๊ทธ ๋ฆฌ์ŠคํŠธ
tags_to_process = [tag.replace('_', ' ') for tag in original_tags]
# ์ตœ์ข…์ ์œผ๋กœ ๋‚จ์„ ํƒœ๊ทธ๋“ค์˜ ์ธ๋ฑ์Šค๋ฅผ ๊ด€๋ฆฌ
final_indices = list(range(len(tags_to_process)))
user_removed_originals = []
other_removed_originals = []
# 1. ์‚ฌ์šฉ์ž๊ฐ€ ์ •์˜ํ•œ ๋ฌด์‹œํ•  ํƒœ๊ทธ ์ฒ˜๋ฆฌ
ignore_list_str = self.ignore_tags_text.get(1.0, tk.END).strip()
if ignore_list_str:
user_ignored_tags = {tag.strip().replace('_', ' ') for tag in ignore_list_str.split(',') if tag.strip()}
indices_to_remove = {i for i in final_indices if tags_to_process[i] in user_ignored_tags}
for i in indices_to_remove:
user_removed_originals.append(original_tags[i])
final_indices = [i for i in final_indices if i not in indices_to_remove]
# 2. ๋‚˜๋จธ์ง€ ํ•„ํ„ฐ๋“ค ์ ์šฉ (๋‚ด๋ถ€ ๋ฐ์ดํ„ฐ๊ฐ€ ์ด๋ฏธ ๊ณต๋ฐฑ ๊ธฐ์ค€์ด๋ฏ€๋กœ ์ง์ ‘ ๋น„๊ต)
if self.rm_c_var.get():
indices_to_remove = {i for i in final_indices if tags_to_process[i] in self.bag_of_tags or self._is_character_name(tags_to_process[i])}
for i in indices_to_remove:
other_removed_originals.append(original_tags[i])
final_indices = [i for i in final_indices if i not in indices_to_remove]
if self.rm_color_var.get():
indices_to_remove = {i for i in final_indices if any(color in tags_to_process[i] for color in self.colors) and " eyes" not in tags_to_process[i] and " hair" not in tags_to_process[i] and " pupils" not in tags_to_process[i]}
for i in indices_to_remove:
other_removed_originals.append(original_tags[i])
final_indices = [i for i in final_indices if i not in indices_to_remove]
if self.rm_clothes_var.get():
indices_to_remove = {i for i in final_indices if tags_to_process[i] in self.clothes_list and tags_to_process[i] not in self.bag_of_tags}
for i in indices_to_remove:
other_removed_originals.append(original_tags[i])
final_indices = [i for i in final_indices if i not in indices_to_remove]
# 3. ์ตœ์ข… ํƒœ๊ทธ ๋ชฉ๋ก ์ƒ์„ฑ ๋ฐ ํฌ๋งทํŒ…
final_tags_with_space = [tags_to_process[i] for i in final_indices]
if self.webui_mode_var.get():
final_tags_formatted = [tag.replace('(', '\\(').replace(')', '\\)') for tag in final_tags_with_space]
else:
final_tags_formatted = final_tags_with_space
# 4. ์„ ํ–‰/ํ›„ํ–‰ ํ”„๋กฌํ”„ํŠธ ์ฒ˜๋ฆฌ
final_tags_with_type = [(tag, 'original') for tag in final_tags_formatted]
if self.prompt_active_var.get():
prefix_str = self.prefix_text.get(1.0, tk.END).strip()
if prefix_str:
prefix_tags = [t.strip() for t in prefix_str.split(',') if t.strip()]
prefix_tags_with_type = [(tag, 'prompt') for tag in prefix_tags]
person_tags = {'1girl', '2girls', '3girls', '4girls', '5girls', '6+girls',
'1boy', '2boys', '3boys', '4boys', '5boys', '6+boys'}
last_person_tag_index = -1
search_range = min(len(final_tags_with_type), 4)
for i in range(search_range):
if final_tags_with_type[i][0] in person_tags:
last_person_tag_index = i
if last_person_tag_index != -1:
final_tags_with_type[last_person_tag_index+1:last_person_tag_index+1] = prefix_tags_with_type
else:
final_tags_with_type[0:0] = prefix_tags_with_type
suffix_str = self.suffix_text.get(1.0, tk.END).strip()
if suffix_str:
suffix_tags = [t.strip() for t in suffix_str.split(',') if t.strip()]
suffix_tags_with_type = [(tag, 'prompt') for tag in suffix_tags]
final_tags_with_type.extend(suffix_tags_with_type)
return final_tags_with_type, user_removed_originals, list(set(other_removed_originals))
def _is_character_name(self, tag):
if hasattr(self, 'character_keys') and tag in self.character_keys:
return True
if ('_' in tag and any(c.isupper() for c in tag) and
not tag.startswith('1') and
not any(word in tag.lower() for word in ['girl', 'boy', 'solo', 'multiple']) and
not any(word in tag.lower() for word in ['hair', 'eye', 'dress', 'shirt']) and
len(tag) > 3):
return True
return False
def _format_results(self, tags):
"""๊ฒฐ๊ณผ ํฌ๋งทํŒ… (ํƒ€์ž… ์ •๋ณด๊ฐ€ ํฌํ•จ๋œ ๋ฆฌ์ŠคํŠธ๋ฅผ ๊ทธ๋Œ€๋กœ ์ „๋‹ฌ)"""
# ... (๋ฉ”์„œ๋“œ ์ƒ๋‹จ์˜ 1์ฐจ ํ•„ํ„ฐ๋ง ๋กœ์ง์€ ๋™์ผ) ...
general_threshold = self.general_threshold_var.get()
char_threshold = self.char_threshold_var.get()
person_tags = {'1girl', '2girls', '3girls', '4girls', '5girls', '6+girls',
'1boy', '2boys', '3boys', '4boys', '5boys', '6+boys',
'multiple_girls', 'multiple_boys', 'solo', 'no_humans'}
filtered_tags = []
for tag, conf in tags.items():
if tag in person_tags:
if conf >= 0.9: filtered_tags.append(tag)
elif any(c.isupper() for c in tag) and '_' in tag and conf >= char_threshold:
filtered_tags.append(tag)
elif conf >= general_threshold:
filtered_tags.append(tag)
# 2์ฐจ ํ•„ํ„ฐ๋ง: ํƒ€์ž… ์ •๋ณด๊ฐ€ ํฌํ•จ๋œ ์ตœ์ข… ํƒœ๊ทธ ๋ฆฌ์ŠคํŠธ๋ฅผ ๋ฐ›์Œ
final_tags_with_type, user_removed, other_removed = self.apply_tag_filters(filtered_tags)
# ์ œ๊ฑฐ๋œ ํƒœ๊ทธ๋“ค์„ ํ‘œ์‹œ์šฉ์œผ๋กœ ์žฌ๊ตฌ์„ฑ
removed_for_display = []
for tag in user_removed:
removed_for_display.append((tag, 'ignored'))
for tag in other_removed:
if tag not in user_removed:
removed_for_display.append((tag, 'other'))
# ํŠœํ”Œ๋กœ ๊ฒฐ๊ณผ ๋ฐ˜ํ™˜: ((ํƒœ๊ทธ, ํƒ€์ž…) ๋ฆฌ์ŠคํŠธ, (ํƒœ๊ทธ, ํƒ€์ž…) ๋ฆฌ์ŠคํŠธ)
return final_tags_with_type, removed_for_display
def copy_all_tags(self):
content = self.result_text.get(1.0, tk.END).strip()
if content:
self.root.clipboard_clear()
self.root.clipboard_append(content)
#messagebox.showinfo("๋ณต์‚ฌ ์™„๋ฃŒ", "์ „์ฒด ํƒœ๊ทธ๊ฐ€ ํด๋ฆฝ๋ณด๋“œ์— ๋ณต์‚ฌ๋˜์—ˆ์Šต๋‹ˆ๋‹ค.")
else:
messagebox.showwarning("๋ณต์‚ฌ ์‹คํŒจ", "๋ณต์‚ฌํ•  ํƒœ๊ทธ๊ฐ€ ์—†์Šต๋‹ˆ๋‹ค.")
def copy_from_second(self):
self._copy_from_index(1, "2๋ฒˆ ํƒœ๊ทธ๋ถ€ํ„ฐ")
def copy_from_third(self):
self._copy_from_index(2, "3๋ฒˆ ํƒœ๊ทธ๋ถ€ํ„ฐ")
def _copy_from_index(self, start_index, description):
content = self.result_text.get(1.0, tk.END).strip()
if content:
tags_list = [tag.strip() for tag in content.split(',')]
if len(tags_list) > start_index:
selected_tags = tags_list[start_index:]
result = ', '.join(selected_tags)
self.root.clipboard_clear()
self.root.clipboard_append(result)
#messagebox.showinfo("๋ณต์‚ฌ ์™„๋ฃŒ", f"{description} ํƒœ๊ทธ๊ฐ€ ํด๋ฆฝ๋ณด๋“œ์— ๋ณต์‚ฌ๋˜์—ˆ์Šต๋‹ˆ๋‹ค.")
else:
messagebox.showwarning("๋ณต์‚ฌ ์‹คํŒจ", f"ํƒœ๊ทธ๊ฐ€ {start_index + 1}๊ฐœ ๋ฏธ๋งŒ์ž…๋‹ˆ๋‹ค.")
else:
messagebox.showwarning("๋ณต์‚ฌ ์‹คํŒจ", "๋ณต์‚ฌํ•  ํƒœ๊ทธ๊ฐ€ ์—†์Šต๋‹ˆ๋‹ค.")
def _update_results(self, result_tuple):
"""๊ฒฐ๊ณผ ์—…๋ฐ์ดํŠธ (์„ ํ–‰/ํ›„ํ–‰ ํ”„๋กฌํ”„ํŠธ ์ƒ‰์ƒ ์ ์šฉ)"""
main_tags_info, removed_tags_info = result_tuple
# ๋ฉ”์ธ ํƒœ๊ทธ ๋ฐ•์Šค ์—…๋ฐ์ดํŠธ
self.result_text.delete(1.0, tk.END)
if main_tags_info:
for i, (tag, tag_type) in enumerate(main_tags_info):
if i > 0:
self.result_text.insert(tk.END, ', ')
# ํ”„๋กฌํ”„ํŠธ ํƒœ๊ทธ์ผ ๊ฒฝ์šฐ "prompt" ์Šคํƒ€์ผ ์ ์šฉ
if tag_type == 'prompt':
self.result_text.insert(tk.END, tag, 'prompt')
else: # 'original' ํƒœ๊ทธ
self.result_text.insert(tk.END, tag)
# ์ œ๊ฑฐ๋œ ํƒœ๊ทธ ๋ฐ•์Šค ์—…๋ฐ์ดํŠธ (๊ธฐ์กด๊ณผ ๋™์ผ)
self.removed_tags_text.config(state=tk.NORMAL)
self.removed_tags_text.delete(1.0, tk.END)
if removed_tags_info:
for i, (tag, tag_type) in enumerate(removed_tags_info):
display_tag = tag.replace('_', ' ')
if i > 0:
self.removed_tags_text.insert(tk.END, ', ')
if tag_type == 'ignored':
self.removed_tags_text.insert(tk.END, display_tag, 'ignored')
else:
self.removed_tags_text.insert(tk.END, display_tag)
self.removed_tags_text.config(state=tk.DISABLED)
self.extract_btn.config(state='normal', text='ํƒœ๊ทธ ์ถ”์ถœ')
def make_square(self, img, target_size):
old_size = img.shape[:2]
desired_size = max(old_size)
desired_size = max(desired_size, target_size)
delta_w = desired_size - old_size[1]
delta_h = desired_size - old_size[0]
top, bottom = delta_h // 2, delta_h - (delta_h // 2)
left, right = delta_w // 2, delta_w - (delta_w // 2)
color = [255, 255, 255]
new_im = cv2.copyMakeBorder(
img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color
)
return new_im
def smart_resize(self, img, size):
if img.shape[0] > size:
img = cv2.resize(img, (size, size), interpolation=cv2.INTER_AREA)
elif img.shape[0] < size:
img = cv2.resize(img, (size, size), interpolation=cv2.INTER_CUBIC)
return img
def main():
try:
from tkinterdnd2 import TkinterDnD
root = TkinterDnD.Tk()
except ImportError:
print("tkinterdnd2๊ฐ€ ์„ค์น˜๋˜์ง€ ์•Š์•„ Drag & Drop ๊ธฐ๋Šฅ์ด ๋น„ํ™œ์„ฑํ™”๋ฉ๋‹ˆ๋‹ค.")
root = tk.Tk()
app = WD14TaggerGUI(root)
root.mainloop()
if __name__ == "__main__":
main()