Aniimage-1 / generate_hf.py
8BitStudio's picture
Upload generate_hf.py
5dbc62b verified
"""
Aniimage Generator β€” Generate anime images from text prompts.
https://huggingface.co/8BitStudio/Aniimage-1
Usage:
pip install torch torchvision diffusers transformers safetensors pillow huggingface_hub
python generate_hf.py
"""
import os
import sys
import torch
import torch.nn.functional as F
import numpy as np
import tkinter as tk
from tkinter import ttk, simpledialog
from pathlib import Path
from PIL import Image, ImageTk, ImageEnhance, ImageFilter
from threading import Thread
# ── Paths ─────────────────────────────────────────────────────────────────────
SCRIPT_DIR = Path(__file__).resolve().parent
MODEL_DIR = SCRIPT_DIR / "models"
OUTPUT_DIR = SCRIPT_DIR / "generated"
# ── HuggingFace repo ─────────────────────────────────────────────────────────
HF_REPO_ID = "8BitStudio/Aniimage-1"
# ── UNet config (must match training) ─────────────────────────────────────────
UNET_CONFIG = dict(
sample_size=32,
in_channels=4,
out_channels=4,
block_out_channels=(256, 512, 768, 1024),
layers_per_block=2,
cross_attention_dim=768,
attention_head_dim=8,
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D",
"CrossAttnDownBlock2D", "DownBlock2D"),
up_block_types=("UpBlock2D", "CrossAttnUpBlock2D",
"CrossAttnUpBlock2D", "UpBlock2D"),
)
VAE_ID = "stabilityai/sd-vae-ft-mse"
CLIP_ID = "openai/clip-vit-large-patch14"
SCHEDULER_LIST = [
"DPM++ 2M Karras",
"DPM++ SDE Karras",
"Euler a",
"Euler",
"DDIM",
]
DEFAULT_NEGATIVE = (
"low quality, ugly, blurry, distorted, deformed, bad anatomy, "
"bad proportions, extra limbs, missing limbs, watermark, text, "
"signature, washed out, flat colors, manga panel, disfigured, "
"poorly drawn, jpeg artifacts, cropped, out of frame"
)
# ── Model discovery ───────────────────────────────────────────────────────────
def download_from_hf():
"""Download model weights from HuggingFace if not already cached."""
try:
from huggingface_hub import hf_hub_download
except ImportError:
print("Install huggingface_hub: pip install huggingface_hub")
return None
MODEL_DIR.mkdir(parents=True, exist_ok=True)
aniimage_dir = MODEL_DIR / "Aniimage-1"
weights_path = aniimage_dir / "diffusion_pytorch_model.safetensors"
if weights_path.exists():
print("Aniimage-1 weights already downloaded.")
return aniimage_dir
print(f"Downloading Aniimage-1 from {HF_REPO_ID}...")
aniimage_dir.mkdir(parents=True, exist_ok=True)
import shutil
dl_weights = hf_hub_download(repo_id=HF_REPO_ID,
filename="diffusion_pytorch_model.safetensors")
shutil.copy2(dl_weights, weights_path)
try:
dl_config = hf_hub_download(repo_id=HF_REPO_ID, filename="config.json")
shutil.copy2(dl_config, aniimage_dir / "config.json")
except Exception:
pass
print("Download complete!")
return aniimage_dir
def find_models():
"""Find all available models."""
options = []
if MODEL_DIR.exists():
for d in sorted(MODEL_DIR.iterdir()):
if d.is_dir():
safetensors = d / "diffusion_pytorch_model.safetensors"
ema_path = d / "ema_unet.pt"
unet_path = d / "unet.pt"
if safetensors.exists():
options.append(("safetensors", d.name, d, "256"))
elif ema_path.exists() or unet_path.exists():
options.append(("checkpoint", d.name, d, "256"))
return options
# ── Theme ─────────────────────────────────────────────────────────────────────
C = {
"bg": "#111119",
"panel": "#1b1b2f",
"card": "#24243e",
"card_sel": "#3a3a6e",
"border": "#2e2e52",
"accent": "#6c5ce7",
"accent_h": "#8577ed",
"red": "#e74c3c",
"green": "#2ecc71",
"text": "#eaeaea",
"text2": "#a0a0b8",
"text3": "#60607a",
"input": "#16162a",
"input_fg": "#dcdcf0",
}
class Generator:
def __init__(self, device="cuda"):
self.device = device if device == "cuda" and torch.cuda.is_available() else "cpu"
self.vae = None
self.text_encoder = None
self.tokenizer = None
self.unet = None
self.scheduler = None
self.loaded_checkpoint = None
self.latent_size = 32
self.output_size = 256
self.cancelled = False
def switch_device(self, new_device):
"""Move all loaded models to a new device."""
new_device = new_device if new_device == "cuda" and torch.cuda.is_available() else "cpu"
if new_device == self.device:
return
self.device = new_device
if self.vae is not None:
self.vae = self.vae.to(self.device)
if self.text_encoder is not None:
self.text_encoder = self.text_encoder.to(self.device)
if self.unet is not None:
self.unet = self.unet.to(self.device)
self.loaded_checkpoint = None # force reload on next generate
print(f"Switched to {self.device.upper()}")
def load_shared(self):
if self.vae is not None:
return
from diffusers import AutoencoderKL
from transformers import CLIPTextModel, CLIPTokenizer
print("Loading VAE...")
self.vae = AutoencoderKL.from_pretrained(VAE_ID).to(self.device)
self.vae.eval()
print("Loading CLIP text encoder...")
self.tokenizer = CLIPTokenizer.from_pretrained(CLIP_ID)
self.text_encoder = CLIPTextModel.from_pretrained(CLIP_ID).to(self.device)
self.text_encoder.eval()
self.scheduler = self._make_scheduler("DPM++ 2M Karras")
self.scheduler_name = "DPM++ 2M Karras"
print("Shared models loaded.")
def _make_scheduler(self, name="DPM++ 2M Karras"):
from diffusers import (DDIMScheduler, DPMSolverMultistepScheduler,
EulerAncestralDiscreteScheduler,
EulerDiscreteScheduler)
base = dict(num_train_timesteps=1000, beta_schedule="scaled_linear",
prediction_type="epsilon")
if name == "DPM++ 2M Karras":
return DPMSolverMultistepScheduler(
**base, algorithm_type="dpmsolver++",
solver_order=2, use_karras_sigmas=True)
elif name == "DPM++ SDE Karras":
return DPMSolverMultistepScheduler(
**base, algorithm_type="sde-dpmsolver++",
use_karras_sigmas=True)
elif name == "Euler a":
return EulerAncestralDiscreteScheduler(**base)
elif name == "Euler":
return EulerDiscreteScheduler(**base)
else:
return DDIMScheduler(**base, clip_sample=False,
set_alpha_to_one=False)
def set_scheduler(self, name):
self.scheduler = self._make_scheduler(name)
self.scheduler_name = name
def load_model(self, model_path: Path, res_label: str = "256"):
if str(model_path) == self.loaded_checkpoint:
return
from diffusers import UNet2DConditionModel
self.load_shared()
if res_label == "512":
self.latent_size = 64
self.output_size = 512
else:
self.latent_size = 32
self.output_size = 256
unet_cfg = dict(UNET_CONFIG)
unet_cfg["sample_size"] = self.latent_size
print(f"Loading UNet from {model_path.name} ({res_label}px)...")
self.unet = UNet2DConditionModel(**unet_cfg).to(self.device)
safetensors_path = model_path / "diffusion_pytorch_model.safetensors"
ema_path = model_path / "ema_unet.pt"
unet_path = model_path / "unet.pt"
if safetensors_path.exists():
from safetensors.torch import load_file
state = load_file(str(safetensors_path), device=str(self.device))
self.unet.load_state_dict(state)
print("Loaded safetensors weights.")
elif ema_path.exists():
state = torch.load(ema_path, map_location=self.device, weights_only=True)
if "shadow_params" in state:
params = dict(self.unet.named_parameters())
keys = list(params.keys())
for i, sp in enumerate(state["shadow_params"]):
params[keys[i]].data.copy_(sp)
else:
self.unet.load_state_dict(state)
print("Loaded EMA weights.")
elif unet_path.exists():
self.unet.load_state_dict(
torch.load(unet_path, map_location=self.device, weights_only=True))
print("Loaded UNet weights.")
else:
raise FileNotFoundError(f"No weights found in {model_path}")
self.unet.eval()
self.loaded_checkpoint = str(model_path)
print(f"Ready to generate at {self.output_size}x{self.output_size}!")
def _decode_latents(self, latents, post_process=False):
scaled = latents / self.vae.config.scaling_factor
with torch.no_grad():
image = self.vae.decode(scaled.float()).sample
image = (image.float() / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()[0]
image = (image * 255).round().astype("uint8")
img = Image.fromarray(image)
if post_process:
img = self._post_process(img)
return img
def _sharpen_latents(self, latents, amount=0.08):
blurred = F.avg_pool2d(latents, kernel_size=3, stride=1, padding=1)
return latents + amount * (latents - blurred)
def _post_process(self, img):
img = img.filter(ImageFilter.UnsharpMask(radius=1.5, percent=40, threshold=2))
img = ImageEnhance.Contrast(img).enhance(1.06)
img = ImageEnhance.Color(img).enhance(1.10)
return img
def _image_quality_score(self, img: Image.Image) -> float:
arr = np.array(img.convert("L"), dtype=np.float32)
lap = (np.roll(arr, 1, 0) + np.roll(arr, -1, 0)
+ np.roll(arr, 1, 1) + np.roll(arr, -1, 1) - 4.0 * arr)
sharpness = float(np.var(lap))
arr_rgb = np.array(img, dtype=np.float32)
color_var = float(np.mean(np.var(arr_rgb, axis=(0, 1))))
score = (sharpness * 0.6 + color_var * 0.4)
return min(100.0, score / 10.0)
@torch.no_grad()
def generate(self, prompt: str, negative_prompt: str = "",
steps: int = 25, guidance_scale: float = 7.5,
seed: int = -1, preview_callback=None,
preview_every: int = 5) -> tuple:
if seed < 0:
seed = torch.randint(0, 2**32, (1,)).item()
gen = torch.Generator(device=self.device).manual_seed(seed)
tok = self.tokenizer(prompt, padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True, return_tensors="pt")
text_emb = self.text_encoder(tok.input_ids.to(self.device))[0]
tok_neg = self.tokenizer(negative_prompt if negative_prompt else "",
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True, return_tensors="pt")
neg_emb = self.text_encoder(tok_neg.input_ids.to(self.device))[0]
text_emb_combined = torch.cat([neg_emb, text_emb])
scheduler = self._make_scheduler(self.scheduler_name)
scheduler.set_timesteps(steps, device=self.device)
latents = torch.randn(1, 4, self.latent_size, self.latent_size,
generator=gen, device=self.device)
latents = latents * scheduler.init_noise_sigma
timesteps = scheduler.timesteps
total_steps = len(timesteps)
for step_i, t in enumerate(timesteps):
if self.cancelled:
return None, seed
latent_input = torch.cat([latents] * 2)
latent_input = scheduler.scale_model_input(latent_input, t)
with torch.autocast(device_type="cuda", dtype=torch.bfloat16,
enabled=(self.device == "cuda")):
pred = self.unet(latent_input, t,
encoder_hidden_states=text_emb_combined).sample
pred_neg, pred_text = pred.chunk(2)
pred = pred_neg + guidance_scale * (pred_text - pred_neg)
latents = scheduler.step(pred, t, latents).prev_sample
if (preview_callback and step_i > 0
and step_i % preview_every == 0
and step_i < total_steps - 1):
preview = self._decode_latents(latents, post_process=False)
preview_callback(preview, step_i + 1, total_steps)
latents = self._sharpen_latents(latents)
final = self._decode_latents(latents, post_process=True)
return final, seed
@torch.no_grad()
def generate_adaptive(self, prompt: str, negative_prompt: str = "",
base_steps: int = 25, max_steps: int = 85,
guidance_scale: float = 7.5,
quality_threshold: float = 45.0,
preview_callback=None, preview_every: int = 5,
status_callback=None) -> tuple:
result = self.generate(
prompt=prompt, negative_prompt=negative_prompt,
steps=base_steps, guidance_scale=guidance_scale,
preview_callback=preview_callback, preview_every=preview_every)
if result[0] is None:
return result
image, seed = result
quality = self._image_quality_score(image)
if status_callback:
status_callback(f"Quality: {quality:.1f}/100")
if quality >= quality_threshold:
return image, seed
rounds = 0
max_rounds = (max_steps - base_steps) // 20
while quality < quality_threshold and rounds < max_rounds:
if self.cancelled:
return image, seed
rounds += 1
if status_callback:
status_callback(f"Refining +20 steps (round {rounds})...")
refined = self.refine(
source_image=image, prompt=prompt,
negative_prompt=negative_prompt,
extra_steps=20, strength=0.3,
guidance_scale=guidance_scale,
preview_callback=preview_callback, preview_every=5)
if refined is None:
return image, seed
image = refined
quality = self._image_quality_score(image)
if status_callback:
status_callback(f"Quality after round {rounds}: {quality:.1f}/100")
return image, seed
@torch.no_grad()
def refine(self, source_image: Image.Image, prompt: str,
negative_prompt: str = "", extra_steps: int = 20,
strength: float = 0.35, guidance_scale: float = 7.5,
preview_callback=None, preview_every: int = 5) -> Image.Image:
img = source_image.resize((self.output_size, self.output_size), Image.LANCZOS)
img_tensor = torch.from_numpy(np.array(img)).float().div(127.5).sub(1.0)
img_tensor = img_tensor.permute(2, 0, 1).unsqueeze(0).to(self.device)
with torch.no_grad():
latents = self.vae.encode(img_tensor.float()).latent_dist.sample()
latents = latents * self.vae.config.scaling_factor
tok = self.tokenizer(prompt, padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True, return_tensors="pt")
text_emb = self.text_encoder(tok.input_ids.to(self.device))[0]
tok_neg = self.tokenizer(negative_prompt if negative_prompt else "",
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True, return_tensors="pt")
neg_emb = self.text_encoder(tok_neg.input_ids.to(self.device))[0]
text_emb_combined = torch.cat([neg_emb, text_emb])
scheduler = self._make_scheduler(self.scheduler_name)
scheduler.set_timesteps(extra_steps, device=self.device)
start_step = max(0, int(len(scheduler.timesteps) * (1 - strength)))
timesteps = scheduler.timesteps[start_step:]
noise = torch.randn_like(latents)
latents = scheduler.add_noise(latents, noise, timesteps[:1])
total_steps = len(timesteps)
for step_i, t in enumerate(timesteps):
if self.cancelled:
return None
latent_input = torch.cat([latents] * 2)
latent_input = scheduler.scale_model_input(latent_input, t)
with torch.autocast(device_type="cuda", dtype=torch.bfloat16,
enabled=(self.device == "cuda")):
pred = self.unet(latent_input, t,
encoder_hidden_states=text_emb_combined).sample
pred_neg, pred_text = pred.chunk(2)
pred = pred_neg + guidance_scale * (pred_text - pred_neg)
latents = scheduler.step(pred, t, latents).prev_sample
if (preview_callback and step_i > 0
and step_i % preview_every == 0
and step_i < total_steps - 1):
preview = self._decode_latents(latents, post_process=False)
preview_callback(preview, step_i + 1, total_steps)
latents = self._sharpen_latents(latents)
return self._decode_latents(latents, post_process=True)
# ── GUI ───────────────────────────────────────────────────────────────────────
class App:
def __init__(self):
self.gen = Generator()
self.models = find_models()
self.generated_images = []
self.generated_seeds = []
self.photo_refs = []
self.generating = False
self.selected_index = None
self.root = tk.Tk()
self.root.title("Aniimage")
self.root.configure(bg=C["bg"])
self.root.resizable(True, True)
self.root.geometry("900x780")
self.root.minsize(640, 500)
self._setup_styles()
self._build_ui()
def _setup_styles(self):
s = ttk.Style()
s.theme_use("clam")
# Base
s.configure(".", background=C["bg"], foreground=C["text"], font=("Segoe UI", 10))
s.configure("TFrame", background=C["bg"])
s.configure("TLabel", background=C["bg"], foreground=C["text"])
s.configure("TCheckbutton", background=C["bg"], foreground=C["text"])
# Combobox β€” readable text
s.configure("TCombobox", fieldbackground=C["input"], foreground=C["input_fg"],
selectbackground=C["accent"], selectforeground="#ffffff",
arrowcolor=C["text2"], padding=4)
s.map("TCombobox",
fieldbackground=[("readonly", C["input"])],
foreground=[("readonly", C["input_fg"])],
selectbackground=[("readonly", C["accent"])],
selectforeground=[("readonly", "#ffffff")])
# Combobox dropdown list colors
self.root.option_add("*TCombobox*Listbox.background", C["input"])
self.root.option_add("*TCombobox*Listbox.foreground", C["input_fg"])
self.root.option_add("*TCombobox*Listbox.selectBackground", C["accent"])
self.root.option_add("*TCombobox*Listbox.selectForeground", "#ffffff")
self.root.option_add("*TCombobox*Listbox.font", ("Segoe UI", 10))
# Spinbox
s.configure("TSpinbox", fieldbackground=C["input"], foreground=C["input_fg"],
arrowcolor=C["text2"], padding=3)
# Buttons
s.configure("TButton", font=("Segoe UI", 10), padding=(14, 7),
background=C["card"], foreground=C["text"])
s.map("TButton", background=[("active", C["card_sel"]), ("disabled", C["bg"])],
foreground=[("disabled", C["text3"])])
s.configure("Go.TButton", font=("Segoe UI", 11, "bold"), padding=(20, 9),
background=C["accent"], foreground="#ffffff")
s.map("Go.TButton", background=[("active", C["accent_h"]),
("disabled", C["border"])])
s.configure("Stop.TButton", font=("Segoe UI", 10, "bold"), padding=(14, 7),
background=C["red"], foreground="#ffffff")
s.map("Stop.TButton", background=[("active", "#c0392b"),
("disabled", C["border"])])
# Labelframe
s.configure("TLabelframe", background=C["bg"], foreground=C["text2"])
s.configure("TLabelframe.Label", background=C["bg"],
foreground=C["text2"], font=("Segoe UI", 9, "bold"))
# Scrollbar
s.configure("Vertical.TScrollbar", background=C["card"],
troughcolor=C["bg"], arrowcolor=C["text3"])
def _make_entry(self, parent, font_size=11, dim=False):
"""Create a styled tk.Entry with readable text."""
return tk.Entry(parent, font=("Segoe UI", font_size),
bg=C["input"], fg=C["input_fg"] if not dim else C["text2"],
insertbackground=C["input_fg"],
relief="flat", bd=6,
selectbackground=C["accent"], selectforeground="#ffffff",
highlightthickness=1, highlightcolor=C["accent"],
highlightbackground=C["border"])
def _build_ui(self):
# ── Header ────────────────────────────────────────────────────────
header = tk.Frame(self.root, bg=C["panel"], padx=20, pady=12)
header.pack(fill=tk.X)
tk.Label(header, text="Aniimage", bg=C["panel"], fg=C["accent"],
font=("Segoe UI", 20, "bold")).pack(side=tk.LEFT)
tk.Label(header, text="by 8BitStudio", bg=C["panel"], fg=C["text3"],
font=("Segoe UI", 10)).pack(side=tk.LEFT, padx=(10, 0), pady=(6, 0))
# Device switch β€” right side of header
device_frame = tk.Frame(header, bg=C["panel"])
device_frame.pack(side=tk.RIGHT)
tk.Label(device_frame, text="Device:", bg=C["panel"], fg=C["text2"],
font=("Segoe UI", 9)).pack(side=tk.LEFT, padx=(0, 5))
self.device_var = tk.StringVar(value="GPU" if self.gen.device == "cuda" else "CPU")
devices = ["GPU", "CPU"] if torch.cuda.is_available() else ["CPU"]
device_combo = ttk.Combobox(device_frame, textvariable=self.device_var,
values=devices, state="readonly", width=5)
device_combo.pack(side=tk.LEFT)
device_combo.bind("<<ComboboxSelected>>", self._on_device_change)
# ── Main content β€” two-column: controls left, images right ────────
main = tk.Frame(self.root, bg=C["bg"])
main.pack(fill=tk.BOTH, expand=True, padx=12, pady=(8, 12))
# Left panel (controls)
left = tk.Frame(main, bg=C["panel"], width=340, padx=16, pady=12)
left.pack(side=tk.LEFT, fill=tk.Y, padx=(0, 8))
left.pack_propagate(False)
# Right panel (image grid)
right = tk.Frame(main, bg=C["bg"])
right.pack(side=tk.LEFT, fill=tk.BOTH, expand=True)
self._build_controls(left)
self._build_grid(right)
def _build_controls(self, parent):
# ── Model ─────────────────────────────────────────────────────────
tk.Label(parent, text="Model", bg=C["panel"], fg=C["text2"],
font=("Segoe UI", 9, "bold")).pack(anchor=tk.W)
self.model_var = tk.StringVar()
model_names = [m[1] for m in self.models] or ["No models found"]
self.model_combo = ttk.Combobox(parent, textvariable=self.model_var,
values=model_names, state="readonly", width=32)
self.model_combo.pack(fill=tk.X, pady=(3, 12))
self.model_combo.current(len(model_names) - 1)
# ── Prompt ────────────────────────────────────────────────────────
tk.Label(parent, text="Prompt", bg=C["panel"], fg=C["text2"],
font=("Segoe UI", 9, "bold")).pack(anchor=tk.W)
self.prompt_entry = self._make_entry(parent)
self.prompt_entry.pack(fill=tk.X, pady=(3, 8))
self.prompt_entry.insert(0, "a smiling anime girl with long blue hair")
self.prompt_entry.bind("<Return>", lambda e: self.on_generate())
# ── Negative prompt ───────────────────────────────────────────────
tk.Label(parent, text="Negative prompt", bg=C["panel"], fg=C["text3"],
font=("Segoe UI", 9)).pack(anchor=tk.W)
self.neg_entry = self._make_entry(parent, font_size=9, dim=True)
self.neg_entry.pack(fill=tk.X, pady=(3, 12))
self.neg_entry.insert(0, DEFAULT_NEGATIVE)
# ── Settings grid ─────────────────────────────────────────────────
grid = tk.Frame(parent, bg=C["panel"])
grid.pack(fill=tk.X, pady=(0, 8))
# Row 1: Scheduler
tk.Label(grid, text="Scheduler", bg=C["panel"], fg=C["text2"],
font=("Segoe UI", 9)).grid(row=0, column=0, sticky="w", pady=(0, 6))
self.scheduler_var = tk.StringVar(value="DPM++ 2M Karras")
sched_combo = ttk.Combobox(grid, textvariable=self.scheduler_var,
values=SCHEDULER_LIST, state="readonly", width=18)
sched_combo.grid(row=0, column=1, columnspan=3, sticky="ew", padx=(8, 0), pady=(0, 6))
sched_combo.bind("<<ComboboxSelected>>", self._on_scheduler_change)
# Row 2: Steps, CFG, Count
tk.Label(grid, text="Steps", bg=C["panel"], fg=C["text2"],
font=("Segoe UI", 9)).grid(row=1, column=0, sticky="w", pady=(0, 6))
self.steps_var = tk.StringVar(value="25")
tk.Entry(grid, textvariable=self.steps_var, width=5, font=("Segoe UI", 10),
bg=C["input"], fg=C["input_fg"], insertbackground=C["input_fg"],
relief="flat", bd=4).grid(row=1, column=1, sticky="w", padx=(8, 12), pady=(0, 6))
tk.Label(grid, text="CFG", bg=C["panel"], fg=C["text2"],
font=("Segoe UI", 9)).grid(row=1, column=2, sticky="w", pady=(0, 6))
self.cfg_var = tk.StringVar(value="7.5")
tk.Entry(grid, textvariable=self.cfg_var, width=5, font=("Segoe UI", 10),
bg=C["input"], fg=C["input_fg"], insertbackground=C["input_fg"],
relief="flat", bd=4).grid(row=1, column=3, sticky="w", padx=(8, 0), pady=(0, 6))
# Row 3: Count, Live preview
tk.Label(grid, text="Count", bg=C["panel"], fg=C["text2"],
font=("Segoe UI", 9)).grid(row=2, column=0, sticky="w", pady=(0, 6))
self.count_var = tk.StringVar(value="4")
ttk.Spinbox(grid, from_=1, to=12, textvariable=self.count_var, width=4,
font=("Segoe UI", 10)).grid(row=2, column=1, sticky="w", padx=(8, 12), pady=(0, 6))
self.live_preview_var = tk.BooleanVar(value=False)
ttk.Checkbutton(grid, text="Live preview",
variable=self.live_preview_var).grid(
row=2, column=2, columnspan=2, sticky="w", pady=(0, 6))
grid.columnconfigure(1, weight=1)
grid.columnconfigure(3, weight=1)
# ── Auto quality ──────────────────────────────────────────────────
self.auto_quality_var = tk.BooleanVar(value=False)
ttk.Checkbutton(parent, text="Auto quality (refine if undercooked)",
variable=self.auto_quality_var).pack(anchor=tk.W, pady=(0, 12))
# ── Buttons ───────────────────────────────────────────────────────
btn_frame = tk.Frame(parent, bg=C["panel"])
btn_frame.pack(fill=tk.X, pady=(0, 10))
self.gen_btn = ttk.Button(btn_frame, text="Generate", command=self.on_generate,
style="Go.TButton")
self.gen_btn.pack(fill=tk.X, pady=(0, 5))
btn_row = tk.Frame(btn_frame, bg=C["panel"])
btn_row.pack(fill=tk.X)
self.stop_btn = ttk.Button(btn_row, text="Stop", command=self.on_stop,
state=tk.DISABLED, style="Stop.TButton")
self.stop_btn.pack(side=tk.LEFT, fill=tk.X, expand=True, padx=(0, 3))
self.save_btn = ttk.Button(btn_row, text="Save Selected", command=self.on_save,
state=tk.DISABLED)
self.save_btn.pack(side=tk.LEFT, fill=tk.X, expand=True, padx=(3, 3))
self.save_all_btn = ttk.Button(btn_row, text="Save All", command=self.on_save_all,
state=tk.DISABLED)
self.save_all_btn.pack(side=tk.LEFT, fill=tk.X, expand=True, padx=(3, 0))
# ── Prompt queue ─────────────────────────────────────────────────
sep = tk.Frame(parent, height=1, bg=C["border"])
sep.pack(fill=tk.X, pady=(8, 10))
tk.Label(parent, text="Prompt Queue", bg=C["panel"], fg=C["text2"],
font=("Segoe UI", 9, "bold")).pack(anchor=tk.W)
queue_input = tk.Frame(parent, bg=C["panel"])
queue_input.pack(fill=tk.X, pady=(4, 0))
self.queue_entry = self._make_entry(queue_input, font_size=9)
self.queue_entry.pack(side=tk.LEFT, fill=tk.X, expand=True, padx=(0, 4))
self.queue_entry.bind("<Return>", lambda e: self._queue_add())
ttk.Button(queue_input, text="Add", width=4,
command=self._queue_add).pack(side=tk.LEFT)
self.queue_listbox = tk.Listbox(
parent, height=4, bg=C["input"], fg=C["input_fg"],
selectbackground=C["accent"], selectforeground="#fff",
font=("Segoe UI", 9), activestyle="none",
relief="flat", bd=4, highlightthickness=0)
self.queue_listbox.pack(fill=tk.X, pady=(5, 0))
queue_btns = tk.Frame(parent, bg=C["panel"])
queue_btns.pack(fill=tk.X, pady=(4, 0))
self.queue_run_btn = ttk.Button(queue_btns, text="Run Queue",
command=self.on_run_queue, style="Go.TButton")
self.queue_run_btn.pack(side=tk.LEFT, padx=(0, 4))
for txt, cmd in [("Remove", self._queue_remove), ("Clear", self._queue_clear),
("Up", self._queue_move_up), ("Down", self._queue_move_down),
("+ Current", self._queue_add_current)]:
ttk.Button(queue_btns, text=txt, command=cmd).pack(side=tk.LEFT, padx=2)
# ── Status bar ────────────────────────────────────────────────────
status_frame = tk.Frame(parent, bg=C["bg"], padx=8, pady=6)
status_frame.pack(fill=tk.X, side=tk.BOTTOM)
self.status_var = tk.StringVar(value="Ready")
tk.Label(status_frame, textvariable=self.status_var,
bg=C["bg"], fg=C["green"], font=("Segoe UI", 9),
anchor="w").pack(fill=tk.X)
def _build_grid(self, parent):
self.canvas = tk.Canvas(parent, bg=C["bg"], highlightthickness=0)
scrollbar = ttk.Scrollbar(parent, orient=tk.VERTICAL, command=self.canvas.yview)
self.grid_frame = tk.Frame(self.canvas, bg=C["bg"])
self.grid_frame.bind("<Configure>",
lambda e: self.canvas.configure(
scrollregion=self.canvas.bbox("all")))
self.canvas_window = self.canvas.create_window((0, 0), window=self.grid_frame,
anchor="nw")
self.canvas.configure(yscrollcommand=scrollbar.set)
self.canvas.pack(side=tk.LEFT, fill=tk.BOTH, expand=True)
scrollbar.pack(side=tk.RIGHT, fill=tk.Y)
self.canvas.bind("<Configure>", self._on_canvas_resize)
self.canvas.bind_all("<MouseWheel>",
lambda e: self.canvas.yview_scroll(
int(-1 * (e.delta / 120)), "units"))
self.placeholder = tk.Label(self.grid_frame,
text="Generated images\nwill appear here",
bg=C["bg"], fg=C["text3"],
font=("Segoe UI", 13), justify="center")
self.placeholder.grid(row=0, column=0, pady=80)
# ── Event handlers ────────────────────────────────────────────────────
def _on_device_change(self, event=None):
choice = self.device_var.get()
new_dev = "cuda" if choice == "GPU" else "cpu"
self.status_var.set(f"Switching to {choice}...")
self.root.update()
self.gen.switch_device(new_dev)
self.status_var.set(f"Now using {choice}")
def _on_scheduler_change(self, event=None):
name = self.scheduler_var.get()
self.gen.set_scheduler(name)
self.status_var.set(f"Scheduler: {name}")
def _on_canvas_resize(self, event):
self.canvas.itemconfig(self.canvas_window, width=event.width)
if self.generated_images:
self._layout_grid()
def _get_grid_cols(self):
canvas_w = self.canvas.winfo_width()
if canvas_w < 50:
canvas_w = 560
tile_size = self._get_tile_size()
return max(1, canvas_w // (tile_size + 16))
def _get_tile_size(self):
n = len(self.generated_images)
if n <= 2: return 260
elif n <= 4: return 220
elif n <= 6: return 180
else: return 160
def _layout_grid(self):
for w in self.grid_frame.winfo_children():
w.destroy()
self.photo_refs.clear()
if not self.generated_images:
return
tile_size = self._get_tile_size()
cols = self._get_grid_cols()
for i, (img, seed) in enumerate(zip(self.generated_images, self.generated_seeds)):
row, col = divmod(i, cols)
is_selected = (i == self.selected_index)
card_bg = C["accent"] if is_selected else C["card"]
card = tk.Frame(self.grid_frame, bg=card_bg, padx=3, pady=3)
card.grid(row=row, column=col, padx=5, pady=5, sticky="nsew")
display = img.resize((tile_size, tile_size), Image.LANCZOS)
photo = ImageTk.PhotoImage(display)
self.photo_refs.append(photo)
img_label = tk.Label(card, image=photo, bg=card_bg, bd=0)
img_label.pack()
img_label.bind("<Button-1>", lambda e, idx=i: self._select_image(idx))
img_label.bind("<Button-3>", lambda e, idx=i: self._show_refine_menu(e, idx))
tk.Label(card, text=f"seed: {seed}", bg=card_bg,
fg=C["text3"], font=("Segoe UI", 8)).pack()
for c in range(cols):
self.grid_frame.columnconfigure(c, weight=1)
def _select_image(self, idx):
if idx >= len(self.generated_images):
return
self.selected_index = idx
self.save_btn.configure(state=tk.NORMAL)
self.status_var.set(f"Selected image {idx + 1} (seed: {self.generated_seeds[idx]})")
self._layout_grid()
def _show_refine_menu(self, event, idx):
if self.generating:
return
menu = tk.Menu(self.root, tearoff=0, bg=C["card"], fg=C["text"],
activebackground=C["accent"], activeforeground="#fff",
font=("Segoe UI", 10), bd=0)
menu.add_command(label=" Refine (more steps)... ",
command=lambda: self._ask_refine(idx))
menu.tk_popup(event.x_root, event.y_root)
def _ask_refine(self, idx):
extra = simpledialog.askinteger(
"Refine Image", "Extra denoising steps:",
initialvalue=20, minvalue=5, maxvalue=200, parent=self.root)
if extra is None:
return
self._select_image(idx)
self.generating = True
self.gen.cancelled = False
self.gen_btn.configure(state=tk.DISABLED)
self.stop_btn.configure(state=tk.NORMAL)
self.status_var.set(f"Refining image {idx + 1}...")
self.root.update()
Thread(target=self._refine_thread, args=(idx, extra), daemon=True).start()
def _refine_thread(self, idx, extra_steps):
try:
source = self.generated_images[idx]
prompt = self.prompt_entry.get().strip()
neg = self.neg_entry.get().strip()
cfg = float(self.cfg_var.get())
callback = self._show_preview if self.live_preview_var.get() else None
refined = self.gen.refine(
source_image=source, prompt=prompt, negative_prompt=neg,
extra_steps=extra_steps, guidance_scale=cfg,
preview_callback=callback, preview_every=5)
if refined is not None:
self.generated_images[idx] = refined
self.generated_seeds[idx] = f"{self.generated_seeds[idx]}+R{extra_steps}"
self._layout_grid()
self.status_var.set(f"Refined image {idx + 1}")
else:
self.status_var.set("Refine stopped.")
self.root.update()
except Exception as e:
self.status_var.set(f"Refine error: {e}")
import traceback; traceback.print_exc()
finally:
self.generating = False
self.gen.cancelled = False
self.gen_btn.configure(state=tk.NORMAL)
self.stop_btn.configure(state=tk.DISABLED)
# ── Queue ─────────────────────────────────────────────────────────────
def _queue_add(self):
text = self.queue_entry.get().strip()
if text:
self.queue_listbox.insert(tk.END, text)
self.queue_entry.delete(0, tk.END)
def _queue_add_current(self):
text = self.prompt_entry.get().strip()
if text:
self.queue_listbox.insert(tk.END, text)
def _queue_remove(self):
sel = self.queue_listbox.curselection()
if sel:
self.queue_listbox.delete(sel[0])
def _queue_clear(self):
self.queue_listbox.delete(0, tk.END)
def _queue_move_up(self):
sel = self.queue_listbox.curselection()
if sel and sel[0] > 0:
idx = sel[0]
text = self.queue_listbox.get(idx)
self.queue_listbox.delete(idx)
self.queue_listbox.insert(idx - 1, text)
self.queue_listbox.selection_set(idx - 1)
def _queue_move_down(self):
sel = self.queue_listbox.curselection()
if sel and sel[0] < self.queue_listbox.size() - 1:
idx = sel[0]
text = self.queue_listbox.get(idx)
self.queue_listbox.delete(idx)
self.queue_listbox.insert(idx + 1, text)
self.queue_listbox.selection_set(idx + 1)
def on_run_queue(self):
if self.generating or not self.models:
return
prompts = list(self.queue_listbox.get(0, tk.END))
if not prompts:
self.status_var.set("Queue is empty")
return
self.generating = True
self.gen.cancelled = False
self.gen_btn.configure(state=tk.DISABLED)
self.queue_run_btn.configure(state=tk.DISABLED)
self.stop_btn.configure(state=tk.NORMAL)
Thread(target=self._queue_thread, args=(prompts,), daemon=True).start()
def _queue_thread(self, prompts):
try:
idx = self.model_combo.current()
mdl = self.models[idx]
self.status_var.set(f"Loading {mdl[1]}...")
self.root.update()
self.gen.load_model(mdl[2], mdl[3])
neg = self.neg_entry.get().strip()
steps = int(self.steps_var.get())
cfg = float(self.cfg_var.get())
num_images = max(1, min(12, int(self.count_var.get())))
live_preview = self.live_preview_var.get()
auto_quality = self.auto_quality_var.get()
self.generated_images.clear()
self.generated_seeds.clear()
self.selected_index = None
if self.placeholder:
self.placeholder.destroy()
self.placeholder = None
for p_idx, prompt in enumerate(prompts):
if self.gen.cancelled:
break
self.queue_listbox.selection_clear(0, tk.END)
self.queue_listbox.selection_set(p_idx)
self.queue_listbox.see(p_idx)
for img_i in range(num_images):
if self.gen.cancelled:
break
self.status_var.set(
f"[{p_idx + 1}/{len(prompts)}] image {img_i + 1}/{num_images}")
self.root.update()
callback = None
if live_preview:
self._setup_preview_card()
callback = self._show_preview
if auto_quality:
image, used_seed = self.gen.generate_adaptive(
prompt=prompt, negative_prompt=neg,
base_steps=steps, max_steps=steps + 60,
guidance_scale=cfg,
preview_callback=callback, preview_every=5,
status_callback=lambda m: (
self.status_var.set(m), self.root.update()))
else:
image, used_seed = self.gen.generate(
prompt=prompt, negative_prompt=neg,
steps=steps, guidance_scale=cfg,
preview_callback=callback, preview_every=5)
if image is None:
break
self.generated_images.append(image)
self.generated_seeds.append(used_seed)
save_path = self._next_save_path(prompt)
image.save(save_path)
self._layout_grid()
self.root.update()
if self.gen.cancelled:
break
done = len(self.generated_images)
self.status_var.set(
f"Queue {'stopped' if self.gen.cancelled else 'done'}! {done} images saved.")
if done > 0:
self.save_all_btn.configure(state=tk.NORMAL)
except Exception as e:
self.status_var.set(f"Queue error: {e}")
import traceback; traceback.print_exc()
finally:
self.generating = False
self.gen.cancelled = False
self.gen_btn.configure(state=tk.NORMAL)
self.queue_run_btn.configure(state=tk.NORMAL)
self.stop_btn.configure(state=tk.DISABLED)
# ── Generation ────────────────────────────────────────────────────────
def on_stop(self):
if self.generating:
self.gen.cancelled = True
self.status_var.set("Stopping...")
self.root.update()
def on_generate(self):
if self.generating or not self.models:
return
self.generating = True
self.gen.cancelled = False
self.gen_btn.configure(state=tk.DISABLED)
self.stop_btn.configure(state=tk.NORMAL)
self.status_var.set("Loading model...")
self.root.update()
Thread(target=self._generate_thread, daemon=True).start()
def _setup_preview_card(self):
tile_size = self._get_tile_size()
cols = self._get_grid_cols()
row, col = divmod(len(self.generated_images), cols)
card = tk.Frame(self.grid_frame, bg=C["card"], padx=3, pady=3)
card.grid(row=row, column=col, padx=5, pady=5, sticky="nsew")
self._preview_label = tk.Label(card, bg=C["card"],
width=tile_size, height=tile_size)
self._preview_label.pack()
self.root.update()
def _show_preview(self, preview_img, step, total):
tile_size = self._get_tile_size()
display = preview_img.resize((tile_size, tile_size), Image.LANCZOS)
photo = ImageTk.PhotoImage(display)
self._preview_photo = photo
if hasattr(self, '_preview_label') and self._preview_label.winfo_exists():
self._preview_label.configure(image=photo)
self.status_var.set(f"Step {step}/{total}")
self.root.update()
def _generate_thread(self):
try:
idx = self.model_combo.current()
mdl = self.models[idx]
self.status_var.set(f"Loading {mdl[1]}...")
self.root.update()
self.gen.load_model(mdl[2], mdl[3])
prompt = self.prompt_entry.get().strip()
neg = self.neg_entry.get().strip()
steps = int(self.steps_var.get())
cfg = float(self.cfg_var.get())
num_images = max(1, min(12, int(self.count_var.get())))
live_preview = self.live_preview_var.get()
auto_quality = self.auto_quality_var.get()
self.generated_images.clear()
self.generated_seeds.clear()
self.selected_index = None
if self.placeholder:
self.placeholder.destroy()
self.placeholder = None
for i in range(num_images):
if self.gen.cancelled:
break
self.status_var.set(f"Generating {i + 1}/{num_images}...")
self.root.update()
callback = None
if live_preview:
self._setup_preview_card()
callback = self._show_preview
if auto_quality:
image, used_seed = self.gen.generate_adaptive(
prompt=prompt, negative_prompt=neg,
base_steps=steps, max_steps=steps + 60,
guidance_scale=cfg,
preview_callback=callback, preview_every=5,
status_callback=lambda m: (
self.status_var.set(m), self.root.update()))
else:
image, used_seed = self.gen.generate(
prompt=prompt, negative_prompt=neg,
steps=steps, guidance_scale=cfg,
preview_callback=callback, preview_every=5)
if image is None:
break
self.generated_images.append(image)
self.generated_seeds.append(used_seed)
self._layout_grid()
self.root.update()
done = len(self.generated_images)
if self.gen.cancelled:
self.status_var.set(f"Stopped. {done} image(s) kept.")
else:
self.status_var.set(f"Done! {done} images. Click to select.")
if done > 0:
self.save_all_btn.configure(state=tk.NORMAL)
self.save_btn.configure(state=tk.DISABLED)
except Exception as e:
self.status_var.set(f"Error: {e}")
import traceback; traceback.print_exc()
finally:
self.generating = False
self.gen.cancelled = False
self.gen_btn.configure(state=tk.NORMAL)
self.stop_btn.configure(state=tk.DISABLED)
# ── Save ──────────────────────────────────────────────────────────────
def _next_save_path(self, prompt_text):
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
slug = prompt_text.strip()[:50] if prompt_text.strip() else "untitled"
base = OUTPUT_DIR / f"{slug}.png"
if not base.exists():
return base
n = 1
while True:
path = OUTPUT_DIR / f"{slug} {n}.png"
if not path.exists():
return path
n += 1
def on_save(self):
if self.selected_index is None or not self.generated_images:
return
img = self.generated_images[self.selected_index]
path = self._next_save_path(self.prompt_entry.get().strip())
img.save(path)
self.status_var.set(f"Saved: {path.name}")
def on_save_all(self):
if not self.generated_images:
return
prompt_text = self.prompt_entry.get().strip()
for img in self.generated_images:
path = self._next_save_path(prompt_text)
img.save(path)
self.status_var.set(f"Saved {len(self.generated_images)} images")
def run(self):
self.root.mainloop()
# ── Entry point ───────────────────────────────────────────────────────────────
if __name__ == "__main__":
models = find_models()
if not models:
print("No models found locally. Downloading from HuggingFace...")
result = download_from_hf()
if result:
models = find_models()
if not models:
print("No models found!")
print(f"Place model weights in: {MODEL_DIR}/YourModelName/")
print("Expected files: diffusion_pytorch_model.safetensors or ema_unet.pt")
sys.exit(1)
print(f"Found {len(models)} model(s): {', '.join(m[1] for m in models)}")
print(f"Device: {'CUDA (GPU)' if torch.cuda.is_available() else 'CPU'}")
print("Starting Aniimage...")
app = App()
app.run()