Spaces:
Sleeping
Sleeping
| """ | |
| PixelMAE v10 β Neural Sprite Engine | |
| Hugging Face Spaces Β· Production UI | |
| """ | |
| import os | |
| import math | |
| import warnings | |
| from typing import Optional, Dict, List, Tuple | |
| from dataclasses import dataclass, field | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from PIL import Image, ImageDraw | |
| import gradio as gr | |
| from huggingface_hub import hf_hub_download | |
| warnings.filterwarnings("ignore") | |
| # ============================================================================ | |
| # REPO CONFIG β edit these two lines | |
| # ============================================================================ | |
| REPO_ID = "MidFord327/PixelArt-MAE-v7" | |
| MODEL_FILES = {"best": "best.pth", "latest": "latest.pth"} | |
| # ============================================================================ | |
| # 1. CONFIGURATION | |
| # ============================================================================ | |
| class Config: | |
| IMAGE_SIZE: int = 16 | |
| CHANNELS: int = 4 | |
| NUM_PIXELS: int = 256 | |
| LATENT_DIM: int = 192 | |
| ENCODER_LAYERS:int = 6 | |
| DECODER_LAYERS:int = 4 | |
| NUM_HEADS: int = 6 | |
| FFN_RATIO: int = 4 | |
| DROPOUT: float = 0.0 | |
| PALETTE_SIZE: int = 16 | |
| USE_EMA: bool = True | |
| EMA_DECAY: float = 0.9995 | |
| EMA_UPDATE_EVERY: int = 1 | |
| DEVICE: torch.device = field( | |
| default_factory=lambda: torch.device("cpu")) | |
| # ============================================================================ | |
| # 2. ARCHITECTURE | |
| # ============================================================================ | |
| def get_2d_sincos_pos_embed(embed_dim: int, grid_size: int) -> torch.Tensor: | |
| grid_h = torch.arange(grid_size, dtype=torch.float32) | |
| grid_w = torch.arange(grid_size, dtype=torch.float32) | |
| grid = torch.meshgrid(grid_h, grid_w, indexing="ij") | |
| grid = torch.stack(grid, dim=0).reshape(2, -1) | |
| half = embed_dim // 4 | |
| omega = 1.0 / (10000 ** (torch.arange(half, dtype=torch.float32) / half)) | |
| out_h = torch.einsum("n,d->nd", grid[0], omega) | |
| out_w = torch.einsum("n,d->nd", grid[1], omega) | |
| pe = torch.zeros(grid_size * grid_size, embed_dim) | |
| pe[:, 0::4] = torch.sin(out_h); pe[:, 1::4] = torch.cos(out_h) | |
| pe[:, 2::4] = torch.sin(out_w); pe[:, 3::4] = torch.cos(out_w) | |
| return pe.unsqueeze(0) | |
| class EMA: | |
| def __init__(self, model, decay=0.9995, update_every=1): | |
| self.model = model | |
| self.decay = decay | |
| self.update_every = update_every | |
| self.step = 0 | |
| self.shadow = {n: p.data.clone() | |
| for n, p in model.named_parameters() if p.requires_grad} | |
| self.backup = {} | |
| def apply_shadow(self): | |
| for name, param in self.model.named_parameters(): | |
| if param.requires_grad and name in self.shadow: | |
| self.backup[name] = param.data.clone() | |
| param.data.copy_(self.shadow[name]) | |
| def restore(self): | |
| for name, param in self.model.named_parameters(): | |
| if name in self.backup: | |
| param.data.copy_(self.backup[name]) | |
| self.backup = {} | |
| def load_state_dict(self, state, device): | |
| self.shadow = {k: v.to(device) for k, v in state["shadow"].items()} | |
| self.step = state["step"] | |
| class MHSA(nn.Module): | |
| def __init__(self, dim, heads, dropout=0.0): | |
| super().__init__() | |
| self.heads = heads | |
| self.head_dim = dim // heads | |
| self.qkv = nn.Linear(dim, dim * 3) | |
| self.proj = nn.Linear(dim, dim) | |
| self.drop = dropout | |
| def forward(self, x): | |
| B, N, C = x.shape | |
| qkv = self.qkv(x).reshape(B, N, 3, self.heads, self.head_dim) | |
| q, k, v = qkv.permute(2, 0, 3, 1, 4).unbind(0) | |
| out = F.scaled_dot_product_attention(q, k, v, dropout_p=0.0) | |
| return self.proj(out.transpose(1, 2).reshape(B, N, C)) | |
| class TransformerBlock(nn.Module): | |
| def __init__(self, dim, heads, ffn_ratio=4, dropout=0.0): | |
| super().__init__() | |
| self.norm1 = nn.LayerNorm(dim) | |
| self.attn = MHSA(dim, heads, dropout) | |
| self.norm2 = nn.LayerNorm(dim) | |
| self.ffn = nn.Sequential( | |
| nn.Linear(dim, dim * ffn_ratio), nn.GELU(), nn.Dropout(dropout), | |
| nn.Linear(dim * ffn_ratio, dim), nn.Dropout(dropout)) | |
| def forward(self, x): | |
| x = x + self.attn(self.norm1(x)) | |
| x = x + self.ffn(self.norm2(x)) | |
| return x | |
| class AsymmetricPixelMAE(nn.Module): | |
| def __init__(self, config: Config): | |
| super().__init__() | |
| self.config = config | |
| D, N = config.LATENT_DIM, config.NUM_PIXELS | |
| self.pixel_embed = nn.Linear(config.CHANNELS, D) | |
| self.embed_norm = nn.LayerNorm(D) | |
| self.pos_embed = nn.Parameter( | |
| get_2d_sincos_pos_embed(D, config.IMAGE_SIZE), requires_grad=False) | |
| self.cls_token = nn.Parameter(torch.zeros(1, 1, D)) | |
| self.mask_token = nn.Parameter(torch.zeros(1, 1, D)) | |
| nn.init.normal_(self.cls_token, std=0.02) | |
| nn.init.normal_(self.mask_token, std=0.02) | |
| self.encoder = nn.ModuleList([ | |
| TransformerBlock(D, config.NUM_HEADS, config.FFN_RATIO, config.DROPOUT) | |
| for _ in range(config.ENCODER_LAYERS)]) | |
| self.encoder_norm = nn.LayerNorm(D) | |
| self.decoder_embed= nn.Linear(D, D) | |
| self.decoder_pos_embed = nn.Parameter(torch.zeros(1, N + 1, D), requires_grad=False) | |
| dp = torch.zeros(1, N + 1, D) | |
| dp[:, 1:, :] = get_2d_sincos_pos_embed(D, config.IMAGE_SIZE) | |
| self.decoder_pos_embed.data.copy_(dp) | |
| self.decoder = nn.ModuleList([ | |
| TransformerBlock(D, config.NUM_HEADS, config.FFN_RATIO, config.DROPOUT) | |
| for _ in range(config.DECODER_LAYERS)]) | |
| self.decoder_norm = nn.LayerNorm(D) | |
| self.pixel_head_rgb = nn.Sequential( | |
| nn.LayerNorm(D), nn.Linear(D, D), nn.GELU(), nn.Linear(D, 3)) | |
| self.pixel_head_alpha = nn.Sequential( | |
| nn.LayerNorm(D), nn.Linear(D, D), nn.GELU(), nn.Linear(D, 1)) | |
| self.palette_head = nn.Sequential( | |
| nn.LayerNorm(D), nn.Linear(D, D), nn.GELU(), | |
| nn.Linear(D, config.PALETTE_SIZE * config.CHANNELS), nn.Sigmoid()) | |
| def _tokenize(self, x): | |
| return self.embed_norm(self.pixel_embed(x.flatten(2).transpose(1, 2))) | |
| def forward_encoder(self, x, mask=None): | |
| B, N, D = x.shape[0], self.config.NUM_PIXELS, self.config.LATENT_DIM | |
| tokens = self._tokenize(x) + self.pos_embed | |
| if mask is not None: | |
| mf = mask.flatten(1) | |
| if mf.shape[1] != N: mf = mask.reshape(B, -1) | |
| if mf.dim() == 3: mf = mf.squeeze(1) | |
| noise = torch.rand(B, N, device=x.device) | |
| ids_shuf = torch.argsort(noise + mf.float() * 1e6, dim=1) | |
| ids_rest = torch.argsort(ids_shuf, dim=1) | |
| n_vis = max(1, int((1 - mf.float()).sum(dim=1).min().item())) | |
| ids_keep = ids_shuf[:, :n_vis] | |
| tokens = torch.gather(tokens, 1, ids_keep.unsqueeze(-1).expand(-1, -1, D)) | |
| else: | |
| n_vis = N | |
| ids_rest = torch.arange(N, device=x.device).unsqueeze(0).expand(B, -1) | |
| cls = self.cls_token.expand(B, -1, -1) | |
| x_enc = torch.cat([cls, tokens], dim=1) | |
| for blk in self.encoder: x_enc = blk(x_enc) | |
| return self.encoder_norm(x_enc), ids_rest, n_vis | |
| def forward_decoder(self, x_enc, ids_rest, n_vis): | |
| B, N, D = x_enc.shape[0], self.config.NUM_PIXELS, self.config.LATENT_DIM | |
| cls_enc, vis_enc = x_enc[:, :1], x_enc[:, 1:] | |
| vis_dec = self.decoder_embed(vis_enc) | |
| n_masked = N - n_vis | |
| if n_masked > 0: | |
| mask_tok = self.mask_token.expand(B, n_masked, -1) | |
| full_seq = torch.cat([vis_dec, mask_tok], dim=1) | |
| else: | |
| full_seq = vis_dec | |
| full_seq = torch.gather(full_seq, 1, ids_rest.unsqueeze(-1).expand(-1, -1, D)) | |
| full_seq = torch.cat([self.decoder_embed(cls_enc), full_seq], dim=1) + self.decoder_pos_embed | |
| for blk in self.decoder: full_seq = blk(full_seq) | |
| return self.decoder_norm(full_seq)[:, 1:] | |
| def forward(self, x, mask=None): | |
| B, C, H, W = x.shape | |
| enc, ids, n_vis = self.forward_encoder(x, mask) | |
| dec = self.forward_decoder(enc, ids, n_vis) | |
| rgb = torch.sigmoid(self.pixel_head_rgb(dec)).transpose(1, 2).reshape(B, 3, H, W) | |
| alp = torch.sigmoid(self.pixel_head_alpha(dec)).transpose(1, 2).reshape(B, 1, H, W) | |
| return {"pixel_pred": torch.cat([rgb, alp], dim=1)} | |
| # ============================================================================ | |
| # 3. INFERENCE ENGINE (extended) | |
| # ============================================================================ | |
| class InferenceEngine: | |
| def __init__(self, config: Config, device: torch.device): | |
| self.config = config | |
| self.device = device | |
| self.model = AsymmetricPixelMAE(config).to(device) | |
| self.ema = EMA(self.model, config.EMA_DECAY, config.EMA_UPDATE_EVERY) if config.USE_EMA else None | |
| self.ckpt_meta: Dict = {} | |
| def load_checkpoint(self, path: str) -> Dict: | |
| ckpt = torch.load(path, map_location=self.device, weights_only=False) | |
| state = {k.replace("module.", "").replace("_orig_mod.", ""): v | |
| for k, v in ckpt["model"].items()} | |
| self.model.load_state_dict(state, strict=False) | |
| if self.ema and ckpt.get("ema"): | |
| try: self.ema.load_state_dict(ckpt["ema"], self.device) | |
| except Exception: pass | |
| self.ckpt_meta = { | |
| "epoch": ckpt.get("epoch", "?"), | |
| "best_loss": ckpt.get("best_loss", None), | |
| "has_ema": bool(self.ema and ckpt.get("ema")), | |
| } | |
| return self.ckpt_meta | |
| # ββ Core helpers ββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _apply_ema(self): | |
| if self.ema: self.ema.apply_shadow() | |
| def _restore_ema(self): | |
| if self.ema: self.ema.restore() | |
| # ββ Inpaint βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def inpaint(self, image: torch.Tensor, mask: torch.Tensor, | |
| use_ema: bool = True) -> torch.Tensor: | |
| self.model.eval() | |
| if image.dim() == 3: image = image.unsqueeze(0) | |
| if mask.dim() == 3: mask = mask.unsqueeze(0) | |
| if use_ema: self._apply_ema() | |
| try: | |
| pred = self.model(image, mask=mask.float())["pixel_pred"] | |
| result = torch.where(mask.expand_as(image).bool(), pred, image) | |
| finally: | |
| if use_ema: self._restore_ema() | |
| return result.clamp(0, 1) | |
| # ββ Restore βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def restore(self, image: torch.Tensor, strength: float = 0.5, | |
| use_ema: bool = True) -> torch.Tensor: | |
| self.model.eval() | |
| if image.dim() == 3: image = image.unsqueeze(0) | |
| B, C, H, W = image.shape | |
| ratio = 0.05 + strength * 0.50 | |
| mask = (torch.rand(B, 1, H, W, device=image.device) < ratio).float() | |
| return self.inpaint(image, mask, use_ema=use_ema) | |
| def restore_multi_pass(self, image: torch.Tensor, strength: float = 0.5, | |
| passes: int = 3, use_ema: bool = True) -> torch.Tensor: | |
| """Iterative multi-pass restore β each pass uses slightly less strength.""" | |
| result = image.clone() | |
| for i in range(passes): | |
| s = strength * (1.0 - i * 0.1) | |
| result = self.restore(result, strength=max(0.05, s), use_ema=use_ema) | |
| return result | |
| # ββ Generate from scratch βββββββββββββββββββββββββββββββββββββββββββββ | |
| def generate(self, n_samples: int = 1, n_steps: int = 12, | |
| seed_image: Optional[torch.Tensor] = None, | |
| seed_ratio: float = 0.0, | |
| temperature: float = 1.0, | |
| use_ema: bool = True) -> torch.Tensor: | |
| self.model.eval() | |
| if use_ema: self._apply_ema() | |
| try: | |
| H = W = self.config.IMAGE_SIZE | |
| C, N = self.config.CHANNELS, self.config.NUM_PIXELS | |
| dev = self.device | |
| B = n_samples | |
| if seed_image is not None: | |
| if seed_image.dim() == 3: seed_image = seed_image.unsqueeze(0) | |
| seed_image = seed_image.expand(B, -1, -1, -1) | |
| canvas = seed_image.clone() | |
| keep = (torch.rand(B, 1, H, W, device=dev) < seed_ratio) | |
| else: | |
| canvas = torch.rand(B, C, H, W, device=dev) * 0.5 | |
| keep = torch.zeros(B, 1, H, W, device=dev, dtype=torch.bool) | |
| revealed = keep.float() | |
| for step in range(n_steps): | |
| mask = 1.0 - revealed | |
| n_unrevealed = int(mask.sum().item() / B) | |
| if n_unrevealed == 0: break | |
| pred = self.model(canvas, mask=mask)["pixel_pred"] | |
| # temperature scaling β add noise before confidence scoring | |
| if temperature != 1.0: | |
| noise = torch.randn_like(pred) * (temperature - 1.0) * 0.05 | |
| pred = (pred + noise).clamp(0, 1) | |
| canvas = torch.where(mask.expand_as(canvas).bool(), pred, canvas) | |
| if step < n_steps - 1: | |
| confidence = torch.abs(pred - 0.5).mean(dim=1, keepdim=True) | |
| confidence = confidence * mask + torch.rand_like(confidence) * 0.05 | |
| progress = (step + 1) / n_steps | |
| target_frac = 1.0 - 0.5 * (1 + math.cos(math.pi * progress)) | |
| cur_revealed = int(revealed.sum().item() / B) | |
| n_reveal = max(1, int(target_frac * N) - cur_revealed) | |
| n_reveal = min(n_reveal, n_unrevealed) | |
| flat_conf = confidence.reshape(B, -1).masked_fill( | |
| mask.reshape(B, -1) < 0.5, -float("inf")) | |
| _, top = flat_conf.topk(min(n_reveal, N), dim=1) | |
| new_rev = torch.zeros(B, N, device=dev).scatter_(1, top, 1.0) | |
| revealed = (revealed + new_rev.reshape(B, 1, H, W)).clamp(0, 1) | |
| finally: | |
| if use_ema: self._restore_ema() | |
| return canvas.clamp(0, 1) | |
| # ββ Variation βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def generate_variation(self, image: torch.Tensor, | |
| diversity: float = 0.5, | |
| n_samples: int = 1, | |
| n_steps: int = 16, | |
| use_ema: bool = True) -> torch.Tensor: | |
| """Generate variations of an existing sprite. | |
| diversity=0 β near-identical copy; diversity=1 β fully free generation.""" | |
| seed_ratio = 1.0 - diversity | |
| return self.generate( | |
| n_samples=n_samples, n_steps=n_steps, | |
| seed_image=image, seed_ratio=seed_ratio, | |
| use_ema=use_ema) | |
| # ββ Palette extraction from CLS token βββββββββββββββββββββββββββββββββ | |
| def extract_palette(self, image: torch.Tensor, | |
| use_ema: bool = True) -> torch.Tensor: | |
| """Returns (PALETTE_SIZE, 4) RGBA palette from the model's palette head.""" | |
| self.model.eval() | |
| if image.dim() == 3: image = image.unsqueeze(0) | |
| if use_ema: self._apply_ema() | |
| try: | |
| enc, _, _ = self.model.forward_encoder(image) | |
| cls_tok = enc[:, :1, :] # (1, 1, D) | |
| palette = self.model.palette_head(cls_tok) # (1, 1, P*C) | |
| palette = palette.reshape(self.config.PALETTE_SIZE, self.config.CHANNELS) | |
| finally: | |
| if use_ema: self._restore_ema() | |
| return palette.clamp(0, 1) | |
| # ============================================================================ | |
| # 4. INIT ENGINE + MODEL MANAGEMENT | |
| # ============================================================================ | |
| print("Initializing PixelMAE Inference Engine β¦") | |
| config = Config() | |
| engine = InferenceEngine(config, config.DEVICE) | |
| # Global state | |
| _loaded_model_key = None | |
| _ckpt_info_text = "No model loaded." | |
| _model_paths: Dict[str, Optional[str]] = {"best": None, "latest": None} | |
| def _download_model(key: str) -> Tuple[Optional[str], str]: | |
| filename = MODEL_FILES[key] | |
| try: | |
| path = hf_hub_download(repo_id=REPO_ID, filename=filename) | |
| return path, f"β Downloaded `{filename}` from `{REPO_ID}`." | |
| except Exception as e: | |
| return None, f"β Could not download `{filename}`: {e}" | |
| def load_model(model_choice: str) -> str: | |
| global _loaded_model_key, _ckpt_info_text, _model_paths | |
| key = "best" if model_choice == "Best (best.pth)" else "latest" | |
| if _model_paths[key] is None: | |
| path, msg = _download_model(key) | |
| if path is None: | |
| _ckpt_info_text = msg | |
| return _ckpt_info_text | |
| _model_paths[key] = path | |
| try: | |
| meta = engine.load_checkpoint(_model_paths[key]) | |
| loss_str = f"{meta['best_loss']:.6f}" if meta["best_loss"] is not None else "N/A" | |
| _ckpt_info_text = ( | |
| f"**Model:** `{MODEL_FILES[key]}` \n" | |
| f"**Epoch:** {meta['epoch']} \n" | |
| f"**Best Val Loss:** {loss_str} \n" | |
| f"**EMA Weights:** {'β loaded' if meta['has_ema'] else 'β οΈ not found'} \n" | |
| f"**Device:** {config.DEVICE}" | |
| ) | |
| _loaded_model_key = key | |
| except Exception as e: | |
| _ckpt_info_text = f"β Load error: {e}" | |
| return _ckpt_info_text | |
| # Pre-load best model at startup (non-fatal) | |
| try: | |
| _path, _msg = _download_model("best") | |
| if _path: | |
| _model_paths["best"] = _path | |
| load_model("Best (best.pth)") | |
| print(f" ββ {_ckpt_info_text.replace(chr(10), ' ')}") | |
| else: | |
| print(f" ββ {_msg}") | |
| except Exception as ex: | |
| print(f" ββ Startup load failed: {ex}") | |
| # ============================================================================ | |
| # 5. IMAGE UTILITIES | |
| # ============================================================================ | |
| def preprocess(pil_img: Image.Image) -> torch.Tensor: | |
| img = pil_img.convert("RGBA").resize( | |
| (config.IMAGE_SIZE, config.IMAGE_SIZE), Image.Resampling.NEAREST) | |
| arr = np.array(img, np.float32) / 255.0 | |
| return torch.from_numpy(arr).permute(2, 0, 1).unsqueeze(0).to(config.DEVICE) | |
| def postprocess(tensor: torch.Tensor, upscale: int = 16) -> Image.Image: | |
| """Convert tensor β pixel-perfect upscaled PIL image (RGBA).""" | |
| if tensor is None: return None | |
| if tensor.dim() == 4: tensor = tensor.squeeze(0) | |
| arr = (tensor.permute(1, 2, 0).cpu().clamp(0, 1).numpy() * 255).astype(np.uint8) | |
| img = Image.fromarray(arr, "RGBA") | |
| return img.resize( | |
| (config.IMAGE_SIZE * upscale, config.IMAGE_SIZE * upscale), | |
| Image.Resampling.NEAREST) | |
| def tensor_to_pils(batch: torch.Tensor, upscale: int = 16) -> List[Image.Image]: | |
| """Batch tensor β list of PIL images.""" | |
| if batch.dim() == 3: batch = batch.unsqueeze(0) | |
| return [postprocess(batch[i], upscale) for i in range(batch.shape[0])] | |
| def make_palette_image(palette: torch.Tensor, swatch_size: int = 32) -> Image.Image: | |
| """Render palette as a row of color swatches.""" | |
| n = palette.shape[0] | |
| img = Image.new("RGBA", (n * swatch_size, swatch_size), (0, 0, 0, 0)) | |
| draw = ImageDraw.Draw(img) | |
| for i, color in enumerate(palette): | |
| rgba = tuple((color.cpu().numpy() * 255).astype(int).tolist()) | |
| draw.rectangle([i * swatch_size, 0, (i + 1) * swatch_size - 1, swatch_size - 1], | |
| fill=rgba) | |
| return img | |
| def compare_images(before: Image.Image, after: Image.Image, | |
| label_size: int = 10) -> Image.Image: | |
| """Side-by-side comparison with labels.""" | |
| w, h = before.size | |
| total = Image.new("RGBA", (w * 2 + 4, h), (20, 20, 20, 255)) | |
| total.paste(before, (0, 0)) | |
| total.paste(after, (w + 4, 0)) | |
| return total | |
| def extract_unique_colors(img: Image.Image, max_colors: int = 32) -> List[Tuple]: | |
| """Extract sorted unique RGBA colors (ignoring transparent).""" | |
| arr = np.array(img.convert("RGBA")) | |
| pixels = arr.reshape(-1, 4) | |
| pixels = pixels[pixels[:, 3] > 10] # ignore near-transparent | |
| unique = np.unique(pixels, axis=0) | |
| return [tuple(c) for c in unique[:max_colors]] | |
| # ============================================================================ | |
| # 6. GRADIO FUNCTIONS | |
| # ============================================================================ | |
| def fn_load_model(choice): | |
| return load_model(choice) | |
| def fn_generate(seed_img, seed_ratio, steps, n_samples, | |
| temperature, use_ema, upscale): | |
| n_samples = int(n_samples) | |
| steps = int(steps) | |
| upscale = int(upscale) | |
| seed_t = preprocess(seed_img) if seed_img is not None else None | |
| out = engine.generate( | |
| n_samples=n_samples, n_steps=steps, | |
| seed_image=seed_t, seed_ratio=seed_ratio, | |
| temperature=temperature, use_ema=use_ema) | |
| pils = tensor_to_pils(out, upscale) | |
| # Return gallery items + first image separately + palette | |
| palette = engine.extract_palette(out[0:1], use_ema=use_ema) | |
| pal_img = make_palette_image(palette, swatch_size=40) | |
| gallery = [(p, f"Sample {i+1}") for i, p in enumerate(pils)] | |
| return gallery, pils[0], pal_img | |
| def fn_generate_single(seed_img, seed_ratio, steps, temperature, use_ema, upscale): | |
| upscale = int(upscale) | |
| steps = int(steps) | |
| seed_t = preprocess(seed_img) if seed_img is not None else None | |
| out = engine.generate( | |
| n_samples=1, n_steps=steps, | |
| seed_image=seed_t, seed_ratio=seed_ratio, | |
| temperature=temperature, use_ema=use_ema) | |
| pil = postprocess(out[0], upscale) | |
| palette = engine.extract_palette(out[0:1], use_ema=use_ema) | |
| pal_img = make_palette_image(palette, swatch_size=40) | |
| return pil, pal_img | |
| def fn_restore(image, strength, passes, use_ema, upscale): | |
| if image is None: return None, None, None | |
| upscale = int(upscale) | |
| passes = int(passes) | |
| in_t = preprocess(image) | |
| if passes > 1: | |
| out_t = engine.restore_multi_pass(in_t, strength=strength, passes=passes, use_ema=use_ema) | |
| else: | |
| out_t = engine.restore(in_t, strength=strength, use_ema=use_ema) | |
| before = postprocess(in_t[0], upscale) | |
| after = postprocess(out_t[0], upscale) | |
| compare = compare_images(before, after) | |
| palette = engine.extract_palette(out_t[0:1], use_ema=use_ema) | |
| pal_img = make_palette_image(palette, swatch_size=40) | |
| return after, compare, pal_img | |
| def fn_inpaint(editor_dict, use_ema, upscale): | |
| if not editor_dict or editor_dict.get("background") is None: | |
| return None, None | |
| upscale = int(upscale) | |
| bg = editor_dict["background"].convert("RGBA") | |
| in_t = preprocess(bg) | |
| mask_t = torch.zeros((1, 1, 16, 16), device=config.DEVICE) | |
| layers = editor_dict.get("layers", []) | |
| if layers: | |
| drawing = layers[0].convert("RGBA").resize((16, 16), Image.Resampling.NEAREST) | |
| mask_np = (np.array(drawing)[:, :, 3] > 0).astype(np.float32) | |
| mask_t = torch.from_numpy(mask_np).unsqueeze(0).unsqueeze(0).to(config.DEVICE) | |
| out_t = engine.inpaint(in_t, mask_t, use_ema=use_ema) | |
| result = postprocess(out_t[0], upscale) | |
| palette = engine.extract_palette(out_t[0:1], use_ema=use_ema) | |
| pal_img = make_palette_image(palette, swatch_size=40) | |
| return result, pal_img | |
| def fn_inpaint_alpha(image, use_ema, upscale): | |
| """Fill transparent pixels using the model.""" | |
| if image is None: return None, None | |
| upscale = int(upscale) | |
| in_t = preprocess(image) | |
| mask_t = (in_t[:, 3:4] < 0.5).float() | |
| out_t = engine.inpaint(in_t, mask_t, use_ema=use_ema) | |
| result = postprocess(out_t[0], upscale) | |
| palette = engine.extract_palette(out_t[0:1], use_ema=use_ema) | |
| pal_img = make_palette_image(palette, swatch_size=40) | |
| return result, pal_img | |
| def fn_variation(image, diversity, n_var, steps, use_ema, upscale): | |
| if image is None: return [], None | |
| n_var = int(n_var) | |
| steps = int(steps) | |
| upscale = int(upscale) | |
| in_t = preprocess(image) | |
| out_t = engine.generate_variation( | |
| in_t, diversity=diversity, n_samples=n_var, | |
| n_steps=steps, use_ema=use_ema) | |
| pils = tensor_to_pils(out_t, upscale) | |
| palette = engine.extract_palette(out_t[0:1], use_ema=use_ema) | |
| pal_img = make_palette_image(palette, swatch_size=40) | |
| gallery = [(p, f"Variation {i+1}") for i, p in enumerate(pils)] | |
| return gallery, pal_img | |
| def fn_batch_generate(n_total, steps, temperature, use_ema, upscale): | |
| n_total = int(n_total) | |
| steps = int(steps) | |
| upscale = int(upscale) | |
| out_t = engine.generate( | |
| n_samples=n_total, n_steps=steps, | |
| temperature=temperature, use_ema=use_ema) | |
| pils = tensor_to_pils(out_t, upscale) | |
| gallery = [(p, f"#{i+1}") for i, p in enumerate(pils)] | |
| return gallery | |
| # ============================================================================ | |
| # 7. GRADIO UI | |
| # ============================================================================ | |
| CSS = """ | |
| /* ββ Pixel-perfect rendering for all output images ββ */ | |
| .pixel-out img, | |
| .pixel-gallery img, | |
| .pixel-compare img { | |
| image-rendering: pixelated !important; | |
| image-rendering: crisp-edges !important; | |
| } | |
| /* ββ Dark card styling ββ */ | |
| .card { background: #1a1a2e; border-radius: 10px; padding: 12px; } | |
| /* ββ Status badge ββ */ | |
| .status-box textarea { font-family: monospace; font-size: 12px; } | |
| /* ββ Palette strip ββ */ | |
| .pal-out img { | |
| image-rendering: pixelated !important; | |
| border-radius: 4px; | |
| border: 1px solid #444; | |
| } | |
| """ | |
| _UPSCALE_CHOICES = ["4", "8", "16", "24", "32"] | |
| _MODEL_CHOICES = ["Best (best.pth)", "Latest (latest.pth)"] | |
| def _upscale_ctrl(default="16"): | |
| return gr.Radio( | |
| _UPSCALE_CHOICES, value=default, label="Preview Scale", | |
| info="Pixel-perfect zoom multiplier (Γ original 16px)") | |
| def _ema_ctrl(): | |
| return gr.Checkbox(value=True, label="Use EMA Weights", | |
| info="Usually better quality") | |
| def _palette_out(): | |
| return gr.Image(label="Predicted Palette", type="pil", | |
| elem_classes=["pal-out"]) | |
| with gr.Blocks(theme=gr.themes.Monochrome(), css=CSS, | |
| title="PixelMAE v10 β Neural Sprite Engine") as app: | |
| gr.Markdown( | |
| """ | |
| # πΎ PixelMAE β Neural Sprite Engine `v10` | |
| Asymmetric Masked Autoencoder Β· 16 Γ 16 RGBA Β· Generation Β· Restoration Β· Inpainting Β· Variations | |
| """) | |
| # ββ Model loader βββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| with gr.Accordion("βοΈ Model & Checkpoint", open=False): | |
| with gr.Row(): | |
| model_choice = gr.Radio( | |
| _MODEL_CHOICES, value="Best (best.pth)", | |
| label="Checkpoint") | |
| load_btn = gr.Button("π Load / Reload Model", variant="secondary") | |
| ckpt_info = gr.Markdown(value=_ckpt_info_text, label="Checkpoint Info") | |
| load_btn.click(fn_load_model, inputs=model_choice, outputs=ckpt_info) | |
| gr.Markdown("---") | |
| # ββ TABS βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| with gr.Tabs(): | |
| # ββ TAB 1: GENERATE ββββββββββββββββββββββββββββββββββββββββββββββ | |
| with gr.TabItem("β¨ Generate"): | |
| gr.Markdown( | |
| "Generate sprites from noise using iterative MaskGIT decoding. " | |
| "Optionally seed from an existing image.") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gen_seed = gr.Image( | |
| label="Seed Image (optional)", type="pil", | |
| image_mode="RGBA", height=200) | |
| gen_ratio = gr.Slider(0.0, 1.0, 0.0, step=0.05, | |
| label="Seed Fidelity", | |
| info="0 = fully random Β· 1 = preserve seed completely") | |
| gen_steps = gr.Slider(4, 64, 16, step=1, | |
| label="Decoding Steps", | |
| info="More steps β smoother, slower") | |
| gen_temp = gr.Slider(0.5, 2.0, 1.0, step=0.05, | |
| label="Temperature", | |
| info="< 1 sharper/conservative Β· > 1 more chaotic/varied") | |
| gen_use_ema = _ema_ctrl() | |
| gen_upscale = _upscale_ctrl("16") | |
| gen_btn = gr.Button("β‘ Generate", variant="primary") | |
| with gr.Column(scale=1): | |
| gen_out = gr.Image( | |
| label="Output Sprite", type="pil", format="png", | |
| elem_classes=["pixel-out"]) | |
| gen_pal = _palette_out() | |
| gen_btn.click( | |
| fn_generate_single, | |
| inputs=[gen_seed, gen_ratio, gen_steps, | |
| gen_temp, gen_use_ema, gen_upscale], | |
| outputs=[gen_out, gen_pal]) | |
| # ββ TAB 2: BATCH GENERATE ββββββββββββββββββββββββββββββββββββββββ | |
| with gr.TabItem("ποΈ Batch Generate"): | |
| gr.Markdown( | |
| "Generate multiple sprites in one shot. " | |
| "All sprites are rendered pixel-perfect in the gallery.") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| bg_n = gr.Slider(1, 16, 8, step=1, | |
| label="Number of Sprites") | |
| bg_steps = gr.Slider(4, 64, 16, step=1, | |
| label="Decoding Steps") | |
| bg_temp = gr.Slider(0.5, 2.0, 1.0, step=0.05, | |
| label="Temperature") | |
| bg_use_ema = _ema_ctrl() | |
| bg_upscale = _upscale_ctrl("16") | |
| bg_btn = gr.Button("β‘ Generate Batch", variant="primary") | |
| with gr.Column(scale=2): | |
| bg_gallery = gr.Gallery( | |
| label="Generated Sprites", columns=4, rows=4, | |
| height="auto", elem_classes=["pixel-gallery"], format="png") | |
| bg_btn.click( | |
| fn_batch_generate, | |
| inputs=[bg_n, bg_steps, bg_temp, bg_use_ema, bg_upscale], | |
| outputs=bg_gallery) | |
| # ββ TAB 3: RESTORE & REFINE ββββββββββββββββββββββββββββββββββββββ | |
| with gr.TabItem("π Restore & Refine"): | |
| gr.Markdown( | |
| "Feed the model an existing sprite. It randomly masks regions and " | |
| "reconstructs them, improving pixel coherence. Multi-pass iteratively refines.") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| rest_img = gr.Image( | |
| label="Input Sprite", type="pil", | |
| image_mode="RGBA", height=200) | |
| rest_str = gr.Slider(0.0, 1.0, 0.5, step=0.05, | |
| label="Mask Strength", | |
| info="What fraction of pixels are randomly re-predicted") | |
| rest_passes = gr.Slider(1, 6, 1, step=1, | |
| label="Passes", | |
| info="Multi-pass: each pass refines further") | |
| rest_use_ema = _ema_ctrl() | |
| rest_upscale = _upscale_ctrl("16") | |
| rest_btn = gr.Button("π Restore", variant="primary") | |
| with gr.Column(scale=1): | |
| rest_out = gr.Image( | |
| label="Restored Sprite", type="pil", format="png", | |
| elem_classes=["pixel-out"]) | |
| rest_compare = gr.Image( | |
| label="Before Β· After", type="pil", | |
| elem_classes=["pixel-compare"]) | |
| rest_pal = _palette_out() | |
| rest_btn.click( | |
| fn_restore, | |
| inputs=[rest_img, rest_str, rest_passes, rest_use_ema, rest_upscale], | |
| outputs=[rest_out, rest_compare, rest_pal]) | |
| # ββ TAB 4: SMART INPAINT βββββββββββββββββββββββββββββββββββββββββ | |
| with gr.TabItem("ποΈ Smart Inpaint"): | |
| gr.Markdown( | |
| "Upload a sprite and **paint the mask** over pixels you want the AI to redraw. " | |
| "Use any brush color β coverage counts, not color.") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| inp_editor = gr.ImageEditor( | |
| label="Draw Mask (paint = redo)", | |
| type="pil", image_mode="RGBA", | |
| brush=gr.Brush(colors=["#ff0000"], color_mode="fixed"), | |
| height=300) | |
| inp_use_ema = _ema_ctrl() | |
| inp_upscale = _upscale_ctrl("16") | |
| inp_btn = gr.Button("ποΈ Inpaint", variant="primary") | |
| with gr.Column(scale=1): | |
| inp_out = gr.Image( | |
| label="Inpainted Output", type="pil", format="png", | |
| elem_classes=["pixel-out"]) | |
| inp_pal = _palette_out() | |
| inp_btn.click( | |
| fn_inpaint, | |
| inputs=[inp_editor, inp_use_ema, inp_upscale], | |
| outputs=[inp_out, inp_pal]) | |
| # ββ TAB 5: FILL TRANSPARENT ββββββββββββββββββββββββββββββββββββββ | |
| with gr.TabItem("π Fill Transparent"): | |
| gr.Markdown( | |
| "Upload a partially transparent sprite (RGBA). " | |
| "The model will fill all transparent pixels based on the visible context.") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| alp_img = gr.Image( | |
| label="Partial Sprite (RGBA)", type="pil", | |
| image_mode="RGBA", height=200) | |
| alp_use_ema = _ema_ctrl() | |
| alp_upscale = _upscale_ctrl("16") | |
| alp_btn = gr.Button("β¨ Fill Transparent", variant="primary") | |
| with gr.Column(scale=1): | |
| alp_out = gr.Image( | |
| label="Completed Sprite", type="pil", format="png", | |
| elem_classes=["pixel-out"]) | |
| alp_pal = _palette_out() | |
| alp_btn.click( | |
| fn_inpaint_alpha, | |
| inputs=[alp_img, alp_use_ema, alp_upscale], | |
| outputs=[alp_out, alp_pal]) | |
| # ββ TAB 6: VARIATIONS ββββββββββββββββββββββββββββββββββββββββββββ | |
| with gr.TabItem("π² Variations"): | |
| gr.Markdown( | |
| "Upload a sprite and generate **N creative variations**. " | |
| "Diversity 0 = almost identical Β· 1 = free improvisation on the theme.") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| var_img = gr.Image( | |
| label="Source Sprite", type="pil", | |
| image_mode="RGBA", height=200) | |
| var_div = gr.Slider(0.0, 1.0, 0.5, step=0.05, | |
| label="Diversity", | |
| info="How far variants can deviate from the source") | |
| var_n = gr.Slider(1, 16, 4, step=1, | |
| label="Number of Variations") | |
| var_steps = gr.Slider(4, 64, 16, step=1, | |
| label="Decoding Steps") | |
| var_use_ema = _ema_ctrl() | |
| var_upscale = _upscale_ctrl("16") | |
| var_btn = gr.Button("π² Generate Variations", variant="primary") | |
| with gr.Column(scale=2): | |
| var_gallery = gr.Gallery( | |
| label="Variations", columns=4, rows=4, | |
| height="auto", elem_classes=["pixel-gallery"], | |
| format="png") | |
| var_pal = _palette_out() | |
| var_btn.click( | |
| fn_variation, | |
| inputs=[var_img, var_div, var_n, var_steps, var_use_ema, var_upscale], | |
| outputs=[var_gallery, var_pal]) | |
| # ββ Footer ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| gr.Markdown( | |
| """ | |
| --- | |
| **PixelMAE v10** Β· Asymmetric MAE Β· 4.2M params Β· 16Γ16 RGBA | |
| Model: [`MidFord327/PixelArt-MAE-v7`](https://huggingface.co/MidFord327/PixelArt-MAE-v7) | |
| """) | |
| if __name__ == "__main__": | |
| app.launch() | |