""" PixelMAE v10 — Neural Sprite Engine Hugging Face Spaces · Production UI """ import os import math import warnings from typing import Optional, Dict, List, Tuple from dataclasses import dataclass, field import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from PIL import Image, ImageDraw import gradio as gr from huggingface_hub import hf_hub_download warnings.filterwarnings("ignore") # ============================================================================ # REPO CONFIG — edit these two lines # ============================================================================ REPO_ID = "MidFord327/PixelArt-MAE-v7" MODEL_FILES = {"best": "best.pth", "latest": "latest.pth"} # ============================================================================ # 1. CONFIGURATION # ============================================================================ @dataclass class Config: IMAGE_SIZE: int = 16 CHANNELS: int = 4 NUM_PIXELS: int = 256 LATENT_DIM: int = 192 ENCODER_LAYERS:int = 6 DECODER_LAYERS:int = 4 NUM_HEADS: int = 6 FFN_RATIO: int = 4 DROPOUT: float = 0.0 PALETTE_SIZE: int = 16 USE_EMA: bool = True EMA_DECAY: float = 0.9995 EMA_UPDATE_EVERY: int = 1 DEVICE: torch.device = field( default_factory=lambda: torch.device("cpu")) # ============================================================================ # 2. ARCHITECTURE # ============================================================================ def get_2d_sincos_pos_embed(embed_dim: int, grid_size: int) -> torch.Tensor: grid_h = torch.arange(grid_size, dtype=torch.float32) grid_w = torch.arange(grid_size, dtype=torch.float32) grid = torch.meshgrid(grid_h, grid_w, indexing="ij") grid = torch.stack(grid, dim=0).reshape(2, -1) half = embed_dim // 4 omega = 1.0 / (10000 ** (torch.arange(half, dtype=torch.float32) / half)) out_h = torch.einsum("n,d->nd", grid[0], omega) out_w = torch.einsum("n,d->nd", grid[1], omega) pe = torch.zeros(grid_size * grid_size, embed_dim) pe[:, 0::4] = torch.sin(out_h); pe[:, 1::4] = torch.cos(out_h) pe[:, 2::4] = torch.sin(out_w); pe[:, 3::4] = torch.cos(out_w) return pe.unsqueeze(0) class EMA: def __init__(self, model, decay=0.9995, update_every=1): self.model = model self.decay = decay self.update_every = update_every self.step = 0 self.shadow = {n: p.data.clone() for n, p in model.named_parameters() if p.requires_grad} self.backup = {} def apply_shadow(self): for name, param in self.model.named_parameters(): if param.requires_grad and name in self.shadow: self.backup[name] = param.data.clone() param.data.copy_(self.shadow[name]) def restore(self): for name, param in self.model.named_parameters(): if name in self.backup: param.data.copy_(self.backup[name]) self.backup = {} def load_state_dict(self, state, device): self.shadow = {k: v.to(device) for k, v in state["shadow"].items()} self.step = state["step"] class MHSA(nn.Module): def __init__(self, dim, heads, dropout=0.0): super().__init__() self.heads = heads self.head_dim = dim // heads self.qkv = nn.Linear(dim, dim * 3) self.proj = nn.Linear(dim, dim) self.drop = dropout def forward(self, x): B, N, C = x.shape qkv = self.qkv(x).reshape(B, N, 3, self.heads, self.head_dim) q, k, v = qkv.permute(2, 0, 3, 1, 4).unbind(0) out = F.scaled_dot_product_attention(q, k, v, dropout_p=0.0) return self.proj(out.transpose(1, 2).reshape(B, N, C)) class TransformerBlock(nn.Module): def __init__(self, dim, heads, ffn_ratio=4, dropout=0.0): super().__init__() self.norm1 = nn.LayerNorm(dim) self.attn = MHSA(dim, heads, dropout) self.norm2 = nn.LayerNorm(dim) self.ffn = nn.Sequential( nn.Linear(dim, dim * ffn_ratio), nn.GELU(), nn.Dropout(dropout), nn.Linear(dim * ffn_ratio, dim), nn.Dropout(dropout)) def forward(self, x): x = x + self.attn(self.norm1(x)) x = x + self.ffn(self.norm2(x)) return x class AsymmetricPixelMAE(nn.Module): def __init__(self, config: Config): super().__init__() self.config = config D, N = config.LATENT_DIM, config.NUM_PIXELS self.pixel_embed = nn.Linear(config.CHANNELS, D) self.embed_norm = nn.LayerNorm(D) self.pos_embed = nn.Parameter( get_2d_sincos_pos_embed(D, config.IMAGE_SIZE), requires_grad=False) self.cls_token = nn.Parameter(torch.zeros(1, 1, D)) self.mask_token = nn.Parameter(torch.zeros(1, 1, D)) nn.init.normal_(self.cls_token, std=0.02) nn.init.normal_(self.mask_token, std=0.02) self.encoder = nn.ModuleList([ TransformerBlock(D, config.NUM_HEADS, config.FFN_RATIO, config.DROPOUT) for _ in range(config.ENCODER_LAYERS)]) self.encoder_norm = nn.LayerNorm(D) self.decoder_embed= nn.Linear(D, D) self.decoder_pos_embed = nn.Parameter(torch.zeros(1, N + 1, D), requires_grad=False) dp = torch.zeros(1, N + 1, D) dp[:, 1:, :] = get_2d_sincos_pos_embed(D, config.IMAGE_SIZE) self.decoder_pos_embed.data.copy_(dp) self.decoder = nn.ModuleList([ TransformerBlock(D, config.NUM_HEADS, config.FFN_RATIO, config.DROPOUT) for _ in range(config.DECODER_LAYERS)]) self.decoder_norm = nn.LayerNorm(D) self.pixel_head_rgb = nn.Sequential( nn.LayerNorm(D), nn.Linear(D, D), nn.GELU(), nn.Linear(D, 3)) self.pixel_head_alpha = nn.Sequential( nn.LayerNorm(D), nn.Linear(D, D), nn.GELU(), nn.Linear(D, 1)) self.palette_head = nn.Sequential( nn.LayerNorm(D), nn.Linear(D, D), nn.GELU(), nn.Linear(D, config.PALETTE_SIZE * config.CHANNELS), nn.Sigmoid()) def _tokenize(self, x): return self.embed_norm(self.pixel_embed(x.flatten(2).transpose(1, 2))) def forward_encoder(self, x, mask=None): B, N, D = x.shape[0], self.config.NUM_PIXELS, self.config.LATENT_DIM tokens = self._tokenize(x) + self.pos_embed if mask is not None: mf = mask.flatten(1) if mf.shape[1] != N: mf = mask.reshape(B, -1) if mf.dim() == 3: mf = mf.squeeze(1) noise = torch.rand(B, N, device=x.device) ids_shuf = torch.argsort(noise + mf.float() * 1e6, dim=1) ids_rest = torch.argsort(ids_shuf, dim=1) n_vis = max(1, int((1 - mf.float()).sum(dim=1).min().item())) ids_keep = ids_shuf[:, :n_vis] tokens = torch.gather(tokens, 1, ids_keep.unsqueeze(-1).expand(-1, -1, D)) else: n_vis = N ids_rest = torch.arange(N, device=x.device).unsqueeze(0).expand(B, -1) cls = self.cls_token.expand(B, -1, -1) x_enc = torch.cat([cls, tokens], dim=1) for blk in self.encoder: x_enc = blk(x_enc) return self.encoder_norm(x_enc), ids_rest, n_vis def forward_decoder(self, x_enc, ids_rest, n_vis): B, N, D = x_enc.shape[0], self.config.NUM_PIXELS, self.config.LATENT_DIM cls_enc, vis_enc = x_enc[:, :1], x_enc[:, 1:] vis_dec = self.decoder_embed(vis_enc) n_masked = N - n_vis if n_masked > 0: mask_tok = self.mask_token.expand(B, n_masked, -1) full_seq = torch.cat([vis_dec, mask_tok], dim=1) else: full_seq = vis_dec full_seq = torch.gather(full_seq, 1, ids_rest.unsqueeze(-1).expand(-1, -1, D)) full_seq = torch.cat([self.decoder_embed(cls_enc), full_seq], dim=1) + self.decoder_pos_embed for blk in self.decoder: full_seq = blk(full_seq) return self.decoder_norm(full_seq)[:, 1:] def forward(self, x, mask=None): B, C, H, W = x.shape enc, ids, n_vis = self.forward_encoder(x, mask) dec = self.forward_decoder(enc, ids, n_vis) rgb = torch.sigmoid(self.pixel_head_rgb(dec)).transpose(1, 2).reshape(B, 3, H, W) alp = torch.sigmoid(self.pixel_head_alpha(dec)).transpose(1, 2).reshape(B, 1, H, W) return {"pixel_pred": torch.cat([rgb, alp], dim=1)} # ============================================================================ # 3. INFERENCE ENGINE (extended) # ============================================================================ class InferenceEngine: def __init__(self, config: Config, device: torch.device): self.config = config self.device = device self.model = AsymmetricPixelMAE(config).to(device) self.ema = EMA(self.model, config.EMA_DECAY, config.EMA_UPDATE_EVERY) if config.USE_EMA else None self.ckpt_meta: Dict = {} def load_checkpoint(self, path: str) -> Dict: ckpt = torch.load(path, map_location=self.device, weights_only=False) state = {k.replace("module.", "").replace("_orig_mod.", ""): v for k, v in ckpt["model"].items()} self.model.load_state_dict(state, strict=False) if self.ema and ckpt.get("ema"): try: self.ema.load_state_dict(ckpt["ema"], self.device) except Exception: pass self.ckpt_meta = { "epoch": ckpt.get("epoch", "?"), "best_loss": ckpt.get("best_loss", None), "has_ema": bool(self.ema and ckpt.get("ema")), } return self.ckpt_meta # ── Core helpers ────────────────────────────────────────────────────── def _apply_ema(self): if self.ema: self.ema.apply_shadow() def _restore_ema(self): if self.ema: self.ema.restore() # ── Inpaint ─────────────────────────────────────────────────────────── @torch.no_grad() def inpaint(self, image: torch.Tensor, mask: torch.Tensor, use_ema: bool = True) -> torch.Tensor: self.model.eval() if image.dim() == 3: image = image.unsqueeze(0) if mask.dim() == 3: mask = mask.unsqueeze(0) if use_ema: self._apply_ema() try: pred = self.model(image, mask=mask.float())["pixel_pred"] result = torch.where(mask.expand_as(image).bool(), pred, image) finally: if use_ema: self._restore_ema() return result.clamp(0, 1) # ── Restore ─────────────────────────────────────────────────────────── @torch.no_grad() def restore(self, image: torch.Tensor, strength: float = 0.5, use_ema: bool = True) -> torch.Tensor: self.model.eval() if image.dim() == 3: image = image.unsqueeze(0) B, C, H, W = image.shape ratio = 0.05 + strength * 0.50 mask = (torch.rand(B, 1, H, W, device=image.device) < ratio).float() return self.inpaint(image, mask, use_ema=use_ema) @torch.no_grad() def restore_multi_pass(self, image: torch.Tensor, strength: float = 0.5, passes: int = 3, use_ema: bool = True) -> torch.Tensor: """Iterative multi-pass restore — each pass uses slightly less strength.""" result = image.clone() for i in range(passes): s = strength * (1.0 - i * 0.1) result = self.restore(result, strength=max(0.05, s), use_ema=use_ema) return result # ── Generate from scratch ───────────────────────────────────────────── @torch.no_grad() def generate(self, n_samples: int = 1, n_steps: int = 12, seed_image: Optional[torch.Tensor] = None, seed_ratio: float = 0.0, temperature: float = 1.0, use_ema: bool = True) -> torch.Tensor: self.model.eval() if use_ema: self._apply_ema() try: H = W = self.config.IMAGE_SIZE C, N = self.config.CHANNELS, self.config.NUM_PIXELS dev = self.device B = n_samples if seed_image is not None: if seed_image.dim() == 3: seed_image = seed_image.unsqueeze(0) seed_image = seed_image.expand(B, -1, -1, -1) canvas = seed_image.clone() keep = (torch.rand(B, 1, H, W, device=dev) < seed_ratio) else: canvas = torch.rand(B, C, H, W, device=dev) * 0.5 keep = torch.zeros(B, 1, H, W, device=dev, dtype=torch.bool) revealed = keep.float() for step in range(n_steps): mask = 1.0 - revealed n_unrevealed = int(mask.sum().item() / B) if n_unrevealed == 0: break pred = self.model(canvas, mask=mask)["pixel_pred"] # temperature scaling — add noise before confidence scoring if temperature != 1.0: noise = torch.randn_like(pred) * (temperature - 1.0) * 0.05 pred = (pred + noise).clamp(0, 1) canvas = torch.where(mask.expand_as(canvas).bool(), pred, canvas) if step < n_steps - 1: confidence = torch.abs(pred - 0.5).mean(dim=1, keepdim=True) confidence = confidence * mask + torch.rand_like(confidence) * 0.05 progress = (step + 1) / n_steps target_frac = 1.0 - 0.5 * (1 + math.cos(math.pi * progress)) cur_revealed = int(revealed.sum().item() / B) n_reveal = max(1, int(target_frac * N) - cur_revealed) n_reveal = min(n_reveal, n_unrevealed) flat_conf = confidence.reshape(B, -1).masked_fill( mask.reshape(B, -1) < 0.5, -float("inf")) _, top = flat_conf.topk(min(n_reveal, N), dim=1) new_rev = torch.zeros(B, N, device=dev).scatter_(1, top, 1.0) revealed = (revealed + new_rev.reshape(B, 1, H, W)).clamp(0, 1) finally: if use_ema: self._restore_ema() return canvas.clamp(0, 1) # ── Variation ───────────────────────────────────────────────────────── @torch.no_grad() def generate_variation(self, image: torch.Tensor, diversity: float = 0.5, n_samples: int = 1, n_steps: int = 16, use_ema: bool = True) -> torch.Tensor: """Generate variations of an existing sprite. diversity=0 → near-identical copy; diversity=1 → fully free generation.""" seed_ratio = 1.0 - diversity return self.generate( n_samples=n_samples, n_steps=n_steps, seed_image=image, seed_ratio=seed_ratio, use_ema=use_ema) # ── Palette extraction from CLS token ───────────────────────────────── @torch.no_grad() def extract_palette(self, image: torch.Tensor, use_ema: bool = True) -> torch.Tensor: """Returns (PALETTE_SIZE, 4) RGBA palette from the model's palette head.""" self.model.eval() if image.dim() == 3: image = image.unsqueeze(0) if use_ema: self._apply_ema() try: enc, _, _ = self.model.forward_encoder(image) cls_tok = enc[:, :1, :] # (1, 1, D) palette = self.model.palette_head(cls_tok) # (1, 1, P*C) palette = palette.reshape(self.config.PALETTE_SIZE, self.config.CHANNELS) finally: if use_ema: self._restore_ema() return palette.clamp(0, 1) # ============================================================================ # 4. INIT ENGINE + MODEL MANAGEMENT # ============================================================================ print("Initializing PixelMAE Inference Engine …") config = Config() engine = InferenceEngine(config, config.DEVICE) # Global state _loaded_model_key = None _ckpt_info_text = "No model loaded." _model_paths: Dict[str, Optional[str]] = {"best": None, "latest": None} def _download_model(key: str) -> Tuple[Optional[str], str]: filename = MODEL_FILES[key] try: path = hf_hub_download(repo_id=REPO_ID, filename=filename) return path, f"✅ Downloaded `{filename}` from `{REPO_ID}`." except Exception as e: return None, f"❌ Could not download `{filename}`: {e}" def load_model(model_choice: str) -> str: global _loaded_model_key, _ckpt_info_text, _model_paths key = "best" if model_choice == "Best (best.pth)" else "latest" if _model_paths[key] is None: path, msg = _download_model(key) if path is None: _ckpt_info_text = msg return _ckpt_info_text _model_paths[key] = path try: meta = engine.load_checkpoint(_model_paths[key]) loss_str = f"{meta['best_loss']:.6f}" if meta["best_loss"] is not None else "N/A" _ckpt_info_text = ( f"**Model:** `{MODEL_FILES[key]}` \n" f"**Epoch:** {meta['epoch']} \n" f"**Best Val Loss:** {loss_str} \n" f"**EMA Weights:** {'✅ loaded' if meta['has_ema'] else '⚠️ not found'} \n" f"**Device:** {config.DEVICE}" ) _loaded_model_key = key except Exception as e: _ckpt_info_text = f"❌ Load error: {e}" return _ckpt_info_text # Pre-load best model at startup (non-fatal) try: _path, _msg = _download_model("best") if _path: _model_paths["best"] = _path load_model("Best (best.pth)") print(f" └─ {_ckpt_info_text.replace(chr(10), ' ')}") else: print(f" └─ {_msg}") except Exception as ex: print(f" └─ Startup load failed: {ex}") # ============================================================================ # 5. IMAGE UTILITIES # ============================================================================ def preprocess(pil_img: Image.Image) -> torch.Tensor: img = pil_img.convert("RGBA").resize( (config.IMAGE_SIZE, config.IMAGE_SIZE), Image.Resampling.NEAREST) arr = np.array(img, np.float32) / 255.0 return torch.from_numpy(arr).permute(2, 0, 1).unsqueeze(0).to(config.DEVICE) def postprocess(tensor: torch.Tensor, upscale: int = 16) -> Image.Image: """Convert tensor → pixel-perfect upscaled PIL image (RGBA).""" if tensor is None: return None if tensor.dim() == 4: tensor = tensor.squeeze(0) arr = (tensor.permute(1, 2, 0).cpu().clamp(0, 1).numpy() * 255).astype(np.uint8) img = Image.fromarray(arr, "RGBA") return img.resize( (config.IMAGE_SIZE * upscale, config.IMAGE_SIZE * upscale), Image.Resampling.NEAREST) def tensor_to_pils(batch: torch.Tensor, upscale: int = 16) -> List[Image.Image]: """Batch tensor → list of PIL images.""" if batch.dim() == 3: batch = batch.unsqueeze(0) return [postprocess(batch[i], upscale) for i in range(batch.shape[0])] def make_palette_image(palette: torch.Tensor, swatch_size: int = 32) -> Image.Image: """Render palette as a row of color swatches.""" n = palette.shape[0] img = Image.new("RGBA", (n * swatch_size, swatch_size), (0, 0, 0, 0)) draw = ImageDraw.Draw(img) for i, color in enumerate(palette): rgba = tuple((color.cpu().numpy() * 255).astype(int).tolist()) draw.rectangle([i * swatch_size, 0, (i + 1) * swatch_size - 1, swatch_size - 1], fill=rgba) return img def compare_images(before: Image.Image, after: Image.Image, label_size: int = 10) -> Image.Image: """Side-by-side comparison with labels.""" w, h = before.size total = Image.new("RGBA", (w * 2 + 4, h), (20, 20, 20, 255)) total.paste(before, (0, 0)) total.paste(after, (w + 4, 0)) return total def extract_unique_colors(img: Image.Image, max_colors: int = 32) -> List[Tuple]: """Extract sorted unique RGBA colors (ignoring transparent).""" arr = np.array(img.convert("RGBA")) pixels = arr.reshape(-1, 4) pixels = pixels[pixels[:, 3] > 10] # ignore near-transparent unique = np.unique(pixels, axis=0) return [tuple(c) for c in unique[:max_colors]] # ============================================================================ # 6. GRADIO FUNCTIONS # ============================================================================ def fn_load_model(choice): return load_model(choice) def fn_generate(seed_img, seed_ratio, steps, n_samples, temperature, use_ema, upscale): n_samples = int(n_samples) steps = int(steps) upscale = int(upscale) seed_t = preprocess(seed_img) if seed_img is not None else None out = engine.generate( n_samples=n_samples, n_steps=steps, seed_image=seed_t, seed_ratio=seed_ratio, temperature=temperature, use_ema=use_ema) pils = tensor_to_pils(out, upscale) # Return gallery items + first image separately + palette palette = engine.extract_palette(out[0:1], use_ema=use_ema) pal_img = make_palette_image(palette, swatch_size=40) gallery = [(p, f"Sample {i+1}") for i, p in enumerate(pils)] return gallery, pils[0], pal_img def fn_generate_single(seed_img, seed_ratio, steps, temperature, use_ema, upscale): upscale = int(upscale) steps = int(steps) seed_t = preprocess(seed_img) if seed_img is not None else None out = engine.generate( n_samples=1, n_steps=steps, seed_image=seed_t, seed_ratio=seed_ratio, temperature=temperature, use_ema=use_ema) pil = postprocess(out[0], upscale) palette = engine.extract_palette(out[0:1], use_ema=use_ema) pal_img = make_palette_image(palette, swatch_size=40) return pil, pal_img def fn_restore(image, strength, passes, use_ema, upscale): if image is None: return None, None, None upscale = int(upscale) passes = int(passes) in_t = preprocess(image) if passes > 1: out_t = engine.restore_multi_pass(in_t, strength=strength, passes=passes, use_ema=use_ema) else: out_t = engine.restore(in_t, strength=strength, use_ema=use_ema) before = postprocess(in_t[0], upscale) after = postprocess(out_t[0], upscale) compare = compare_images(before, after) palette = engine.extract_palette(out_t[0:1], use_ema=use_ema) pal_img = make_palette_image(palette, swatch_size=40) return after, compare, pal_img def fn_inpaint(editor_dict, use_ema, upscale): if not editor_dict or editor_dict.get("background") is None: return None, None upscale = int(upscale) bg = editor_dict["background"].convert("RGBA") in_t = preprocess(bg) mask_t = torch.zeros((1, 1, 16, 16), device=config.DEVICE) layers = editor_dict.get("layers", []) if layers: drawing = layers[0].convert("RGBA").resize((16, 16), Image.Resampling.NEAREST) mask_np = (np.array(drawing)[:, :, 3] > 0).astype(np.float32) mask_t = torch.from_numpy(mask_np).unsqueeze(0).unsqueeze(0).to(config.DEVICE) out_t = engine.inpaint(in_t, mask_t, use_ema=use_ema) result = postprocess(out_t[0], upscale) palette = engine.extract_palette(out_t[0:1], use_ema=use_ema) pal_img = make_palette_image(palette, swatch_size=40) return result, pal_img def fn_inpaint_alpha(image, use_ema, upscale): """Fill transparent pixels using the model.""" if image is None: return None, None upscale = int(upscale) in_t = preprocess(image) mask_t = (in_t[:, 3:4] < 0.5).float() out_t = engine.inpaint(in_t, mask_t, use_ema=use_ema) result = postprocess(out_t[0], upscale) palette = engine.extract_palette(out_t[0:1], use_ema=use_ema) pal_img = make_palette_image(palette, swatch_size=40) return result, pal_img def fn_variation(image, diversity, n_var, steps, use_ema, upscale): if image is None: return [], None n_var = int(n_var) steps = int(steps) upscale = int(upscale) in_t = preprocess(image) out_t = engine.generate_variation( in_t, diversity=diversity, n_samples=n_var, n_steps=steps, use_ema=use_ema) pils = tensor_to_pils(out_t, upscale) palette = engine.extract_palette(out_t[0:1], use_ema=use_ema) pal_img = make_palette_image(palette, swatch_size=40) gallery = [(p, f"Variation {i+1}") for i, p in enumerate(pils)] return gallery, pal_img def fn_batch_generate(n_total, steps, temperature, use_ema, upscale): n_total = int(n_total) steps = int(steps) upscale = int(upscale) out_t = engine.generate( n_samples=n_total, n_steps=steps, temperature=temperature, use_ema=use_ema) pils = tensor_to_pils(out_t, upscale) gallery = [(p, f"#{i+1}") for i, p in enumerate(pils)] return gallery # ============================================================================ # 7. GRADIO UI # ============================================================================ CSS = """ /* ── Pixel-perfect rendering for all output images ── */ .pixel-out img, .pixel-gallery img, .pixel-compare img { image-rendering: pixelated !important; image-rendering: crisp-edges !important; } /* ── Dark card styling ── */ .card { background: #1a1a2e; border-radius: 10px; padding: 12px; } /* ── Status badge ── */ .status-box textarea { font-family: monospace; font-size: 12px; } /* ── Palette strip ── */ .pal-out img { image-rendering: pixelated !important; border-radius: 4px; border: 1px solid #444; } """ _UPSCALE_CHOICES = ["4", "8", "16", "24", "32"] _MODEL_CHOICES = ["Best (best.pth)", "Latest (latest.pth)"] def _upscale_ctrl(default="16"): return gr.Radio( _UPSCALE_CHOICES, value=default, label="Preview Scale", info="Pixel-perfect zoom multiplier (× original 16px)") def _ema_ctrl(): return gr.Checkbox(value=True, label="Use EMA Weights", info="Usually better quality") def _palette_out(): return gr.Image(label="Predicted Palette", type="pil", elem_classes=["pal-out"]) with gr.Blocks(theme=gr.themes.Monochrome(), css=CSS, title="PixelMAE v10 — Neural Sprite Engine") as app: gr.Markdown( """ # 👾 PixelMAE — Neural Sprite Engine `v10` Asymmetric Masked Autoencoder · 16 × 16 RGBA · Generation · Restoration · Inpainting · Variations """) # ── Model loader ───────────────────────────────────────────────────── with gr.Accordion("⚙️ Model & Checkpoint", open=False): with gr.Row(): model_choice = gr.Radio( _MODEL_CHOICES, value="Best (best.pth)", label="Checkpoint") load_btn = gr.Button("🔄 Load / Reload Model", variant="secondary") ckpt_info = gr.Markdown(value=_ckpt_info_text, label="Checkpoint Info") load_btn.click(fn_load_model, inputs=model_choice, outputs=ckpt_info) gr.Markdown("---") # ── TABS ───────────────────────────────────────────────────────────── with gr.Tabs(): # ── TAB 1: GENERATE ────────────────────────────────────────────── with gr.TabItem("✨ Generate"): gr.Markdown( "Generate sprites from noise using iterative MaskGIT decoding. " "Optionally seed from an existing image.") with gr.Row(): with gr.Column(scale=1): gen_seed = gr.Image( label="Seed Image (optional)", type="pil", image_mode="RGBA", height=200) gen_ratio = gr.Slider(0.0, 1.0, 0.0, step=0.05, label="Seed Fidelity", info="0 = fully random · 1 = preserve seed completely") gen_steps = gr.Slider(4, 64, 16, step=1, label="Decoding Steps", info="More steps → smoother, slower") gen_temp = gr.Slider(0.5, 2.0, 1.0, step=0.05, label="Temperature", info="< 1 sharper/conservative · > 1 more chaotic/varied") gen_use_ema = _ema_ctrl() gen_upscale = _upscale_ctrl("16") gen_btn = gr.Button("⚡ Generate", variant="primary") with gr.Column(scale=1): gen_out = gr.Image( label="Output Sprite", type="pil", format="png", elem_classes=["pixel-out"]) gen_pal = _palette_out() gen_btn.click( fn_generate_single, inputs=[gen_seed, gen_ratio, gen_steps, gen_temp, gen_use_ema, gen_upscale], outputs=[gen_out, gen_pal]) # ── TAB 2: BATCH GENERATE ──────────────────────────────────────── with gr.TabItem("🗂️ Batch Generate"): gr.Markdown( "Generate multiple sprites in one shot. " "All sprites are rendered pixel-perfect in the gallery.") with gr.Row(): with gr.Column(scale=1): bg_n = gr.Slider(1, 16, 8, step=1, label="Number of Sprites") bg_steps = gr.Slider(4, 64, 16, step=1, label="Decoding Steps") bg_temp = gr.Slider(0.5, 2.0, 1.0, step=0.05, label="Temperature") bg_use_ema = _ema_ctrl() bg_upscale = _upscale_ctrl("16") bg_btn = gr.Button("⚡ Generate Batch", variant="primary") with gr.Column(scale=2): bg_gallery = gr.Gallery( label="Generated Sprites", columns=4, rows=4, height="auto", elem_classes=["pixel-gallery"], format="png") bg_btn.click( fn_batch_generate, inputs=[bg_n, bg_steps, bg_temp, bg_use_ema, bg_upscale], outputs=bg_gallery) # ── TAB 3: RESTORE & REFINE ────────────────────────────────────── with gr.TabItem("🔄 Restore & Refine"): gr.Markdown( "Feed the model an existing sprite. It randomly masks regions and " "reconstructs them, improving pixel coherence. Multi-pass iteratively refines.") with gr.Row(): with gr.Column(scale=1): rest_img = gr.Image( label="Input Sprite", type="pil", image_mode="RGBA", height=200) rest_str = gr.Slider(0.0, 1.0, 0.5, step=0.05, label="Mask Strength", info="What fraction of pixels are randomly re-predicted") rest_passes = gr.Slider(1, 6, 1, step=1, label="Passes", info="Multi-pass: each pass refines further") rest_use_ema = _ema_ctrl() rest_upscale = _upscale_ctrl("16") rest_btn = gr.Button("🔄 Restore", variant="primary") with gr.Column(scale=1): rest_out = gr.Image( label="Restored Sprite", type="pil", format="png", elem_classes=["pixel-out"]) rest_compare = gr.Image( label="Before · After", type="pil", elem_classes=["pixel-compare"]) rest_pal = _palette_out() rest_btn.click( fn_restore, inputs=[rest_img, rest_str, rest_passes, rest_use_ema, rest_upscale], outputs=[rest_out, rest_compare, rest_pal]) # ── TAB 4: SMART INPAINT ───────────────────────────────────────── with gr.TabItem("🖌️ Smart Inpaint"): gr.Markdown( "Upload a sprite and **paint the mask** over pixels you want the AI to redraw. " "Use any brush color — coverage counts, not color.") with gr.Row(): with gr.Column(scale=1): inp_editor = gr.ImageEditor( label="Draw Mask (paint = redo)", type="pil", image_mode="RGBA", brush=gr.Brush(colors=["#ff0000"], color_mode="fixed"), height=300) inp_use_ema = _ema_ctrl() inp_upscale = _upscale_ctrl("16") inp_btn = gr.Button("🖌️ Inpaint", variant="primary") with gr.Column(scale=1): inp_out = gr.Image( label="Inpainted Output", type="pil", format="png", elem_classes=["pixel-out"]) inp_pal = _palette_out() inp_btn.click( fn_inpaint, inputs=[inp_editor, inp_use_ema, inp_upscale], outputs=[inp_out, inp_pal]) # ── TAB 5: FILL TRANSPARENT ────────────────────────────────────── with gr.TabItem("🔍 Fill Transparent"): gr.Markdown( "Upload a partially transparent sprite (RGBA). " "The model will fill all transparent pixels based on the visible context.") with gr.Row(): with gr.Column(scale=1): alp_img = gr.Image( label="Partial Sprite (RGBA)", type="pil", image_mode="RGBA", height=200) alp_use_ema = _ema_ctrl() alp_upscale = _upscale_ctrl("16") alp_btn = gr.Button("✨ Fill Transparent", variant="primary") with gr.Column(scale=1): alp_out = gr.Image( label="Completed Sprite", type="pil", format="png", elem_classes=["pixel-out"]) alp_pal = _palette_out() alp_btn.click( fn_inpaint_alpha, inputs=[alp_img, alp_use_ema, alp_upscale], outputs=[alp_out, alp_pal]) # ── TAB 6: VARIATIONS ──────────────────────────────────────────── with gr.TabItem("🎲 Variations"): gr.Markdown( "Upload a sprite and generate **N creative variations**. " "Diversity 0 = almost identical · 1 = free improvisation on the theme.") with gr.Row(): with gr.Column(scale=1): var_img = gr.Image( label="Source Sprite", type="pil", image_mode="RGBA", height=200) var_div = gr.Slider(0.0, 1.0, 0.5, step=0.05, label="Diversity", info="How far variants can deviate from the source") var_n = gr.Slider(1, 16, 4, step=1, label="Number of Variations") var_steps = gr.Slider(4, 64, 16, step=1, label="Decoding Steps") var_use_ema = _ema_ctrl() var_upscale = _upscale_ctrl("16") var_btn = gr.Button("🎲 Generate Variations", variant="primary") with gr.Column(scale=2): var_gallery = gr.Gallery( label="Variations", columns=4, rows=4, height="auto", elem_classes=["pixel-gallery"], format="png") var_pal = _palette_out() var_btn.click( fn_variation, inputs=[var_img, var_div, var_n, var_steps, var_use_ema, var_upscale], outputs=[var_gallery, var_pal]) # ── Footer ──────────────────────────────────────────────────────────── gr.Markdown( """ --- **PixelMAE v10** · Asymmetric MAE · 4.2M params · 16×16 RGBA Model: [`MidFord327/PixelArt-MAE-v7`](https://huggingface.co/MidFord327/PixelArt-MAE-v7) """) if __name__ == "__main__": app.launch()