""" Aniimage Generator — Generate anime images from text prompts. https://huggingface.co/8BitStudio/Aniimage-1 Usage: pip install torch torchvision diffusers transformers safetensors pillow huggingface_hub python generate_hf.py """ import os import sys import torch import torch.nn.functional as F import numpy as np import tkinter as tk from tkinter import ttk, simpledialog from pathlib import Path from PIL import Image, ImageTk, ImageEnhance, ImageFilter from threading import Thread # ── Paths ───────────────────────────────────────────────────────────────────── SCRIPT_DIR = Path(__file__).resolve().parent MODEL_DIR = SCRIPT_DIR / "models" OUTPUT_DIR = SCRIPT_DIR / "generated" # ── HuggingFace repo ───────────────────────────────────────────────────────── HF_REPO_ID = "8BitStudio/Aniimage-1" # ── UNet config (must match training) ───────────────────────────────────────── UNET_CONFIG = dict( sample_size=32, in_channels=4, out_channels=4, block_out_channels=(256, 512, 768, 1024), layers_per_block=2, cross_attention_dim=768, attention_head_dim=8, down_block_types=("DownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D"), up_block_types=("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "UpBlock2D"), ) VAE_ID = "stabilityai/sd-vae-ft-mse" CLIP_ID = "openai/clip-vit-large-patch14" SCHEDULER_LIST = [ "DPM++ 2M Karras", "DPM++ SDE Karras", "Euler a", "Euler", "DDIM", ] DEFAULT_NEGATIVE = ( "low quality, ugly, blurry, distorted, deformed, bad anatomy, " "bad proportions, extra limbs, missing limbs, watermark, text, " "signature, washed out, flat colors, manga panel, disfigured, " "poorly drawn, jpeg artifacts, cropped, out of frame" ) # ── Model discovery ─────────────────────────────────────────────────────────── def download_from_hf(): """Download model weights from HuggingFace if not already cached.""" try: from huggingface_hub import hf_hub_download except ImportError: print("Install huggingface_hub: pip install huggingface_hub") return None MODEL_DIR.mkdir(parents=True, exist_ok=True) aniimage_dir = MODEL_DIR / "Aniimage-1" weights_path = aniimage_dir / "diffusion_pytorch_model.safetensors" if weights_path.exists(): print("Aniimage-1 weights already downloaded.") return aniimage_dir print(f"Downloading Aniimage-1 from {HF_REPO_ID}...") aniimage_dir.mkdir(parents=True, exist_ok=True) import shutil dl_weights = hf_hub_download(repo_id=HF_REPO_ID, filename="diffusion_pytorch_model.safetensors") shutil.copy2(dl_weights, weights_path) try: dl_config = hf_hub_download(repo_id=HF_REPO_ID, filename="config.json") shutil.copy2(dl_config, aniimage_dir / "config.json") except Exception: pass print("Download complete!") return aniimage_dir def find_models(): """Find all available models.""" options = [] if MODEL_DIR.exists(): for d in sorted(MODEL_DIR.iterdir()): if d.is_dir(): safetensors = d / "diffusion_pytorch_model.safetensors" ema_path = d / "ema_unet.pt" unet_path = d / "unet.pt" if safetensors.exists(): options.append(("safetensors", d.name, d, "256")) elif ema_path.exists() or unet_path.exists(): options.append(("checkpoint", d.name, d, "256")) return options # ── Theme ───────────────────────────────────────────────────────────────────── C = { "bg": "#111119", "panel": "#1b1b2f", "card": "#24243e", "card_sel": "#3a3a6e", "border": "#2e2e52", "accent": "#6c5ce7", "accent_h": "#8577ed", "red": "#e74c3c", "green": "#2ecc71", "text": "#eaeaea", "text2": "#a0a0b8", "text3": "#60607a", "input": "#16162a", "input_fg": "#dcdcf0", } class Generator: def __init__(self, device="cuda"): self.device = device if device == "cuda" and torch.cuda.is_available() else "cpu" self.vae = None self.text_encoder = None self.tokenizer = None self.unet = None self.scheduler = None self.loaded_checkpoint = None self.latent_size = 32 self.output_size = 256 self.cancelled = False def switch_device(self, new_device): """Move all loaded models to a new device.""" new_device = new_device if new_device == "cuda" and torch.cuda.is_available() else "cpu" if new_device == self.device: return self.device = new_device if self.vae is not None: self.vae = self.vae.to(self.device) if self.text_encoder is not None: self.text_encoder = self.text_encoder.to(self.device) if self.unet is not None: self.unet = self.unet.to(self.device) self.loaded_checkpoint = None # force reload on next generate print(f"Switched to {self.device.upper()}") def load_shared(self): if self.vae is not None: return from diffusers import AutoencoderKL from transformers import CLIPTextModel, CLIPTokenizer print("Loading VAE...") self.vae = AutoencoderKL.from_pretrained(VAE_ID).to(self.device) self.vae.eval() print("Loading CLIP text encoder...") self.tokenizer = CLIPTokenizer.from_pretrained(CLIP_ID) self.text_encoder = CLIPTextModel.from_pretrained(CLIP_ID).to(self.device) self.text_encoder.eval() self.scheduler = self._make_scheduler("DPM++ 2M Karras") self.scheduler_name = "DPM++ 2M Karras" print("Shared models loaded.") def _make_scheduler(self, name="DPM++ 2M Karras"): from diffusers import (DDIMScheduler, DPMSolverMultistepScheduler, EulerAncestralDiscreteScheduler, EulerDiscreteScheduler) base = dict(num_train_timesteps=1000, beta_schedule="scaled_linear", prediction_type="epsilon") if name == "DPM++ 2M Karras": return DPMSolverMultistepScheduler( **base, algorithm_type="dpmsolver++", solver_order=2, use_karras_sigmas=True) elif name == "DPM++ SDE Karras": return DPMSolverMultistepScheduler( **base, algorithm_type="sde-dpmsolver++", use_karras_sigmas=True) elif name == "Euler a": return EulerAncestralDiscreteScheduler(**base) elif name == "Euler": return EulerDiscreteScheduler(**base) else: return DDIMScheduler(**base, clip_sample=False, set_alpha_to_one=False) def set_scheduler(self, name): self.scheduler = self._make_scheduler(name) self.scheduler_name = name def load_model(self, model_path: Path, res_label: str = "256"): if str(model_path) == self.loaded_checkpoint: return from diffusers import UNet2DConditionModel self.load_shared() if res_label == "512": self.latent_size = 64 self.output_size = 512 else: self.latent_size = 32 self.output_size = 256 unet_cfg = dict(UNET_CONFIG) unet_cfg["sample_size"] = self.latent_size print(f"Loading UNet from {model_path.name} ({res_label}px)...") self.unet = UNet2DConditionModel(**unet_cfg).to(self.device) safetensors_path = model_path / "diffusion_pytorch_model.safetensors" ema_path = model_path / "ema_unet.pt" unet_path = model_path / "unet.pt" if safetensors_path.exists(): from safetensors.torch import load_file state = load_file(str(safetensors_path), device=str(self.device)) self.unet.load_state_dict(state) print("Loaded safetensors weights.") elif ema_path.exists(): state = torch.load(ema_path, map_location=self.device, weights_only=True) if "shadow_params" in state: params = dict(self.unet.named_parameters()) keys = list(params.keys()) for i, sp in enumerate(state["shadow_params"]): params[keys[i]].data.copy_(sp) else: self.unet.load_state_dict(state) print("Loaded EMA weights.") elif unet_path.exists(): self.unet.load_state_dict( torch.load(unet_path, map_location=self.device, weights_only=True)) print("Loaded UNet weights.") else: raise FileNotFoundError(f"No weights found in {model_path}") self.unet.eval() self.loaded_checkpoint = str(model_path) print(f"Ready to generate at {self.output_size}x{self.output_size}!") def _decode_latents(self, latents, post_process=False): scaled = latents / self.vae.config.scaling_factor with torch.no_grad(): image = self.vae.decode(scaled.float()).sample image = (image.float() / 2 + 0.5).clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).numpy()[0] image = (image * 255).round().astype("uint8") img = Image.fromarray(image) if post_process: img = self._post_process(img) return img def _sharpen_latents(self, latents, amount=0.08): blurred = F.avg_pool2d(latents, kernel_size=3, stride=1, padding=1) return latents + amount * (latents - blurred) def _post_process(self, img): img = img.filter(ImageFilter.UnsharpMask(radius=1.5, percent=40, threshold=2)) img = ImageEnhance.Contrast(img).enhance(1.06) img = ImageEnhance.Color(img).enhance(1.10) return img def _image_quality_score(self, img: Image.Image) -> float: arr = np.array(img.convert("L"), dtype=np.float32) lap = (np.roll(arr, 1, 0) + np.roll(arr, -1, 0) + np.roll(arr, 1, 1) + np.roll(arr, -1, 1) - 4.0 * arr) sharpness = float(np.var(lap)) arr_rgb = np.array(img, dtype=np.float32) color_var = float(np.mean(np.var(arr_rgb, axis=(0, 1)))) score = (sharpness * 0.6 + color_var * 0.4) return min(100.0, score / 10.0) @torch.no_grad() def generate(self, prompt: str, negative_prompt: str = "", steps: int = 25, guidance_scale: float = 7.5, seed: int = -1, preview_callback=None, preview_every: int = 5) -> tuple: if seed < 0: seed = torch.randint(0, 2**32, (1,)).item() gen = torch.Generator(device=self.device).manual_seed(seed) tok = self.tokenizer(prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt") text_emb = self.text_encoder(tok.input_ids.to(self.device))[0] tok_neg = self.tokenizer(negative_prompt if negative_prompt else "", padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt") neg_emb = self.text_encoder(tok_neg.input_ids.to(self.device))[0] text_emb_combined = torch.cat([neg_emb, text_emb]) scheduler = self._make_scheduler(self.scheduler_name) scheduler.set_timesteps(steps, device=self.device) latents = torch.randn(1, 4, self.latent_size, self.latent_size, generator=gen, device=self.device) latents = latents * scheduler.init_noise_sigma timesteps = scheduler.timesteps total_steps = len(timesteps) for step_i, t in enumerate(timesteps): if self.cancelled: return None, seed latent_input = torch.cat([latents] * 2) latent_input = scheduler.scale_model_input(latent_input, t) with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=(self.device == "cuda")): pred = self.unet(latent_input, t, encoder_hidden_states=text_emb_combined).sample pred_neg, pred_text = pred.chunk(2) pred = pred_neg + guidance_scale * (pred_text - pred_neg) latents = scheduler.step(pred, t, latents).prev_sample if (preview_callback and step_i > 0 and step_i % preview_every == 0 and step_i < total_steps - 1): preview = self._decode_latents(latents, post_process=False) preview_callback(preview, step_i + 1, total_steps) latents = self._sharpen_latents(latents) final = self._decode_latents(latents, post_process=True) return final, seed @torch.no_grad() def generate_adaptive(self, prompt: str, negative_prompt: str = "", base_steps: int = 25, max_steps: int = 85, guidance_scale: float = 7.5, quality_threshold: float = 45.0, preview_callback=None, preview_every: int = 5, status_callback=None) -> tuple: result = self.generate( prompt=prompt, negative_prompt=negative_prompt, steps=base_steps, guidance_scale=guidance_scale, preview_callback=preview_callback, preview_every=preview_every) if result[0] is None: return result image, seed = result quality = self._image_quality_score(image) if status_callback: status_callback(f"Quality: {quality:.1f}/100") if quality >= quality_threshold: return image, seed rounds = 0 max_rounds = (max_steps - base_steps) // 20 while quality < quality_threshold and rounds < max_rounds: if self.cancelled: return image, seed rounds += 1 if status_callback: status_callback(f"Refining +20 steps (round {rounds})...") refined = self.refine( source_image=image, prompt=prompt, negative_prompt=negative_prompt, extra_steps=20, strength=0.3, guidance_scale=guidance_scale, preview_callback=preview_callback, preview_every=5) if refined is None: return image, seed image = refined quality = self._image_quality_score(image) if status_callback: status_callback(f"Quality after round {rounds}: {quality:.1f}/100") return image, seed @torch.no_grad() def refine(self, source_image: Image.Image, prompt: str, negative_prompt: str = "", extra_steps: int = 20, strength: float = 0.35, guidance_scale: float = 7.5, preview_callback=None, preview_every: int = 5) -> Image.Image: img = source_image.resize((self.output_size, self.output_size), Image.LANCZOS) img_tensor = torch.from_numpy(np.array(img)).float().div(127.5).sub(1.0) img_tensor = img_tensor.permute(2, 0, 1).unsqueeze(0).to(self.device) with torch.no_grad(): latents = self.vae.encode(img_tensor.float()).latent_dist.sample() latents = latents * self.vae.config.scaling_factor tok = self.tokenizer(prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt") text_emb = self.text_encoder(tok.input_ids.to(self.device))[0] tok_neg = self.tokenizer(negative_prompt if negative_prompt else "", padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt") neg_emb = self.text_encoder(tok_neg.input_ids.to(self.device))[0] text_emb_combined = torch.cat([neg_emb, text_emb]) scheduler = self._make_scheduler(self.scheduler_name) scheduler.set_timesteps(extra_steps, device=self.device) start_step = max(0, int(len(scheduler.timesteps) * (1 - strength))) timesteps = scheduler.timesteps[start_step:] noise = torch.randn_like(latents) latents = scheduler.add_noise(latents, noise, timesteps[:1]) total_steps = len(timesteps) for step_i, t in enumerate(timesteps): if self.cancelled: return None latent_input = torch.cat([latents] * 2) latent_input = scheduler.scale_model_input(latent_input, t) with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=(self.device == "cuda")): pred = self.unet(latent_input, t, encoder_hidden_states=text_emb_combined).sample pred_neg, pred_text = pred.chunk(2) pred = pred_neg + guidance_scale * (pred_text - pred_neg) latents = scheduler.step(pred, t, latents).prev_sample if (preview_callback and step_i > 0 and step_i % preview_every == 0 and step_i < total_steps - 1): preview = self._decode_latents(latents, post_process=False) preview_callback(preview, step_i + 1, total_steps) latents = self._sharpen_latents(latents) return self._decode_latents(latents, post_process=True) # ── GUI ─────────────────────────────────────────────────────────────────────── class App: def __init__(self): self.gen = Generator() self.models = find_models() self.generated_images = [] self.generated_seeds = [] self.photo_refs = [] self.generating = False self.selected_index = None self.root = tk.Tk() self.root.title("Aniimage") self.root.configure(bg=C["bg"]) self.root.resizable(True, True) self.root.geometry("900x780") self.root.minsize(640, 500) self._setup_styles() self._build_ui() def _setup_styles(self): s = ttk.Style() s.theme_use("clam") # Base s.configure(".", background=C["bg"], foreground=C["text"], font=("Segoe UI", 10)) s.configure("TFrame", background=C["bg"]) s.configure("TLabel", background=C["bg"], foreground=C["text"]) s.configure("TCheckbutton", background=C["bg"], foreground=C["text"]) # Combobox — readable text s.configure("TCombobox", fieldbackground=C["input"], foreground=C["input_fg"], selectbackground=C["accent"], selectforeground="#ffffff", arrowcolor=C["text2"], padding=4) s.map("TCombobox", fieldbackground=[("readonly", C["input"])], foreground=[("readonly", C["input_fg"])], selectbackground=[("readonly", C["accent"])], selectforeground=[("readonly", "#ffffff")]) # Combobox dropdown list colors self.root.option_add("*TCombobox*Listbox.background", C["input"]) self.root.option_add("*TCombobox*Listbox.foreground", C["input_fg"]) self.root.option_add("*TCombobox*Listbox.selectBackground", C["accent"]) self.root.option_add("*TCombobox*Listbox.selectForeground", "#ffffff") self.root.option_add("*TCombobox*Listbox.font", ("Segoe UI", 10)) # Spinbox s.configure("TSpinbox", fieldbackground=C["input"], foreground=C["input_fg"], arrowcolor=C["text2"], padding=3) # Buttons s.configure("TButton", font=("Segoe UI", 10), padding=(14, 7), background=C["card"], foreground=C["text"]) s.map("TButton", background=[("active", C["card_sel"]), ("disabled", C["bg"])], foreground=[("disabled", C["text3"])]) s.configure("Go.TButton", font=("Segoe UI", 11, "bold"), padding=(20, 9), background=C["accent"], foreground="#ffffff") s.map("Go.TButton", background=[("active", C["accent_h"]), ("disabled", C["border"])]) s.configure("Stop.TButton", font=("Segoe UI", 10, "bold"), padding=(14, 7), background=C["red"], foreground="#ffffff") s.map("Stop.TButton", background=[("active", "#c0392b"), ("disabled", C["border"])]) # Labelframe s.configure("TLabelframe", background=C["bg"], foreground=C["text2"]) s.configure("TLabelframe.Label", background=C["bg"], foreground=C["text2"], font=("Segoe UI", 9, "bold")) # Scrollbar s.configure("Vertical.TScrollbar", background=C["card"], troughcolor=C["bg"], arrowcolor=C["text3"]) def _make_entry(self, parent, font_size=11, dim=False): """Create a styled tk.Entry with readable text.""" return tk.Entry(parent, font=("Segoe UI", font_size), bg=C["input"], fg=C["input_fg"] if not dim else C["text2"], insertbackground=C["input_fg"], relief="flat", bd=6, selectbackground=C["accent"], selectforeground="#ffffff", highlightthickness=1, highlightcolor=C["accent"], highlightbackground=C["border"]) def _build_ui(self): # ── Header ──────────────────────────────────────────────────────── header = tk.Frame(self.root, bg=C["panel"], padx=20, pady=12) header.pack(fill=tk.X) tk.Label(header, text="Aniimage", bg=C["panel"], fg=C["accent"], font=("Segoe UI", 20, "bold")).pack(side=tk.LEFT) tk.Label(header, text="by 8BitStudio", bg=C["panel"], fg=C["text3"], font=("Segoe UI", 10)).pack(side=tk.LEFT, padx=(10, 0), pady=(6, 0)) # Device switch — right side of header device_frame = tk.Frame(header, bg=C["panel"]) device_frame.pack(side=tk.RIGHT) tk.Label(device_frame, text="Device:", bg=C["panel"], fg=C["text2"], font=("Segoe UI", 9)).pack(side=tk.LEFT, padx=(0, 5)) self.device_var = tk.StringVar(value="GPU" if self.gen.device == "cuda" else "CPU") devices = ["GPU", "CPU"] if torch.cuda.is_available() else ["CPU"] device_combo = ttk.Combobox(device_frame, textvariable=self.device_var, values=devices, state="readonly", width=5) device_combo.pack(side=tk.LEFT) device_combo.bind("<>", self._on_device_change) # ── Main content — two-column: controls left, images right ──────── main = tk.Frame(self.root, bg=C["bg"]) main.pack(fill=tk.BOTH, expand=True, padx=12, pady=(8, 12)) # Left panel (controls) left = tk.Frame(main, bg=C["panel"], width=340, padx=16, pady=12) left.pack(side=tk.LEFT, fill=tk.Y, padx=(0, 8)) left.pack_propagate(False) # Right panel (image grid) right = tk.Frame(main, bg=C["bg"]) right.pack(side=tk.LEFT, fill=tk.BOTH, expand=True) self._build_controls(left) self._build_grid(right) def _build_controls(self, parent): # ── Model ───────────────────────────────────────────────────────── tk.Label(parent, text="Model", bg=C["panel"], fg=C["text2"], font=("Segoe UI", 9, "bold")).pack(anchor=tk.W) self.model_var = tk.StringVar() model_names = [m[1] for m in self.models] or ["No models found"] self.model_combo = ttk.Combobox(parent, textvariable=self.model_var, values=model_names, state="readonly", width=32) self.model_combo.pack(fill=tk.X, pady=(3, 12)) self.model_combo.current(len(model_names) - 1) # ── Prompt ──────────────────────────────────────────────────────── tk.Label(parent, text="Prompt", bg=C["panel"], fg=C["text2"], font=("Segoe UI", 9, "bold")).pack(anchor=tk.W) self.prompt_entry = self._make_entry(parent) self.prompt_entry.pack(fill=tk.X, pady=(3, 8)) self.prompt_entry.insert(0, "a smiling anime girl with long blue hair") self.prompt_entry.bind("", lambda e: self.on_generate()) # ── Negative prompt ─────────────────────────────────────────────── tk.Label(parent, text="Negative prompt", bg=C["panel"], fg=C["text3"], font=("Segoe UI", 9)).pack(anchor=tk.W) self.neg_entry = self._make_entry(parent, font_size=9, dim=True) self.neg_entry.pack(fill=tk.X, pady=(3, 12)) self.neg_entry.insert(0, DEFAULT_NEGATIVE) # ── Settings grid ───────────────────────────────────────────────── grid = tk.Frame(parent, bg=C["panel"]) grid.pack(fill=tk.X, pady=(0, 8)) # Row 1: Scheduler tk.Label(grid, text="Scheduler", bg=C["panel"], fg=C["text2"], font=("Segoe UI", 9)).grid(row=0, column=0, sticky="w", pady=(0, 6)) self.scheduler_var = tk.StringVar(value="DPM++ 2M Karras") sched_combo = ttk.Combobox(grid, textvariable=self.scheduler_var, values=SCHEDULER_LIST, state="readonly", width=18) sched_combo.grid(row=0, column=1, columnspan=3, sticky="ew", padx=(8, 0), pady=(0, 6)) sched_combo.bind("<>", self._on_scheduler_change) # Row 2: Steps, CFG, Count tk.Label(grid, text="Steps", bg=C["panel"], fg=C["text2"], font=("Segoe UI", 9)).grid(row=1, column=0, sticky="w", pady=(0, 6)) self.steps_var = tk.StringVar(value="25") tk.Entry(grid, textvariable=self.steps_var, width=5, font=("Segoe UI", 10), bg=C["input"], fg=C["input_fg"], insertbackground=C["input_fg"], relief="flat", bd=4).grid(row=1, column=1, sticky="w", padx=(8, 12), pady=(0, 6)) tk.Label(grid, text="CFG", bg=C["panel"], fg=C["text2"], font=("Segoe UI", 9)).grid(row=1, column=2, sticky="w", pady=(0, 6)) self.cfg_var = tk.StringVar(value="7.5") tk.Entry(grid, textvariable=self.cfg_var, width=5, font=("Segoe UI", 10), bg=C["input"], fg=C["input_fg"], insertbackground=C["input_fg"], relief="flat", bd=4).grid(row=1, column=3, sticky="w", padx=(8, 0), pady=(0, 6)) # Row 3: Count, Live preview tk.Label(grid, text="Count", bg=C["panel"], fg=C["text2"], font=("Segoe UI", 9)).grid(row=2, column=0, sticky="w", pady=(0, 6)) self.count_var = tk.StringVar(value="4") ttk.Spinbox(grid, from_=1, to=12, textvariable=self.count_var, width=4, font=("Segoe UI", 10)).grid(row=2, column=1, sticky="w", padx=(8, 12), pady=(0, 6)) self.live_preview_var = tk.BooleanVar(value=False) ttk.Checkbutton(grid, text="Live preview", variable=self.live_preview_var).grid( row=2, column=2, columnspan=2, sticky="w", pady=(0, 6)) grid.columnconfigure(1, weight=1) grid.columnconfigure(3, weight=1) # ── Auto quality ────────────────────────────────────────────────── self.auto_quality_var = tk.BooleanVar(value=False) ttk.Checkbutton(parent, text="Auto quality (refine if undercooked)", variable=self.auto_quality_var).pack(anchor=tk.W, pady=(0, 12)) # ── Buttons ─────────────────────────────────────────────────────── btn_frame = tk.Frame(parent, bg=C["panel"]) btn_frame.pack(fill=tk.X, pady=(0, 10)) self.gen_btn = ttk.Button(btn_frame, text="Generate", command=self.on_generate, style="Go.TButton") self.gen_btn.pack(fill=tk.X, pady=(0, 5)) btn_row = tk.Frame(btn_frame, bg=C["panel"]) btn_row.pack(fill=tk.X) self.stop_btn = ttk.Button(btn_row, text="Stop", command=self.on_stop, state=tk.DISABLED, style="Stop.TButton") self.stop_btn.pack(side=tk.LEFT, fill=tk.X, expand=True, padx=(0, 3)) self.save_btn = ttk.Button(btn_row, text="Save Selected", command=self.on_save, state=tk.DISABLED) self.save_btn.pack(side=tk.LEFT, fill=tk.X, expand=True, padx=(3, 3)) self.save_all_btn = ttk.Button(btn_row, text="Save All", command=self.on_save_all, state=tk.DISABLED) self.save_all_btn.pack(side=tk.LEFT, fill=tk.X, expand=True, padx=(3, 0)) # ── Prompt queue ───────────────────────────────────────────────── sep = tk.Frame(parent, height=1, bg=C["border"]) sep.pack(fill=tk.X, pady=(8, 10)) tk.Label(parent, text="Prompt Queue", bg=C["panel"], fg=C["text2"], font=("Segoe UI", 9, "bold")).pack(anchor=tk.W) queue_input = tk.Frame(parent, bg=C["panel"]) queue_input.pack(fill=tk.X, pady=(4, 0)) self.queue_entry = self._make_entry(queue_input, font_size=9) self.queue_entry.pack(side=tk.LEFT, fill=tk.X, expand=True, padx=(0, 4)) self.queue_entry.bind("", lambda e: self._queue_add()) ttk.Button(queue_input, text="Add", width=4, command=self._queue_add).pack(side=tk.LEFT) self.queue_listbox = tk.Listbox( parent, height=4, bg=C["input"], fg=C["input_fg"], selectbackground=C["accent"], selectforeground="#fff", font=("Segoe UI", 9), activestyle="none", relief="flat", bd=4, highlightthickness=0) self.queue_listbox.pack(fill=tk.X, pady=(5, 0)) queue_btns = tk.Frame(parent, bg=C["panel"]) queue_btns.pack(fill=tk.X, pady=(4, 0)) self.queue_run_btn = ttk.Button(queue_btns, text="Run Queue", command=self.on_run_queue, style="Go.TButton") self.queue_run_btn.pack(side=tk.LEFT, padx=(0, 4)) for txt, cmd in [("Remove", self._queue_remove), ("Clear", self._queue_clear), ("Up", self._queue_move_up), ("Down", self._queue_move_down), ("+ Current", self._queue_add_current)]: ttk.Button(queue_btns, text=txt, command=cmd).pack(side=tk.LEFT, padx=2) # ── Status bar ──────────────────────────────────────────────────── status_frame = tk.Frame(parent, bg=C["bg"], padx=8, pady=6) status_frame.pack(fill=tk.X, side=tk.BOTTOM) self.status_var = tk.StringVar(value="Ready") tk.Label(status_frame, textvariable=self.status_var, bg=C["bg"], fg=C["green"], font=("Segoe UI", 9), anchor="w").pack(fill=tk.X) def _build_grid(self, parent): self.canvas = tk.Canvas(parent, bg=C["bg"], highlightthickness=0) scrollbar = ttk.Scrollbar(parent, orient=tk.VERTICAL, command=self.canvas.yview) self.grid_frame = tk.Frame(self.canvas, bg=C["bg"]) self.grid_frame.bind("", lambda e: self.canvas.configure( scrollregion=self.canvas.bbox("all"))) self.canvas_window = self.canvas.create_window((0, 0), window=self.grid_frame, anchor="nw") self.canvas.configure(yscrollcommand=scrollbar.set) self.canvas.pack(side=tk.LEFT, fill=tk.BOTH, expand=True) scrollbar.pack(side=tk.RIGHT, fill=tk.Y) self.canvas.bind("", self._on_canvas_resize) self.canvas.bind_all("", lambda e: self.canvas.yview_scroll( int(-1 * (e.delta / 120)), "units")) self.placeholder = tk.Label(self.grid_frame, text="Generated images\nwill appear here", bg=C["bg"], fg=C["text3"], font=("Segoe UI", 13), justify="center") self.placeholder.grid(row=0, column=0, pady=80) # ── Event handlers ──────────────────────────────────────────────────── def _on_device_change(self, event=None): choice = self.device_var.get() new_dev = "cuda" if choice == "GPU" else "cpu" self.status_var.set(f"Switching to {choice}...") self.root.update() self.gen.switch_device(new_dev) self.status_var.set(f"Now using {choice}") def _on_scheduler_change(self, event=None): name = self.scheduler_var.get() self.gen.set_scheduler(name) self.status_var.set(f"Scheduler: {name}") def _on_canvas_resize(self, event): self.canvas.itemconfig(self.canvas_window, width=event.width) if self.generated_images: self._layout_grid() def _get_grid_cols(self): canvas_w = self.canvas.winfo_width() if canvas_w < 50: canvas_w = 560 tile_size = self._get_tile_size() return max(1, canvas_w // (tile_size + 16)) def _get_tile_size(self): n = len(self.generated_images) if n <= 2: return 260 elif n <= 4: return 220 elif n <= 6: return 180 else: return 160 def _layout_grid(self): for w in self.grid_frame.winfo_children(): w.destroy() self.photo_refs.clear() if not self.generated_images: return tile_size = self._get_tile_size() cols = self._get_grid_cols() for i, (img, seed) in enumerate(zip(self.generated_images, self.generated_seeds)): row, col = divmod(i, cols) is_selected = (i == self.selected_index) card_bg = C["accent"] if is_selected else C["card"] card = tk.Frame(self.grid_frame, bg=card_bg, padx=3, pady=3) card.grid(row=row, column=col, padx=5, pady=5, sticky="nsew") display = img.resize((tile_size, tile_size), Image.LANCZOS) photo = ImageTk.PhotoImage(display) self.photo_refs.append(photo) img_label = tk.Label(card, image=photo, bg=card_bg, bd=0) img_label.pack() img_label.bind("", lambda e, idx=i: self._select_image(idx)) img_label.bind("", lambda e, idx=i: self._show_refine_menu(e, idx)) tk.Label(card, text=f"seed: {seed}", bg=card_bg, fg=C["text3"], font=("Segoe UI", 8)).pack() for c in range(cols): self.grid_frame.columnconfigure(c, weight=1) def _select_image(self, idx): if idx >= len(self.generated_images): return self.selected_index = idx self.save_btn.configure(state=tk.NORMAL) self.status_var.set(f"Selected image {idx + 1} (seed: {self.generated_seeds[idx]})") self._layout_grid() def _show_refine_menu(self, event, idx): if self.generating: return menu = tk.Menu(self.root, tearoff=0, bg=C["card"], fg=C["text"], activebackground=C["accent"], activeforeground="#fff", font=("Segoe UI", 10), bd=0) menu.add_command(label=" Refine (more steps)... ", command=lambda: self._ask_refine(idx)) menu.tk_popup(event.x_root, event.y_root) def _ask_refine(self, idx): extra = simpledialog.askinteger( "Refine Image", "Extra denoising steps:", initialvalue=20, minvalue=5, maxvalue=200, parent=self.root) if extra is None: return self._select_image(idx) self.generating = True self.gen.cancelled = False self.gen_btn.configure(state=tk.DISABLED) self.stop_btn.configure(state=tk.NORMAL) self.status_var.set(f"Refining image {idx + 1}...") self.root.update() Thread(target=self._refine_thread, args=(idx, extra), daemon=True).start() def _refine_thread(self, idx, extra_steps): try: source = self.generated_images[idx] prompt = self.prompt_entry.get().strip() neg = self.neg_entry.get().strip() cfg = float(self.cfg_var.get()) callback = self._show_preview if self.live_preview_var.get() else None refined = self.gen.refine( source_image=source, prompt=prompt, negative_prompt=neg, extra_steps=extra_steps, guidance_scale=cfg, preview_callback=callback, preview_every=5) if refined is not None: self.generated_images[idx] = refined self.generated_seeds[idx] = f"{self.generated_seeds[idx]}+R{extra_steps}" self._layout_grid() self.status_var.set(f"Refined image {idx + 1}") else: self.status_var.set("Refine stopped.") self.root.update() except Exception as e: self.status_var.set(f"Refine error: {e}") import traceback; traceback.print_exc() finally: self.generating = False self.gen.cancelled = False self.gen_btn.configure(state=tk.NORMAL) self.stop_btn.configure(state=tk.DISABLED) # ── Queue ───────────────────────────────────────────────────────────── def _queue_add(self): text = self.queue_entry.get().strip() if text: self.queue_listbox.insert(tk.END, text) self.queue_entry.delete(0, tk.END) def _queue_add_current(self): text = self.prompt_entry.get().strip() if text: self.queue_listbox.insert(tk.END, text) def _queue_remove(self): sel = self.queue_listbox.curselection() if sel: self.queue_listbox.delete(sel[0]) def _queue_clear(self): self.queue_listbox.delete(0, tk.END) def _queue_move_up(self): sel = self.queue_listbox.curselection() if sel and sel[0] > 0: idx = sel[0] text = self.queue_listbox.get(idx) self.queue_listbox.delete(idx) self.queue_listbox.insert(idx - 1, text) self.queue_listbox.selection_set(idx - 1) def _queue_move_down(self): sel = self.queue_listbox.curselection() if sel and sel[0] < self.queue_listbox.size() - 1: idx = sel[0] text = self.queue_listbox.get(idx) self.queue_listbox.delete(idx) self.queue_listbox.insert(idx + 1, text) self.queue_listbox.selection_set(idx + 1) def on_run_queue(self): if self.generating or not self.models: return prompts = list(self.queue_listbox.get(0, tk.END)) if not prompts: self.status_var.set("Queue is empty") return self.generating = True self.gen.cancelled = False self.gen_btn.configure(state=tk.DISABLED) self.queue_run_btn.configure(state=tk.DISABLED) self.stop_btn.configure(state=tk.NORMAL) Thread(target=self._queue_thread, args=(prompts,), daemon=True).start() def _queue_thread(self, prompts): try: idx = self.model_combo.current() mdl = self.models[idx] self.status_var.set(f"Loading {mdl[1]}...") self.root.update() self.gen.load_model(mdl[2], mdl[3]) neg = self.neg_entry.get().strip() steps = int(self.steps_var.get()) cfg = float(self.cfg_var.get()) num_images = max(1, min(12, int(self.count_var.get()))) live_preview = self.live_preview_var.get() auto_quality = self.auto_quality_var.get() self.generated_images.clear() self.generated_seeds.clear() self.selected_index = None if self.placeholder: self.placeholder.destroy() self.placeholder = None for p_idx, prompt in enumerate(prompts): if self.gen.cancelled: break self.queue_listbox.selection_clear(0, tk.END) self.queue_listbox.selection_set(p_idx) self.queue_listbox.see(p_idx) for img_i in range(num_images): if self.gen.cancelled: break self.status_var.set( f"[{p_idx + 1}/{len(prompts)}] image {img_i + 1}/{num_images}") self.root.update() callback = None if live_preview: self._setup_preview_card() callback = self._show_preview if auto_quality: image, used_seed = self.gen.generate_adaptive( prompt=prompt, negative_prompt=neg, base_steps=steps, max_steps=steps + 60, guidance_scale=cfg, preview_callback=callback, preview_every=5, status_callback=lambda m: ( self.status_var.set(m), self.root.update())) else: image, used_seed = self.gen.generate( prompt=prompt, negative_prompt=neg, steps=steps, guidance_scale=cfg, preview_callback=callback, preview_every=5) if image is None: break self.generated_images.append(image) self.generated_seeds.append(used_seed) save_path = self._next_save_path(prompt) image.save(save_path) self._layout_grid() self.root.update() if self.gen.cancelled: break done = len(self.generated_images) self.status_var.set( f"Queue {'stopped' if self.gen.cancelled else 'done'}! {done} images saved.") if done > 0: self.save_all_btn.configure(state=tk.NORMAL) except Exception as e: self.status_var.set(f"Queue error: {e}") import traceback; traceback.print_exc() finally: self.generating = False self.gen.cancelled = False self.gen_btn.configure(state=tk.NORMAL) self.queue_run_btn.configure(state=tk.NORMAL) self.stop_btn.configure(state=tk.DISABLED) # ── Generation ──────────────────────────────────────────────────────── def on_stop(self): if self.generating: self.gen.cancelled = True self.status_var.set("Stopping...") self.root.update() def on_generate(self): if self.generating or not self.models: return self.generating = True self.gen.cancelled = False self.gen_btn.configure(state=tk.DISABLED) self.stop_btn.configure(state=tk.NORMAL) self.status_var.set("Loading model...") self.root.update() Thread(target=self._generate_thread, daemon=True).start() def _setup_preview_card(self): tile_size = self._get_tile_size() cols = self._get_grid_cols() row, col = divmod(len(self.generated_images), cols) card = tk.Frame(self.grid_frame, bg=C["card"], padx=3, pady=3) card.grid(row=row, column=col, padx=5, pady=5, sticky="nsew") self._preview_label = tk.Label(card, bg=C["card"], width=tile_size, height=tile_size) self._preview_label.pack() self.root.update() def _show_preview(self, preview_img, step, total): tile_size = self._get_tile_size() display = preview_img.resize((tile_size, tile_size), Image.LANCZOS) photo = ImageTk.PhotoImage(display) self._preview_photo = photo if hasattr(self, '_preview_label') and self._preview_label.winfo_exists(): self._preview_label.configure(image=photo) self.status_var.set(f"Step {step}/{total}") self.root.update() def _generate_thread(self): try: idx = self.model_combo.current() mdl = self.models[idx] self.status_var.set(f"Loading {mdl[1]}...") self.root.update() self.gen.load_model(mdl[2], mdl[3]) prompt = self.prompt_entry.get().strip() neg = self.neg_entry.get().strip() steps = int(self.steps_var.get()) cfg = float(self.cfg_var.get()) num_images = max(1, min(12, int(self.count_var.get()))) live_preview = self.live_preview_var.get() auto_quality = self.auto_quality_var.get() self.generated_images.clear() self.generated_seeds.clear() self.selected_index = None if self.placeholder: self.placeholder.destroy() self.placeholder = None for i in range(num_images): if self.gen.cancelled: break self.status_var.set(f"Generating {i + 1}/{num_images}...") self.root.update() callback = None if live_preview: self._setup_preview_card() callback = self._show_preview if auto_quality: image, used_seed = self.gen.generate_adaptive( prompt=prompt, negative_prompt=neg, base_steps=steps, max_steps=steps + 60, guidance_scale=cfg, preview_callback=callback, preview_every=5, status_callback=lambda m: ( self.status_var.set(m), self.root.update())) else: image, used_seed = self.gen.generate( prompt=prompt, negative_prompt=neg, steps=steps, guidance_scale=cfg, preview_callback=callback, preview_every=5) if image is None: break self.generated_images.append(image) self.generated_seeds.append(used_seed) self._layout_grid() self.root.update() done = len(self.generated_images) if self.gen.cancelled: self.status_var.set(f"Stopped. {done} image(s) kept.") else: self.status_var.set(f"Done! {done} images. Click to select.") if done > 0: self.save_all_btn.configure(state=tk.NORMAL) self.save_btn.configure(state=tk.DISABLED) except Exception as e: self.status_var.set(f"Error: {e}") import traceback; traceback.print_exc() finally: self.generating = False self.gen.cancelled = False self.gen_btn.configure(state=tk.NORMAL) self.stop_btn.configure(state=tk.DISABLED) # ── Save ────────────────────────────────────────────────────────────── def _next_save_path(self, prompt_text): OUTPUT_DIR.mkdir(parents=True, exist_ok=True) slug = prompt_text.strip()[:50] if prompt_text.strip() else "untitled" base = OUTPUT_DIR / f"{slug}.png" if not base.exists(): return base n = 1 while True: path = OUTPUT_DIR / f"{slug} {n}.png" if not path.exists(): return path n += 1 def on_save(self): if self.selected_index is None or not self.generated_images: return img = self.generated_images[self.selected_index] path = self._next_save_path(self.prompt_entry.get().strip()) img.save(path) self.status_var.set(f"Saved: {path.name}") def on_save_all(self): if not self.generated_images: return prompt_text = self.prompt_entry.get().strip() for img in self.generated_images: path = self._next_save_path(prompt_text) img.save(path) self.status_var.set(f"Saved {len(self.generated_images)} images") def run(self): self.root.mainloop() # ── Entry point ─────────────────────────────────────────────────────────────── if __name__ == "__main__": models = find_models() if not models: print("No models found locally. Downloading from HuggingFace...") result = download_from_hf() if result: models = find_models() if not models: print("No models found!") print(f"Place model weights in: {MODEL_DIR}/YourModelName/") print("Expected files: diffusion_pytorch_model.safetensors or ema_unet.pt") sys.exit(1) print(f"Found {len(models)} model(s): {', '.join(m[1] for m in models)}") print(f"Device: {'CUDA (GPU)' if torch.cuda.is_available() else 'CPU'}") print("Starting Aniimage...") app = App() app.run()