| """ |
| Aniimage Generator β Generate anime images from text prompts. |
| https://huggingface.co/8BitStudio/Aniimage-1 |
| |
| Usage: |
| pip install torch torchvision diffusers transformers safetensors pillow huggingface_hub |
| python generate_hf.py |
| """ |
|
|
| import os |
| import sys |
| import torch |
| import torch.nn.functional as F |
| import numpy as np |
| import tkinter as tk |
| from tkinter import ttk, simpledialog |
| from pathlib import Path |
| from PIL import Image, ImageTk, ImageEnhance, ImageFilter |
| from threading import Thread |
|
|
| |
| SCRIPT_DIR = Path(__file__).resolve().parent |
| MODEL_DIR = SCRIPT_DIR / "models" |
| OUTPUT_DIR = SCRIPT_DIR / "generated" |
|
|
| |
| HF_REPO_ID = "8BitStudio/Aniimage-1" |
|
|
| |
| UNET_CONFIG = dict( |
| sample_size=32, |
| in_channels=4, |
| out_channels=4, |
| block_out_channels=(256, 512, 768, 1024), |
| layers_per_block=2, |
| cross_attention_dim=768, |
| attention_head_dim=8, |
| down_block_types=("DownBlock2D", "CrossAttnDownBlock2D", |
| "CrossAttnDownBlock2D", "DownBlock2D"), |
| up_block_types=("UpBlock2D", "CrossAttnUpBlock2D", |
| "CrossAttnUpBlock2D", "UpBlock2D"), |
| ) |
|
|
| VAE_ID = "stabilityai/sd-vae-ft-mse" |
| CLIP_ID = "openai/clip-vit-large-patch14" |
|
|
| SCHEDULER_LIST = [ |
| "DPM++ 2M Karras", |
| "DPM++ SDE Karras", |
| "Euler a", |
| "Euler", |
| "DDIM", |
| ] |
|
|
| DEFAULT_NEGATIVE = ( |
| "low quality, ugly, blurry, distorted, deformed, bad anatomy, " |
| "bad proportions, extra limbs, missing limbs, watermark, text, " |
| "signature, washed out, flat colors, manga panel, disfigured, " |
| "poorly drawn, jpeg artifacts, cropped, out of frame" |
| ) |
|
|
|
|
| |
|
|
| def download_from_hf(): |
| """Download model weights from HuggingFace if not already cached.""" |
| try: |
| from huggingface_hub import hf_hub_download |
| except ImportError: |
| print("Install huggingface_hub: pip install huggingface_hub") |
| return None |
|
|
| MODEL_DIR.mkdir(parents=True, exist_ok=True) |
| aniimage_dir = MODEL_DIR / "Aniimage-1" |
| weights_path = aniimage_dir / "diffusion_pytorch_model.safetensors" |
|
|
| if weights_path.exists(): |
| print("Aniimage-1 weights already downloaded.") |
| return aniimage_dir |
|
|
| print(f"Downloading Aniimage-1 from {HF_REPO_ID}...") |
| aniimage_dir.mkdir(parents=True, exist_ok=True) |
|
|
| import shutil |
| dl_weights = hf_hub_download(repo_id=HF_REPO_ID, |
| filename="diffusion_pytorch_model.safetensors") |
| shutil.copy2(dl_weights, weights_path) |
|
|
| try: |
| dl_config = hf_hub_download(repo_id=HF_REPO_ID, filename="config.json") |
| shutil.copy2(dl_config, aniimage_dir / "config.json") |
| except Exception: |
| pass |
|
|
| print("Download complete!") |
| return aniimage_dir |
|
|
|
|
| def find_models(): |
| """Find all available models.""" |
| options = [] |
| if MODEL_DIR.exists(): |
| for d in sorted(MODEL_DIR.iterdir()): |
| if d.is_dir(): |
| safetensors = d / "diffusion_pytorch_model.safetensors" |
| ema_path = d / "ema_unet.pt" |
| unet_path = d / "unet.pt" |
| if safetensors.exists(): |
| options.append(("safetensors", d.name, d, "256")) |
| elif ema_path.exists() or unet_path.exists(): |
| options.append(("checkpoint", d.name, d, "256")) |
| return options |
|
|
|
|
| |
|
|
| C = { |
| "bg": "#111119", |
| "panel": "#1b1b2f", |
| "card": "#24243e", |
| "card_sel": "#3a3a6e", |
| "border": "#2e2e52", |
| "accent": "#6c5ce7", |
| "accent_h": "#8577ed", |
| "red": "#e74c3c", |
| "green": "#2ecc71", |
| "text": "#eaeaea", |
| "text2": "#a0a0b8", |
| "text3": "#60607a", |
| "input": "#16162a", |
| "input_fg": "#dcdcf0", |
| } |
|
|
|
|
| class Generator: |
| def __init__(self, device="cuda"): |
| self.device = device if device == "cuda" and torch.cuda.is_available() else "cpu" |
| self.vae = None |
| self.text_encoder = None |
| self.tokenizer = None |
| self.unet = None |
| self.scheduler = None |
| self.loaded_checkpoint = None |
| self.latent_size = 32 |
| self.output_size = 256 |
| self.cancelled = False |
|
|
| def switch_device(self, new_device): |
| """Move all loaded models to a new device.""" |
| new_device = new_device if new_device == "cuda" and torch.cuda.is_available() else "cpu" |
| if new_device == self.device: |
| return |
| self.device = new_device |
| if self.vae is not None: |
| self.vae = self.vae.to(self.device) |
| if self.text_encoder is not None: |
| self.text_encoder = self.text_encoder.to(self.device) |
| if self.unet is not None: |
| self.unet = self.unet.to(self.device) |
| self.loaded_checkpoint = None |
| print(f"Switched to {self.device.upper()}") |
|
|
| def load_shared(self): |
| if self.vae is not None: |
| return |
| from diffusers import AutoencoderKL |
| from transformers import CLIPTextModel, CLIPTokenizer |
|
|
| print("Loading VAE...") |
| self.vae = AutoencoderKL.from_pretrained(VAE_ID).to(self.device) |
| self.vae.eval() |
|
|
| print("Loading CLIP text encoder...") |
| self.tokenizer = CLIPTokenizer.from_pretrained(CLIP_ID) |
| self.text_encoder = CLIPTextModel.from_pretrained(CLIP_ID).to(self.device) |
| self.text_encoder.eval() |
|
|
| self.scheduler = self._make_scheduler("DPM++ 2M Karras") |
| self.scheduler_name = "DPM++ 2M Karras" |
| print("Shared models loaded.") |
|
|
| def _make_scheduler(self, name="DPM++ 2M Karras"): |
| from diffusers import (DDIMScheduler, DPMSolverMultistepScheduler, |
| EulerAncestralDiscreteScheduler, |
| EulerDiscreteScheduler) |
| base = dict(num_train_timesteps=1000, beta_schedule="scaled_linear", |
| prediction_type="epsilon") |
| if name == "DPM++ 2M Karras": |
| return DPMSolverMultistepScheduler( |
| **base, algorithm_type="dpmsolver++", |
| solver_order=2, use_karras_sigmas=True) |
| elif name == "DPM++ SDE Karras": |
| return DPMSolverMultistepScheduler( |
| **base, algorithm_type="sde-dpmsolver++", |
| use_karras_sigmas=True) |
| elif name == "Euler a": |
| return EulerAncestralDiscreteScheduler(**base) |
| elif name == "Euler": |
| return EulerDiscreteScheduler(**base) |
| else: |
| return DDIMScheduler(**base, clip_sample=False, |
| set_alpha_to_one=False) |
|
|
| def set_scheduler(self, name): |
| self.scheduler = self._make_scheduler(name) |
| self.scheduler_name = name |
|
|
| def load_model(self, model_path: Path, res_label: str = "256"): |
| if str(model_path) == self.loaded_checkpoint: |
| return |
| from diffusers import UNet2DConditionModel |
|
|
| self.load_shared() |
|
|
| if res_label == "512": |
| self.latent_size = 64 |
| self.output_size = 512 |
| else: |
| self.latent_size = 32 |
| self.output_size = 256 |
|
|
| unet_cfg = dict(UNET_CONFIG) |
| unet_cfg["sample_size"] = self.latent_size |
|
|
| print(f"Loading UNet from {model_path.name} ({res_label}px)...") |
| self.unet = UNet2DConditionModel(**unet_cfg).to(self.device) |
|
|
| safetensors_path = model_path / "diffusion_pytorch_model.safetensors" |
| ema_path = model_path / "ema_unet.pt" |
| unet_path = model_path / "unet.pt" |
|
|
| if safetensors_path.exists(): |
| from safetensors.torch import load_file |
| state = load_file(str(safetensors_path), device=str(self.device)) |
| self.unet.load_state_dict(state) |
| print("Loaded safetensors weights.") |
| elif ema_path.exists(): |
| state = torch.load(ema_path, map_location=self.device, weights_only=True) |
| if "shadow_params" in state: |
| params = dict(self.unet.named_parameters()) |
| keys = list(params.keys()) |
| for i, sp in enumerate(state["shadow_params"]): |
| params[keys[i]].data.copy_(sp) |
| else: |
| self.unet.load_state_dict(state) |
| print("Loaded EMA weights.") |
| elif unet_path.exists(): |
| self.unet.load_state_dict( |
| torch.load(unet_path, map_location=self.device, weights_only=True)) |
| print("Loaded UNet weights.") |
| else: |
| raise FileNotFoundError(f"No weights found in {model_path}") |
|
|
| self.unet.eval() |
| self.loaded_checkpoint = str(model_path) |
| print(f"Ready to generate at {self.output_size}x{self.output_size}!") |
|
|
| def _decode_latents(self, latents, post_process=False): |
| scaled = latents / self.vae.config.scaling_factor |
| with torch.no_grad(): |
| image = self.vae.decode(scaled.float()).sample |
| image = (image.float() / 2 + 0.5).clamp(0, 1) |
| image = image.cpu().permute(0, 2, 3, 1).numpy()[0] |
| image = (image * 255).round().astype("uint8") |
| img = Image.fromarray(image) |
| if post_process: |
| img = self._post_process(img) |
| return img |
|
|
| def _sharpen_latents(self, latents, amount=0.08): |
| blurred = F.avg_pool2d(latents, kernel_size=3, stride=1, padding=1) |
| return latents + amount * (latents - blurred) |
|
|
| def _post_process(self, img): |
| img = img.filter(ImageFilter.UnsharpMask(radius=1.5, percent=40, threshold=2)) |
| img = ImageEnhance.Contrast(img).enhance(1.06) |
| img = ImageEnhance.Color(img).enhance(1.10) |
| return img |
|
|
| def _image_quality_score(self, img: Image.Image) -> float: |
| arr = np.array(img.convert("L"), dtype=np.float32) |
| lap = (np.roll(arr, 1, 0) + np.roll(arr, -1, 0) |
| + np.roll(arr, 1, 1) + np.roll(arr, -1, 1) - 4.0 * arr) |
| sharpness = float(np.var(lap)) |
| arr_rgb = np.array(img, dtype=np.float32) |
| color_var = float(np.mean(np.var(arr_rgb, axis=(0, 1)))) |
| score = (sharpness * 0.6 + color_var * 0.4) |
| return min(100.0, score / 10.0) |
|
|
| @torch.no_grad() |
| def generate(self, prompt: str, negative_prompt: str = "", |
| steps: int = 25, guidance_scale: float = 7.5, |
| seed: int = -1, preview_callback=None, |
| preview_every: int = 5) -> tuple: |
|
|
| if seed < 0: |
| seed = torch.randint(0, 2**32, (1,)).item() |
| gen = torch.Generator(device=self.device).manual_seed(seed) |
|
|
| tok = self.tokenizer(prompt, padding="max_length", |
| max_length=self.tokenizer.model_max_length, |
| truncation=True, return_tensors="pt") |
| text_emb = self.text_encoder(tok.input_ids.to(self.device))[0] |
|
|
| tok_neg = self.tokenizer(negative_prompt if negative_prompt else "", |
| padding="max_length", |
| max_length=self.tokenizer.model_max_length, |
| truncation=True, return_tensors="pt") |
| neg_emb = self.text_encoder(tok_neg.input_ids.to(self.device))[0] |
|
|
| text_emb_combined = torch.cat([neg_emb, text_emb]) |
|
|
| scheduler = self._make_scheduler(self.scheduler_name) |
| scheduler.set_timesteps(steps, device=self.device) |
|
|
| latents = torch.randn(1, 4, self.latent_size, self.latent_size, |
| generator=gen, device=self.device) |
| latents = latents * scheduler.init_noise_sigma |
|
|
| timesteps = scheduler.timesteps |
| total_steps = len(timesteps) |
|
|
| for step_i, t in enumerate(timesteps): |
| if self.cancelled: |
| return None, seed |
|
|
| latent_input = torch.cat([latents] * 2) |
| latent_input = scheduler.scale_model_input(latent_input, t) |
|
|
| with torch.autocast(device_type="cuda", dtype=torch.bfloat16, |
| enabled=(self.device == "cuda")): |
| pred = self.unet(latent_input, t, |
| encoder_hidden_states=text_emb_combined).sample |
|
|
| pred_neg, pred_text = pred.chunk(2) |
| pred = pred_neg + guidance_scale * (pred_text - pred_neg) |
|
|
| latents = scheduler.step(pred, t, latents).prev_sample |
|
|
| if (preview_callback and step_i > 0 |
| and step_i % preview_every == 0 |
| and step_i < total_steps - 1): |
| preview = self._decode_latents(latents, post_process=False) |
| preview_callback(preview, step_i + 1, total_steps) |
|
|
| latents = self._sharpen_latents(latents) |
| final = self._decode_latents(latents, post_process=True) |
| return final, seed |
|
|
| @torch.no_grad() |
| def generate_adaptive(self, prompt: str, negative_prompt: str = "", |
| base_steps: int = 25, max_steps: int = 85, |
| guidance_scale: float = 7.5, |
| quality_threshold: float = 45.0, |
| preview_callback=None, preview_every: int = 5, |
| status_callback=None) -> tuple: |
|
|
| result = self.generate( |
| prompt=prompt, negative_prompt=negative_prompt, |
| steps=base_steps, guidance_scale=guidance_scale, |
| preview_callback=preview_callback, preview_every=preview_every) |
|
|
| if result[0] is None: |
| return result |
|
|
| image, seed = result |
| quality = self._image_quality_score(image) |
|
|
| if status_callback: |
| status_callback(f"Quality: {quality:.1f}/100") |
|
|
| if quality >= quality_threshold: |
| return image, seed |
|
|
| rounds = 0 |
| max_rounds = (max_steps - base_steps) // 20 |
|
|
| while quality < quality_threshold and rounds < max_rounds: |
| if self.cancelled: |
| return image, seed |
| rounds += 1 |
| if status_callback: |
| status_callback(f"Refining +20 steps (round {rounds})...") |
|
|
| refined = self.refine( |
| source_image=image, prompt=prompt, |
| negative_prompt=negative_prompt, |
| extra_steps=20, strength=0.3, |
| guidance_scale=guidance_scale, |
| preview_callback=preview_callback, preview_every=5) |
|
|
| if refined is None: |
| return image, seed |
| image = refined |
| quality = self._image_quality_score(image) |
|
|
| if status_callback: |
| status_callback(f"Quality after round {rounds}: {quality:.1f}/100") |
|
|
| return image, seed |
|
|
| @torch.no_grad() |
| def refine(self, source_image: Image.Image, prompt: str, |
| negative_prompt: str = "", extra_steps: int = 20, |
| strength: float = 0.35, guidance_scale: float = 7.5, |
| preview_callback=None, preview_every: int = 5) -> Image.Image: |
|
|
| img = source_image.resize((self.output_size, self.output_size), Image.LANCZOS) |
| img_tensor = torch.from_numpy(np.array(img)).float().div(127.5).sub(1.0) |
| img_tensor = img_tensor.permute(2, 0, 1).unsqueeze(0).to(self.device) |
|
|
| with torch.no_grad(): |
| latents = self.vae.encode(img_tensor.float()).latent_dist.sample() |
| latents = latents * self.vae.config.scaling_factor |
|
|
| tok = self.tokenizer(prompt, padding="max_length", |
| max_length=self.tokenizer.model_max_length, |
| truncation=True, return_tensors="pt") |
| text_emb = self.text_encoder(tok.input_ids.to(self.device))[0] |
|
|
| tok_neg = self.tokenizer(negative_prompt if negative_prompt else "", |
| padding="max_length", |
| max_length=self.tokenizer.model_max_length, |
| truncation=True, return_tensors="pt") |
| neg_emb = self.text_encoder(tok_neg.input_ids.to(self.device))[0] |
| text_emb_combined = torch.cat([neg_emb, text_emb]) |
|
|
| scheduler = self._make_scheduler(self.scheduler_name) |
| scheduler.set_timesteps(extra_steps, device=self.device) |
| start_step = max(0, int(len(scheduler.timesteps) * (1 - strength))) |
| timesteps = scheduler.timesteps[start_step:] |
|
|
| noise = torch.randn_like(latents) |
| latents = scheduler.add_noise(latents, noise, timesteps[:1]) |
|
|
| total_steps = len(timesteps) |
| for step_i, t in enumerate(timesteps): |
| if self.cancelled: |
| return None |
| latent_input = torch.cat([latents] * 2) |
| latent_input = scheduler.scale_model_input(latent_input, t) |
| with torch.autocast(device_type="cuda", dtype=torch.bfloat16, |
| enabled=(self.device == "cuda")): |
| pred = self.unet(latent_input, t, |
| encoder_hidden_states=text_emb_combined).sample |
| pred_neg, pred_text = pred.chunk(2) |
| pred = pred_neg + guidance_scale * (pred_text - pred_neg) |
| latents = scheduler.step(pred, t, latents).prev_sample |
|
|
| if (preview_callback and step_i > 0 |
| and step_i % preview_every == 0 |
| and step_i < total_steps - 1): |
| preview = self._decode_latents(latents, post_process=False) |
| preview_callback(preview, step_i + 1, total_steps) |
|
|
| latents = self._sharpen_latents(latents) |
| return self._decode_latents(latents, post_process=True) |
|
|
|
|
| |
|
|
| class App: |
| def __init__(self): |
| self.gen = Generator() |
| self.models = find_models() |
| self.generated_images = [] |
| self.generated_seeds = [] |
| self.photo_refs = [] |
| self.generating = False |
| self.selected_index = None |
|
|
| self.root = tk.Tk() |
| self.root.title("Aniimage") |
| self.root.configure(bg=C["bg"]) |
| self.root.resizable(True, True) |
| self.root.geometry("900x780") |
| self.root.minsize(640, 500) |
|
|
| self._setup_styles() |
| self._build_ui() |
|
|
| def _setup_styles(self): |
| s = ttk.Style() |
| s.theme_use("clam") |
|
|
| |
| s.configure(".", background=C["bg"], foreground=C["text"], font=("Segoe UI", 10)) |
| s.configure("TFrame", background=C["bg"]) |
| s.configure("TLabel", background=C["bg"], foreground=C["text"]) |
| s.configure("TCheckbutton", background=C["bg"], foreground=C["text"]) |
|
|
| |
| s.configure("TCombobox", fieldbackground=C["input"], foreground=C["input_fg"], |
| selectbackground=C["accent"], selectforeground="#ffffff", |
| arrowcolor=C["text2"], padding=4) |
| s.map("TCombobox", |
| fieldbackground=[("readonly", C["input"])], |
| foreground=[("readonly", C["input_fg"])], |
| selectbackground=[("readonly", C["accent"])], |
| selectforeground=[("readonly", "#ffffff")]) |
| |
| self.root.option_add("*TCombobox*Listbox.background", C["input"]) |
| self.root.option_add("*TCombobox*Listbox.foreground", C["input_fg"]) |
| self.root.option_add("*TCombobox*Listbox.selectBackground", C["accent"]) |
| self.root.option_add("*TCombobox*Listbox.selectForeground", "#ffffff") |
| self.root.option_add("*TCombobox*Listbox.font", ("Segoe UI", 10)) |
|
|
| |
| s.configure("TSpinbox", fieldbackground=C["input"], foreground=C["input_fg"], |
| arrowcolor=C["text2"], padding=3) |
|
|
| |
| s.configure("TButton", font=("Segoe UI", 10), padding=(14, 7), |
| background=C["card"], foreground=C["text"]) |
| s.map("TButton", background=[("active", C["card_sel"]), ("disabled", C["bg"])], |
| foreground=[("disabled", C["text3"])]) |
|
|
| s.configure("Go.TButton", font=("Segoe UI", 11, "bold"), padding=(20, 9), |
| background=C["accent"], foreground="#ffffff") |
| s.map("Go.TButton", background=[("active", C["accent_h"]), |
| ("disabled", C["border"])]) |
|
|
| s.configure("Stop.TButton", font=("Segoe UI", 10, "bold"), padding=(14, 7), |
| background=C["red"], foreground="#ffffff") |
| s.map("Stop.TButton", background=[("active", "#c0392b"), |
| ("disabled", C["border"])]) |
|
|
| |
| s.configure("TLabelframe", background=C["bg"], foreground=C["text2"]) |
| s.configure("TLabelframe.Label", background=C["bg"], |
| foreground=C["text2"], font=("Segoe UI", 9, "bold")) |
|
|
| |
| s.configure("Vertical.TScrollbar", background=C["card"], |
| troughcolor=C["bg"], arrowcolor=C["text3"]) |
|
|
| def _make_entry(self, parent, font_size=11, dim=False): |
| """Create a styled tk.Entry with readable text.""" |
| return tk.Entry(parent, font=("Segoe UI", font_size), |
| bg=C["input"], fg=C["input_fg"] if not dim else C["text2"], |
| insertbackground=C["input_fg"], |
| relief="flat", bd=6, |
| selectbackground=C["accent"], selectforeground="#ffffff", |
| highlightthickness=1, highlightcolor=C["accent"], |
| highlightbackground=C["border"]) |
|
|
| def _build_ui(self): |
| |
| header = tk.Frame(self.root, bg=C["panel"], padx=20, pady=12) |
| header.pack(fill=tk.X) |
|
|
| tk.Label(header, text="Aniimage", bg=C["panel"], fg=C["accent"], |
| font=("Segoe UI", 20, "bold")).pack(side=tk.LEFT) |
| tk.Label(header, text="by 8BitStudio", bg=C["panel"], fg=C["text3"], |
| font=("Segoe UI", 10)).pack(side=tk.LEFT, padx=(10, 0), pady=(6, 0)) |
|
|
| |
| device_frame = tk.Frame(header, bg=C["panel"]) |
| device_frame.pack(side=tk.RIGHT) |
|
|
| tk.Label(device_frame, text="Device:", bg=C["panel"], fg=C["text2"], |
| font=("Segoe UI", 9)).pack(side=tk.LEFT, padx=(0, 5)) |
|
|
| self.device_var = tk.StringVar(value="GPU" if self.gen.device == "cuda" else "CPU") |
| devices = ["GPU", "CPU"] if torch.cuda.is_available() else ["CPU"] |
| device_combo = ttk.Combobox(device_frame, textvariable=self.device_var, |
| values=devices, state="readonly", width=5) |
| device_combo.pack(side=tk.LEFT) |
| device_combo.bind("<<ComboboxSelected>>", self._on_device_change) |
|
|
| |
| main = tk.Frame(self.root, bg=C["bg"]) |
| main.pack(fill=tk.BOTH, expand=True, padx=12, pady=(8, 12)) |
|
|
| |
| left = tk.Frame(main, bg=C["panel"], width=340, padx=16, pady=12) |
| left.pack(side=tk.LEFT, fill=tk.Y, padx=(0, 8)) |
| left.pack_propagate(False) |
|
|
| |
| right = tk.Frame(main, bg=C["bg"]) |
| right.pack(side=tk.LEFT, fill=tk.BOTH, expand=True) |
|
|
| self._build_controls(left) |
| self._build_grid(right) |
|
|
| def _build_controls(self, parent): |
| |
| tk.Label(parent, text="Model", bg=C["panel"], fg=C["text2"], |
| font=("Segoe UI", 9, "bold")).pack(anchor=tk.W) |
|
|
| self.model_var = tk.StringVar() |
| model_names = [m[1] for m in self.models] or ["No models found"] |
| self.model_combo = ttk.Combobox(parent, textvariable=self.model_var, |
| values=model_names, state="readonly", width=32) |
| self.model_combo.pack(fill=tk.X, pady=(3, 12)) |
| self.model_combo.current(len(model_names) - 1) |
|
|
| |
| tk.Label(parent, text="Prompt", bg=C["panel"], fg=C["text2"], |
| font=("Segoe UI", 9, "bold")).pack(anchor=tk.W) |
| self.prompt_entry = self._make_entry(parent) |
| self.prompt_entry.pack(fill=tk.X, pady=(3, 8)) |
| self.prompt_entry.insert(0, "a smiling anime girl with long blue hair") |
| self.prompt_entry.bind("<Return>", lambda e: self.on_generate()) |
|
|
| |
| tk.Label(parent, text="Negative prompt", bg=C["panel"], fg=C["text3"], |
| font=("Segoe UI", 9)).pack(anchor=tk.W) |
| self.neg_entry = self._make_entry(parent, font_size=9, dim=True) |
| self.neg_entry.pack(fill=tk.X, pady=(3, 12)) |
| self.neg_entry.insert(0, DEFAULT_NEGATIVE) |
|
|
| |
| grid = tk.Frame(parent, bg=C["panel"]) |
| grid.pack(fill=tk.X, pady=(0, 8)) |
|
|
| |
| tk.Label(grid, text="Scheduler", bg=C["panel"], fg=C["text2"], |
| font=("Segoe UI", 9)).grid(row=0, column=0, sticky="w", pady=(0, 6)) |
| self.scheduler_var = tk.StringVar(value="DPM++ 2M Karras") |
| sched_combo = ttk.Combobox(grid, textvariable=self.scheduler_var, |
| values=SCHEDULER_LIST, state="readonly", width=18) |
| sched_combo.grid(row=0, column=1, columnspan=3, sticky="ew", padx=(8, 0), pady=(0, 6)) |
| sched_combo.bind("<<ComboboxSelected>>", self._on_scheduler_change) |
|
|
| |
| tk.Label(grid, text="Steps", bg=C["panel"], fg=C["text2"], |
| font=("Segoe UI", 9)).grid(row=1, column=0, sticky="w", pady=(0, 6)) |
| self.steps_var = tk.StringVar(value="25") |
| tk.Entry(grid, textvariable=self.steps_var, width=5, font=("Segoe UI", 10), |
| bg=C["input"], fg=C["input_fg"], insertbackground=C["input_fg"], |
| relief="flat", bd=4).grid(row=1, column=1, sticky="w", padx=(8, 12), pady=(0, 6)) |
|
|
| tk.Label(grid, text="CFG", bg=C["panel"], fg=C["text2"], |
| font=("Segoe UI", 9)).grid(row=1, column=2, sticky="w", pady=(0, 6)) |
| self.cfg_var = tk.StringVar(value="7.5") |
| tk.Entry(grid, textvariable=self.cfg_var, width=5, font=("Segoe UI", 10), |
| bg=C["input"], fg=C["input_fg"], insertbackground=C["input_fg"], |
| relief="flat", bd=4).grid(row=1, column=3, sticky="w", padx=(8, 0), pady=(0, 6)) |
|
|
| |
| tk.Label(grid, text="Count", bg=C["panel"], fg=C["text2"], |
| font=("Segoe UI", 9)).grid(row=2, column=0, sticky="w", pady=(0, 6)) |
| self.count_var = tk.StringVar(value="4") |
| ttk.Spinbox(grid, from_=1, to=12, textvariable=self.count_var, width=4, |
| font=("Segoe UI", 10)).grid(row=2, column=1, sticky="w", padx=(8, 12), pady=(0, 6)) |
|
|
| self.live_preview_var = tk.BooleanVar(value=False) |
| ttk.Checkbutton(grid, text="Live preview", |
| variable=self.live_preview_var).grid( |
| row=2, column=2, columnspan=2, sticky="w", pady=(0, 6)) |
|
|
| grid.columnconfigure(1, weight=1) |
| grid.columnconfigure(3, weight=1) |
|
|
| |
| self.auto_quality_var = tk.BooleanVar(value=False) |
| ttk.Checkbutton(parent, text="Auto quality (refine if undercooked)", |
| variable=self.auto_quality_var).pack(anchor=tk.W, pady=(0, 12)) |
|
|
| |
| btn_frame = tk.Frame(parent, bg=C["panel"]) |
| btn_frame.pack(fill=tk.X, pady=(0, 10)) |
|
|
| self.gen_btn = ttk.Button(btn_frame, text="Generate", command=self.on_generate, |
| style="Go.TButton") |
| self.gen_btn.pack(fill=tk.X, pady=(0, 5)) |
|
|
| btn_row = tk.Frame(btn_frame, bg=C["panel"]) |
| btn_row.pack(fill=tk.X) |
|
|
| self.stop_btn = ttk.Button(btn_row, text="Stop", command=self.on_stop, |
| state=tk.DISABLED, style="Stop.TButton") |
| self.stop_btn.pack(side=tk.LEFT, fill=tk.X, expand=True, padx=(0, 3)) |
|
|
| self.save_btn = ttk.Button(btn_row, text="Save Selected", command=self.on_save, |
| state=tk.DISABLED) |
| self.save_btn.pack(side=tk.LEFT, fill=tk.X, expand=True, padx=(3, 3)) |
|
|
| self.save_all_btn = ttk.Button(btn_row, text="Save All", command=self.on_save_all, |
| state=tk.DISABLED) |
| self.save_all_btn.pack(side=tk.LEFT, fill=tk.X, expand=True, padx=(3, 0)) |
|
|
| |
| sep = tk.Frame(parent, height=1, bg=C["border"]) |
| sep.pack(fill=tk.X, pady=(8, 10)) |
|
|
| tk.Label(parent, text="Prompt Queue", bg=C["panel"], fg=C["text2"], |
| font=("Segoe UI", 9, "bold")).pack(anchor=tk.W) |
|
|
| queue_input = tk.Frame(parent, bg=C["panel"]) |
| queue_input.pack(fill=tk.X, pady=(4, 0)) |
|
|
| self.queue_entry = self._make_entry(queue_input, font_size=9) |
| self.queue_entry.pack(side=tk.LEFT, fill=tk.X, expand=True, padx=(0, 4)) |
| self.queue_entry.bind("<Return>", lambda e: self._queue_add()) |
|
|
| ttk.Button(queue_input, text="Add", width=4, |
| command=self._queue_add).pack(side=tk.LEFT) |
|
|
| self.queue_listbox = tk.Listbox( |
| parent, height=4, bg=C["input"], fg=C["input_fg"], |
| selectbackground=C["accent"], selectforeground="#fff", |
| font=("Segoe UI", 9), activestyle="none", |
| relief="flat", bd=4, highlightthickness=0) |
| self.queue_listbox.pack(fill=tk.X, pady=(5, 0)) |
|
|
| queue_btns = tk.Frame(parent, bg=C["panel"]) |
| queue_btns.pack(fill=tk.X, pady=(4, 0)) |
|
|
| self.queue_run_btn = ttk.Button(queue_btns, text="Run Queue", |
| command=self.on_run_queue, style="Go.TButton") |
| self.queue_run_btn.pack(side=tk.LEFT, padx=(0, 4)) |
|
|
| for txt, cmd in [("Remove", self._queue_remove), ("Clear", self._queue_clear), |
| ("Up", self._queue_move_up), ("Down", self._queue_move_down), |
| ("+ Current", self._queue_add_current)]: |
| ttk.Button(queue_btns, text=txt, command=cmd).pack(side=tk.LEFT, padx=2) |
|
|
| |
| status_frame = tk.Frame(parent, bg=C["bg"], padx=8, pady=6) |
| status_frame.pack(fill=tk.X, side=tk.BOTTOM) |
|
|
| self.status_var = tk.StringVar(value="Ready") |
| tk.Label(status_frame, textvariable=self.status_var, |
| bg=C["bg"], fg=C["green"], font=("Segoe UI", 9), |
| anchor="w").pack(fill=tk.X) |
|
|
| def _build_grid(self, parent): |
| self.canvas = tk.Canvas(parent, bg=C["bg"], highlightthickness=0) |
| scrollbar = ttk.Scrollbar(parent, orient=tk.VERTICAL, command=self.canvas.yview) |
| self.grid_frame = tk.Frame(self.canvas, bg=C["bg"]) |
|
|
| self.grid_frame.bind("<Configure>", |
| lambda e: self.canvas.configure( |
| scrollregion=self.canvas.bbox("all"))) |
| self.canvas_window = self.canvas.create_window((0, 0), window=self.grid_frame, |
| anchor="nw") |
| self.canvas.configure(yscrollcommand=scrollbar.set) |
|
|
| self.canvas.pack(side=tk.LEFT, fill=tk.BOTH, expand=True) |
| scrollbar.pack(side=tk.RIGHT, fill=tk.Y) |
|
|
| self.canvas.bind("<Configure>", self._on_canvas_resize) |
| self.canvas.bind_all("<MouseWheel>", |
| lambda e: self.canvas.yview_scroll( |
| int(-1 * (e.delta / 120)), "units")) |
|
|
| self.placeholder = tk.Label(self.grid_frame, |
| text="Generated images\nwill appear here", |
| bg=C["bg"], fg=C["text3"], |
| font=("Segoe UI", 13), justify="center") |
| self.placeholder.grid(row=0, column=0, pady=80) |
|
|
| |
|
|
| def _on_device_change(self, event=None): |
| choice = self.device_var.get() |
| new_dev = "cuda" if choice == "GPU" else "cpu" |
| self.status_var.set(f"Switching to {choice}...") |
| self.root.update() |
| self.gen.switch_device(new_dev) |
| self.status_var.set(f"Now using {choice}") |
|
|
| def _on_scheduler_change(self, event=None): |
| name = self.scheduler_var.get() |
| self.gen.set_scheduler(name) |
| self.status_var.set(f"Scheduler: {name}") |
|
|
| def _on_canvas_resize(self, event): |
| self.canvas.itemconfig(self.canvas_window, width=event.width) |
| if self.generated_images: |
| self._layout_grid() |
|
|
| def _get_grid_cols(self): |
| canvas_w = self.canvas.winfo_width() |
| if canvas_w < 50: |
| canvas_w = 560 |
| tile_size = self._get_tile_size() |
| return max(1, canvas_w // (tile_size + 16)) |
|
|
| def _get_tile_size(self): |
| n = len(self.generated_images) |
| if n <= 2: return 260 |
| elif n <= 4: return 220 |
| elif n <= 6: return 180 |
| else: return 160 |
|
|
| def _layout_grid(self): |
| for w in self.grid_frame.winfo_children(): |
| w.destroy() |
| self.photo_refs.clear() |
|
|
| if not self.generated_images: |
| return |
|
|
| tile_size = self._get_tile_size() |
| cols = self._get_grid_cols() |
|
|
| for i, (img, seed) in enumerate(zip(self.generated_images, self.generated_seeds)): |
| row, col = divmod(i, cols) |
| is_selected = (i == self.selected_index) |
|
|
| card_bg = C["accent"] if is_selected else C["card"] |
| card = tk.Frame(self.grid_frame, bg=card_bg, padx=3, pady=3) |
| card.grid(row=row, column=col, padx=5, pady=5, sticky="nsew") |
|
|
| display = img.resize((tile_size, tile_size), Image.LANCZOS) |
| photo = ImageTk.PhotoImage(display) |
| self.photo_refs.append(photo) |
|
|
| img_label = tk.Label(card, image=photo, bg=card_bg, bd=0) |
| img_label.pack() |
| img_label.bind("<Button-1>", lambda e, idx=i: self._select_image(idx)) |
| img_label.bind("<Button-3>", lambda e, idx=i: self._show_refine_menu(e, idx)) |
|
|
| tk.Label(card, text=f"seed: {seed}", bg=card_bg, |
| fg=C["text3"], font=("Segoe UI", 8)).pack() |
|
|
| for c in range(cols): |
| self.grid_frame.columnconfigure(c, weight=1) |
|
|
| def _select_image(self, idx): |
| if idx >= len(self.generated_images): |
| return |
| self.selected_index = idx |
| self.save_btn.configure(state=tk.NORMAL) |
| self.status_var.set(f"Selected image {idx + 1} (seed: {self.generated_seeds[idx]})") |
| self._layout_grid() |
|
|
| def _show_refine_menu(self, event, idx): |
| if self.generating: |
| return |
| menu = tk.Menu(self.root, tearoff=0, bg=C["card"], fg=C["text"], |
| activebackground=C["accent"], activeforeground="#fff", |
| font=("Segoe UI", 10), bd=0) |
| menu.add_command(label=" Refine (more steps)... ", |
| command=lambda: self._ask_refine(idx)) |
| menu.tk_popup(event.x_root, event.y_root) |
|
|
| def _ask_refine(self, idx): |
| extra = simpledialog.askinteger( |
| "Refine Image", "Extra denoising steps:", |
| initialvalue=20, minvalue=5, maxvalue=200, parent=self.root) |
| if extra is None: |
| return |
| self._select_image(idx) |
| self.generating = True |
| self.gen.cancelled = False |
| self.gen_btn.configure(state=tk.DISABLED) |
| self.stop_btn.configure(state=tk.NORMAL) |
| self.status_var.set(f"Refining image {idx + 1}...") |
| self.root.update() |
| Thread(target=self._refine_thread, args=(idx, extra), daemon=True).start() |
|
|
| def _refine_thread(self, idx, extra_steps): |
| try: |
| source = self.generated_images[idx] |
| prompt = self.prompt_entry.get().strip() |
| neg = self.neg_entry.get().strip() |
| cfg = float(self.cfg_var.get()) |
| callback = self._show_preview if self.live_preview_var.get() else None |
|
|
| refined = self.gen.refine( |
| source_image=source, prompt=prompt, negative_prompt=neg, |
| extra_steps=extra_steps, guidance_scale=cfg, |
| preview_callback=callback, preview_every=5) |
|
|
| if refined is not None: |
| self.generated_images[idx] = refined |
| self.generated_seeds[idx] = f"{self.generated_seeds[idx]}+R{extra_steps}" |
| self._layout_grid() |
| self.status_var.set(f"Refined image {idx + 1}") |
| else: |
| self.status_var.set("Refine stopped.") |
| self.root.update() |
| except Exception as e: |
| self.status_var.set(f"Refine error: {e}") |
| import traceback; traceback.print_exc() |
| finally: |
| self.generating = False |
| self.gen.cancelled = False |
| self.gen_btn.configure(state=tk.NORMAL) |
| self.stop_btn.configure(state=tk.DISABLED) |
|
|
| |
|
|
| def _queue_add(self): |
| text = self.queue_entry.get().strip() |
| if text: |
| self.queue_listbox.insert(tk.END, text) |
| self.queue_entry.delete(0, tk.END) |
|
|
| def _queue_add_current(self): |
| text = self.prompt_entry.get().strip() |
| if text: |
| self.queue_listbox.insert(tk.END, text) |
|
|
| def _queue_remove(self): |
| sel = self.queue_listbox.curselection() |
| if sel: |
| self.queue_listbox.delete(sel[0]) |
|
|
| def _queue_clear(self): |
| self.queue_listbox.delete(0, tk.END) |
|
|
| def _queue_move_up(self): |
| sel = self.queue_listbox.curselection() |
| if sel and sel[0] > 0: |
| idx = sel[0] |
| text = self.queue_listbox.get(idx) |
| self.queue_listbox.delete(idx) |
| self.queue_listbox.insert(idx - 1, text) |
| self.queue_listbox.selection_set(idx - 1) |
|
|
| def _queue_move_down(self): |
| sel = self.queue_listbox.curselection() |
| if sel and sel[0] < self.queue_listbox.size() - 1: |
| idx = sel[0] |
| text = self.queue_listbox.get(idx) |
| self.queue_listbox.delete(idx) |
| self.queue_listbox.insert(idx + 1, text) |
| self.queue_listbox.selection_set(idx + 1) |
|
|
| def on_run_queue(self): |
| if self.generating or not self.models: |
| return |
| prompts = list(self.queue_listbox.get(0, tk.END)) |
| if not prompts: |
| self.status_var.set("Queue is empty") |
| return |
| self.generating = True |
| self.gen.cancelled = False |
| self.gen_btn.configure(state=tk.DISABLED) |
| self.queue_run_btn.configure(state=tk.DISABLED) |
| self.stop_btn.configure(state=tk.NORMAL) |
| Thread(target=self._queue_thread, args=(prompts,), daemon=True).start() |
|
|
| def _queue_thread(self, prompts): |
| try: |
| idx = self.model_combo.current() |
| mdl = self.models[idx] |
| self.status_var.set(f"Loading {mdl[1]}...") |
| self.root.update() |
| self.gen.load_model(mdl[2], mdl[3]) |
|
|
| neg = self.neg_entry.get().strip() |
| steps = int(self.steps_var.get()) |
| cfg = float(self.cfg_var.get()) |
| num_images = max(1, min(12, int(self.count_var.get()))) |
| live_preview = self.live_preview_var.get() |
| auto_quality = self.auto_quality_var.get() |
|
|
| self.generated_images.clear() |
| self.generated_seeds.clear() |
| self.selected_index = None |
| if self.placeholder: |
| self.placeholder.destroy() |
| self.placeholder = None |
|
|
| for p_idx, prompt in enumerate(prompts): |
| if self.gen.cancelled: |
| break |
| self.queue_listbox.selection_clear(0, tk.END) |
| self.queue_listbox.selection_set(p_idx) |
| self.queue_listbox.see(p_idx) |
|
|
| for img_i in range(num_images): |
| if self.gen.cancelled: |
| break |
| self.status_var.set( |
| f"[{p_idx + 1}/{len(prompts)}] image {img_i + 1}/{num_images}") |
| self.root.update() |
|
|
| callback = None |
| if live_preview: |
| self._setup_preview_card() |
| callback = self._show_preview |
|
|
| if auto_quality: |
| image, used_seed = self.gen.generate_adaptive( |
| prompt=prompt, negative_prompt=neg, |
| base_steps=steps, max_steps=steps + 60, |
| guidance_scale=cfg, |
| preview_callback=callback, preview_every=5, |
| status_callback=lambda m: ( |
| self.status_var.set(m), self.root.update())) |
| else: |
| image, used_seed = self.gen.generate( |
| prompt=prompt, negative_prompt=neg, |
| steps=steps, guidance_scale=cfg, |
| preview_callback=callback, preview_every=5) |
|
|
| if image is None: |
| break |
| self.generated_images.append(image) |
| self.generated_seeds.append(used_seed) |
| save_path = self._next_save_path(prompt) |
| image.save(save_path) |
| self._layout_grid() |
| self.root.update() |
|
|
| if self.gen.cancelled: |
| break |
|
|
| done = len(self.generated_images) |
| self.status_var.set( |
| f"Queue {'stopped' if self.gen.cancelled else 'done'}! {done} images saved.") |
| if done > 0: |
| self.save_all_btn.configure(state=tk.NORMAL) |
|
|
| except Exception as e: |
| self.status_var.set(f"Queue error: {e}") |
| import traceback; traceback.print_exc() |
| finally: |
| self.generating = False |
| self.gen.cancelled = False |
| self.gen_btn.configure(state=tk.NORMAL) |
| self.queue_run_btn.configure(state=tk.NORMAL) |
| self.stop_btn.configure(state=tk.DISABLED) |
|
|
| |
|
|
| def on_stop(self): |
| if self.generating: |
| self.gen.cancelled = True |
| self.status_var.set("Stopping...") |
| self.root.update() |
|
|
| def on_generate(self): |
| if self.generating or not self.models: |
| return |
| self.generating = True |
| self.gen.cancelled = False |
| self.gen_btn.configure(state=tk.DISABLED) |
| self.stop_btn.configure(state=tk.NORMAL) |
| self.status_var.set("Loading model...") |
| self.root.update() |
| Thread(target=self._generate_thread, daemon=True).start() |
|
|
| def _setup_preview_card(self): |
| tile_size = self._get_tile_size() |
| cols = self._get_grid_cols() |
| row, col = divmod(len(self.generated_images), cols) |
| card = tk.Frame(self.grid_frame, bg=C["card"], padx=3, pady=3) |
| card.grid(row=row, column=col, padx=5, pady=5, sticky="nsew") |
| self._preview_label = tk.Label(card, bg=C["card"], |
| width=tile_size, height=tile_size) |
| self._preview_label.pack() |
| self.root.update() |
|
|
| def _show_preview(self, preview_img, step, total): |
| tile_size = self._get_tile_size() |
| display = preview_img.resize((tile_size, tile_size), Image.LANCZOS) |
| photo = ImageTk.PhotoImage(display) |
| self._preview_photo = photo |
| if hasattr(self, '_preview_label') and self._preview_label.winfo_exists(): |
| self._preview_label.configure(image=photo) |
| self.status_var.set(f"Step {step}/{total}") |
| self.root.update() |
|
|
| def _generate_thread(self): |
| try: |
| idx = self.model_combo.current() |
| mdl = self.models[idx] |
| self.status_var.set(f"Loading {mdl[1]}...") |
| self.root.update() |
| self.gen.load_model(mdl[2], mdl[3]) |
|
|
| prompt = self.prompt_entry.get().strip() |
| neg = self.neg_entry.get().strip() |
| steps = int(self.steps_var.get()) |
| cfg = float(self.cfg_var.get()) |
| num_images = max(1, min(12, int(self.count_var.get()))) |
| live_preview = self.live_preview_var.get() |
| auto_quality = self.auto_quality_var.get() |
|
|
| self.generated_images.clear() |
| self.generated_seeds.clear() |
| self.selected_index = None |
| if self.placeholder: |
| self.placeholder.destroy() |
| self.placeholder = None |
|
|
| for i in range(num_images): |
| if self.gen.cancelled: |
| break |
| self.status_var.set(f"Generating {i + 1}/{num_images}...") |
| self.root.update() |
|
|
| callback = None |
| if live_preview: |
| self._setup_preview_card() |
| callback = self._show_preview |
|
|
| if auto_quality: |
| image, used_seed = self.gen.generate_adaptive( |
| prompt=prompt, negative_prompt=neg, |
| base_steps=steps, max_steps=steps + 60, |
| guidance_scale=cfg, |
| preview_callback=callback, preview_every=5, |
| status_callback=lambda m: ( |
| self.status_var.set(m), self.root.update())) |
| else: |
| image, used_seed = self.gen.generate( |
| prompt=prompt, negative_prompt=neg, |
| steps=steps, guidance_scale=cfg, |
| preview_callback=callback, preview_every=5) |
|
|
| if image is None: |
| break |
| self.generated_images.append(image) |
| self.generated_seeds.append(used_seed) |
| self._layout_grid() |
| self.root.update() |
|
|
| done = len(self.generated_images) |
| if self.gen.cancelled: |
| self.status_var.set(f"Stopped. {done} image(s) kept.") |
| else: |
| self.status_var.set(f"Done! {done} images. Click to select.") |
| if done > 0: |
| self.save_all_btn.configure(state=tk.NORMAL) |
| self.save_btn.configure(state=tk.DISABLED) |
|
|
| except Exception as e: |
| self.status_var.set(f"Error: {e}") |
| import traceback; traceback.print_exc() |
| finally: |
| self.generating = False |
| self.gen.cancelled = False |
| self.gen_btn.configure(state=tk.NORMAL) |
| self.stop_btn.configure(state=tk.DISABLED) |
|
|
| |
|
|
| def _next_save_path(self, prompt_text): |
| OUTPUT_DIR.mkdir(parents=True, exist_ok=True) |
| slug = prompt_text.strip()[:50] if prompt_text.strip() else "untitled" |
| base = OUTPUT_DIR / f"{slug}.png" |
| if not base.exists(): |
| return base |
| n = 1 |
| while True: |
| path = OUTPUT_DIR / f"{slug} {n}.png" |
| if not path.exists(): |
| return path |
| n += 1 |
|
|
| def on_save(self): |
| if self.selected_index is None or not self.generated_images: |
| return |
| img = self.generated_images[self.selected_index] |
| path = self._next_save_path(self.prompt_entry.get().strip()) |
| img.save(path) |
| self.status_var.set(f"Saved: {path.name}") |
|
|
| def on_save_all(self): |
| if not self.generated_images: |
| return |
| prompt_text = self.prompt_entry.get().strip() |
| for img in self.generated_images: |
| path = self._next_save_path(prompt_text) |
| img.save(path) |
| self.status_var.set(f"Saved {len(self.generated_images)} images") |
|
|
| def run(self): |
| self.root.mainloop() |
|
|
|
|
| |
|
|
| if __name__ == "__main__": |
| models = find_models() |
| if not models: |
| print("No models found locally. Downloading from HuggingFace...") |
| result = download_from_hf() |
| if result: |
| models = find_models() |
|
|
| if not models: |
| print("No models found!") |
| print(f"Place model weights in: {MODEL_DIR}/YourModelName/") |
| print("Expected files: diffusion_pytorch_model.safetensors or ema_unet.pt") |
| sys.exit(1) |
|
|
| print(f"Found {len(models)} model(s): {', '.join(m[1] for m in models)}") |
| print(f"Device: {'CUDA (GPU)' if torch.cuda.is_available() else 'CPU'}") |
| print("Starting Aniimage...") |
|
|
| app = App() |
| app.run() |
|
|