PixelMAE / app.py
MidFord327's picture
Update app.py
2feaac2 verified
"""
PixelMAE v10 β€” Neural Sprite Engine
Hugging Face Spaces Β· Production UI
"""
import os
import math
import warnings
from typing import Optional, Dict, List, Tuple
from dataclasses import dataclass, field
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image, ImageDraw
import gradio as gr
from huggingface_hub import hf_hub_download
warnings.filterwarnings("ignore")
# ============================================================================
# REPO CONFIG β€” edit these two lines
# ============================================================================
REPO_ID = "MidFord327/PixelArt-MAE-v7"
MODEL_FILES = {"best": "best.pth", "latest": "latest.pth"}
# ============================================================================
# 1. CONFIGURATION
# ============================================================================
@dataclass
class Config:
IMAGE_SIZE: int = 16
CHANNELS: int = 4
NUM_PIXELS: int = 256
LATENT_DIM: int = 192
ENCODER_LAYERS:int = 6
DECODER_LAYERS:int = 4
NUM_HEADS: int = 6
FFN_RATIO: int = 4
DROPOUT: float = 0.0
PALETTE_SIZE: int = 16
USE_EMA: bool = True
EMA_DECAY: float = 0.9995
EMA_UPDATE_EVERY: int = 1
DEVICE: torch.device = field(
default_factory=lambda: torch.device("cpu"))
# ============================================================================
# 2. ARCHITECTURE
# ============================================================================
def get_2d_sincos_pos_embed(embed_dim: int, grid_size: int) -> torch.Tensor:
grid_h = torch.arange(grid_size, dtype=torch.float32)
grid_w = torch.arange(grid_size, dtype=torch.float32)
grid = torch.meshgrid(grid_h, grid_w, indexing="ij")
grid = torch.stack(grid, dim=0).reshape(2, -1)
half = embed_dim // 4
omega = 1.0 / (10000 ** (torch.arange(half, dtype=torch.float32) / half))
out_h = torch.einsum("n,d->nd", grid[0], omega)
out_w = torch.einsum("n,d->nd", grid[1], omega)
pe = torch.zeros(grid_size * grid_size, embed_dim)
pe[:, 0::4] = torch.sin(out_h); pe[:, 1::4] = torch.cos(out_h)
pe[:, 2::4] = torch.sin(out_w); pe[:, 3::4] = torch.cos(out_w)
return pe.unsqueeze(0)
class EMA:
def __init__(self, model, decay=0.9995, update_every=1):
self.model = model
self.decay = decay
self.update_every = update_every
self.step = 0
self.shadow = {n: p.data.clone()
for n, p in model.named_parameters() if p.requires_grad}
self.backup = {}
def apply_shadow(self):
for name, param in self.model.named_parameters():
if param.requires_grad and name in self.shadow:
self.backup[name] = param.data.clone()
param.data.copy_(self.shadow[name])
def restore(self):
for name, param in self.model.named_parameters():
if name in self.backup:
param.data.copy_(self.backup[name])
self.backup = {}
def load_state_dict(self, state, device):
self.shadow = {k: v.to(device) for k, v in state["shadow"].items()}
self.step = state["step"]
class MHSA(nn.Module):
def __init__(self, dim, heads, dropout=0.0):
super().__init__()
self.heads = heads
self.head_dim = dim // heads
self.qkv = nn.Linear(dim, dim * 3)
self.proj = nn.Linear(dim, dim)
self.drop = dropout
def forward(self, x):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.heads, self.head_dim)
q, k, v = qkv.permute(2, 0, 3, 1, 4).unbind(0)
out = F.scaled_dot_product_attention(q, k, v, dropout_p=0.0)
return self.proj(out.transpose(1, 2).reshape(B, N, C))
class TransformerBlock(nn.Module):
def __init__(self, dim, heads, ffn_ratio=4, dropout=0.0):
super().__init__()
self.norm1 = nn.LayerNorm(dim)
self.attn = MHSA(dim, heads, dropout)
self.norm2 = nn.LayerNorm(dim)
self.ffn = nn.Sequential(
nn.Linear(dim, dim * ffn_ratio), nn.GELU(), nn.Dropout(dropout),
nn.Linear(dim * ffn_ratio, dim), nn.Dropout(dropout))
def forward(self, x):
x = x + self.attn(self.norm1(x))
x = x + self.ffn(self.norm2(x))
return x
class AsymmetricPixelMAE(nn.Module):
def __init__(self, config: Config):
super().__init__()
self.config = config
D, N = config.LATENT_DIM, config.NUM_PIXELS
self.pixel_embed = nn.Linear(config.CHANNELS, D)
self.embed_norm = nn.LayerNorm(D)
self.pos_embed = nn.Parameter(
get_2d_sincos_pos_embed(D, config.IMAGE_SIZE), requires_grad=False)
self.cls_token = nn.Parameter(torch.zeros(1, 1, D))
self.mask_token = nn.Parameter(torch.zeros(1, 1, D))
nn.init.normal_(self.cls_token, std=0.02)
nn.init.normal_(self.mask_token, std=0.02)
self.encoder = nn.ModuleList([
TransformerBlock(D, config.NUM_HEADS, config.FFN_RATIO, config.DROPOUT)
for _ in range(config.ENCODER_LAYERS)])
self.encoder_norm = nn.LayerNorm(D)
self.decoder_embed= nn.Linear(D, D)
self.decoder_pos_embed = nn.Parameter(torch.zeros(1, N + 1, D), requires_grad=False)
dp = torch.zeros(1, N + 1, D)
dp[:, 1:, :] = get_2d_sincos_pos_embed(D, config.IMAGE_SIZE)
self.decoder_pos_embed.data.copy_(dp)
self.decoder = nn.ModuleList([
TransformerBlock(D, config.NUM_HEADS, config.FFN_RATIO, config.DROPOUT)
for _ in range(config.DECODER_LAYERS)])
self.decoder_norm = nn.LayerNorm(D)
self.pixel_head_rgb = nn.Sequential(
nn.LayerNorm(D), nn.Linear(D, D), nn.GELU(), nn.Linear(D, 3))
self.pixel_head_alpha = nn.Sequential(
nn.LayerNorm(D), nn.Linear(D, D), nn.GELU(), nn.Linear(D, 1))
self.palette_head = nn.Sequential(
nn.LayerNorm(D), nn.Linear(D, D), nn.GELU(),
nn.Linear(D, config.PALETTE_SIZE * config.CHANNELS), nn.Sigmoid())
def _tokenize(self, x):
return self.embed_norm(self.pixel_embed(x.flatten(2).transpose(1, 2)))
def forward_encoder(self, x, mask=None):
B, N, D = x.shape[0], self.config.NUM_PIXELS, self.config.LATENT_DIM
tokens = self._tokenize(x) + self.pos_embed
if mask is not None:
mf = mask.flatten(1)
if mf.shape[1] != N: mf = mask.reshape(B, -1)
if mf.dim() == 3: mf = mf.squeeze(1)
noise = torch.rand(B, N, device=x.device)
ids_shuf = torch.argsort(noise + mf.float() * 1e6, dim=1)
ids_rest = torch.argsort(ids_shuf, dim=1)
n_vis = max(1, int((1 - mf.float()).sum(dim=1).min().item()))
ids_keep = ids_shuf[:, :n_vis]
tokens = torch.gather(tokens, 1, ids_keep.unsqueeze(-1).expand(-1, -1, D))
else:
n_vis = N
ids_rest = torch.arange(N, device=x.device).unsqueeze(0).expand(B, -1)
cls = self.cls_token.expand(B, -1, -1)
x_enc = torch.cat([cls, tokens], dim=1)
for blk in self.encoder: x_enc = blk(x_enc)
return self.encoder_norm(x_enc), ids_rest, n_vis
def forward_decoder(self, x_enc, ids_rest, n_vis):
B, N, D = x_enc.shape[0], self.config.NUM_PIXELS, self.config.LATENT_DIM
cls_enc, vis_enc = x_enc[:, :1], x_enc[:, 1:]
vis_dec = self.decoder_embed(vis_enc)
n_masked = N - n_vis
if n_masked > 0:
mask_tok = self.mask_token.expand(B, n_masked, -1)
full_seq = torch.cat([vis_dec, mask_tok], dim=1)
else:
full_seq = vis_dec
full_seq = torch.gather(full_seq, 1, ids_rest.unsqueeze(-1).expand(-1, -1, D))
full_seq = torch.cat([self.decoder_embed(cls_enc), full_seq], dim=1) + self.decoder_pos_embed
for blk in self.decoder: full_seq = blk(full_seq)
return self.decoder_norm(full_seq)[:, 1:]
def forward(self, x, mask=None):
B, C, H, W = x.shape
enc, ids, n_vis = self.forward_encoder(x, mask)
dec = self.forward_decoder(enc, ids, n_vis)
rgb = torch.sigmoid(self.pixel_head_rgb(dec)).transpose(1, 2).reshape(B, 3, H, W)
alp = torch.sigmoid(self.pixel_head_alpha(dec)).transpose(1, 2).reshape(B, 1, H, W)
return {"pixel_pred": torch.cat([rgb, alp], dim=1)}
# ============================================================================
# 3. INFERENCE ENGINE (extended)
# ============================================================================
class InferenceEngine:
def __init__(self, config: Config, device: torch.device):
self.config = config
self.device = device
self.model = AsymmetricPixelMAE(config).to(device)
self.ema = EMA(self.model, config.EMA_DECAY, config.EMA_UPDATE_EVERY) if config.USE_EMA else None
self.ckpt_meta: Dict = {}
def load_checkpoint(self, path: str) -> Dict:
ckpt = torch.load(path, map_location=self.device, weights_only=False)
state = {k.replace("module.", "").replace("_orig_mod.", ""): v
for k, v in ckpt["model"].items()}
self.model.load_state_dict(state, strict=False)
if self.ema and ckpt.get("ema"):
try: self.ema.load_state_dict(ckpt["ema"], self.device)
except Exception: pass
self.ckpt_meta = {
"epoch": ckpt.get("epoch", "?"),
"best_loss": ckpt.get("best_loss", None),
"has_ema": bool(self.ema and ckpt.get("ema")),
}
return self.ckpt_meta
# ── Core helpers ──────────────────────────────────────────────────────
def _apply_ema(self):
if self.ema: self.ema.apply_shadow()
def _restore_ema(self):
if self.ema: self.ema.restore()
# ── Inpaint ───────────────────────────────────────────────────────────
@torch.no_grad()
def inpaint(self, image: torch.Tensor, mask: torch.Tensor,
use_ema: bool = True) -> torch.Tensor:
self.model.eval()
if image.dim() == 3: image = image.unsqueeze(0)
if mask.dim() == 3: mask = mask.unsqueeze(0)
if use_ema: self._apply_ema()
try:
pred = self.model(image, mask=mask.float())["pixel_pred"]
result = torch.where(mask.expand_as(image).bool(), pred, image)
finally:
if use_ema: self._restore_ema()
return result.clamp(0, 1)
# ── Restore ───────────────────────────────────────────────────────────
@torch.no_grad()
def restore(self, image: torch.Tensor, strength: float = 0.5,
use_ema: bool = True) -> torch.Tensor:
self.model.eval()
if image.dim() == 3: image = image.unsqueeze(0)
B, C, H, W = image.shape
ratio = 0.05 + strength * 0.50
mask = (torch.rand(B, 1, H, W, device=image.device) < ratio).float()
return self.inpaint(image, mask, use_ema=use_ema)
@torch.no_grad()
def restore_multi_pass(self, image: torch.Tensor, strength: float = 0.5,
passes: int = 3, use_ema: bool = True) -> torch.Tensor:
"""Iterative multi-pass restore β€” each pass uses slightly less strength."""
result = image.clone()
for i in range(passes):
s = strength * (1.0 - i * 0.1)
result = self.restore(result, strength=max(0.05, s), use_ema=use_ema)
return result
# ── Generate from scratch ─────────────────────────────────────────────
@torch.no_grad()
def generate(self, n_samples: int = 1, n_steps: int = 12,
seed_image: Optional[torch.Tensor] = None,
seed_ratio: float = 0.0,
temperature: float = 1.0,
use_ema: bool = True) -> torch.Tensor:
self.model.eval()
if use_ema: self._apply_ema()
try:
H = W = self.config.IMAGE_SIZE
C, N = self.config.CHANNELS, self.config.NUM_PIXELS
dev = self.device
B = n_samples
if seed_image is not None:
if seed_image.dim() == 3: seed_image = seed_image.unsqueeze(0)
seed_image = seed_image.expand(B, -1, -1, -1)
canvas = seed_image.clone()
keep = (torch.rand(B, 1, H, W, device=dev) < seed_ratio)
else:
canvas = torch.rand(B, C, H, W, device=dev) * 0.5
keep = torch.zeros(B, 1, H, W, device=dev, dtype=torch.bool)
revealed = keep.float()
for step in range(n_steps):
mask = 1.0 - revealed
n_unrevealed = int(mask.sum().item() / B)
if n_unrevealed == 0: break
pred = self.model(canvas, mask=mask)["pixel_pred"]
# temperature scaling β€” add noise before confidence scoring
if temperature != 1.0:
noise = torch.randn_like(pred) * (temperature - 1.0) * 0.05
pred = (pred + noise).clamp(0, 1)
canvas = torch.where(mask.expand_as(canvas).bool(), pred, canvas)
if step < n_steps - 1:
confidence = torch.abs(pred - 0.5).mean(dim=1, keepdim=True)
confidence = confidence * mask + torch.rand_like(confidence) * 0.05
progress = (step + 1) / n_steps
target_frac = 1.0 - 0.5 * (1 + math.cos(math.pi * progress))
cur_revealed = int(revealed.sum().item() / B)
n_reveal = max(1, int(target_frac * N) - cur_revealed)
n_reveal = min(n_reveal, n_unrevealed)
flat_conf = confidence.reshape(B, -1).masked_fill(
mask.reshape(B, -1) < 0.5, -float("inf"))
_, top = flat_conf.topk(min(n_reveal, N), dim=1)
new_rev = torch.zeros(B, N, device=dev).scatter_(1, top, 1.0)
revealed = (revealed + new_rev.reshape(B, 1, H, W)).clamp(0, 1)
finally:
if use_ema: self._restore_ema()
return canvas.clamp(0, 1)
# ── Variation ─────────────────────────────────────────────────────────
@torch.no_grad()
def generate_variation(self, image: torch.Tensor,
diversity: float = 0.5,
n_samples: int = 1,
n_steps: int = 16,
use_ema: bool = True) -> torch.Tensor:
"""Generate variations of an existing sprite.
diversity=0 β†’ near-identical copy; diversity=1 β†’ fully free generation."""
seed_ratio = 1.0 - diversity
return self.generate(
n_samples=n_samples, n_steps=n_steps,
seed_image=image, seed_ratio=seed_ratio,
use_ema=use_ema)
# ── Palette extraction from CLS token ─────────────────────────────────
@torch.no_grad()
def extract_palette(self, image: torch.Tensor,
use_ema: bool = True) -> torch.Tensor:
"""Returns (PALETTE_SIZE, 4) RGBA palette from the model's palette head."""
self.model.eval()
if image.dim() == 3: image = image.unsqueeze(0)
if use_ema: self._apply_ema()
try:
enc, _, _ = self.model.forward_encoder(image)
cls_tok = enc[:, :1, :] # (1, 1, D)
palette = self.model.palette_head(cls_tok) # (1, 1, P*C)
palette = palette.reshape(self.config.PALETTE_SIZE, self.config.CHANNELS)
finally:
if use_ema: self._restore_ema()
return palette.clamp(0, 1)
# ============================================================================
# 4. INIT ENGINE + MODEL MANAGEMENT
# ============================================================================
print("Initializing PixelMAE Inference Engine …")
config = Config()
engine = InferenceEngine(config, config.DEVICE)
# Global state
_loaded_model_key = None
_ckpt_info_text = "No model loaded."
_model_paths: Dict[str, Optional[str]] = {"best": None, "latest": None}
def _download_model(key: str) -> Tuple[Optional[str], str]:
filename = MODEL_FILES[key]
try:
path = hf_hub_download(repo_id=REPO_ID, filename=filename)
return path, f"βœ… Downloaded `{filename}` from `{REPO_ID}`."
except Exception as e:
return None, f"❌ Could not download `{filename}`: {e}"
def load_model(model_choice: str) -> str:
global _loaded_model_key, _ckpt_info_text, _model_paths
key = "best" if model_choice == "Best (best.pth)" else "latest"
if _model_paths[key] is None:
path, msg = _download_model(key)
if path is None:
_ckpt_info_text = msg
return _ckpt_info_text
_model_paths[key] = path
try:
meta = engine.load_checkpoint(_model_paths[key])
loss_str = f"{meta['best_loss']:.6f}" if meta["best_loss"] is not None else "N/A"
_ckpt_info_text = (
f"**Model:** `{MODEL_FILES[key]}` \n"
f"**Epoch:** {meta['epoch']} \n"
f"**Best Val Loss:** {loss_str} \n"
f"**EMA Weights:** {'βœ… loaded' if meta['has_ema'] else '⚠️ not found'} \n"
f"**Device:** {config.DEVICE}"
)
_loaded_model_key = key
except Exception as e:
_ckpt_info_text = f"❌ Load error: {e}"
return _ckpt_info_text
# Pre-load best model at startup (non-fatal)
try:
_path, _msg = _download_model("best")
if _path:
_model_paths["best"] = _path
load_model("Best (best.pth)")
print(f" └─ {_ckpt_info_text.replace(chr(10), ' ')}")
else:
print(f" └─ {_msg}")
except Exception as ex:
print(f" └─ Startup load failed: {ex}")
# ============================================================================
# 5. IMAGE UTILITIES
# ============================================================================
def preprocess(pil_img: Image.Image) -> torch.Tensor:
img = pil_img.convert("RGBA").resize(
(config.IMAGE_SIZE, config.IMAGE_SIZE), Image.Resampling.NEAREST)
arr = np.array(img, np.float32) / 255.0
return torch.from_numpy(arr).permute(2, 0, 1).unsqueeze(0).to(config.DEVICE)
def postprocess(tensor: torch.Tensor, upscale: int = 16) -> Image.Image:
"""Convert tensor β†’ pixel-perfect upscaled PIL image (RGBA)."""
if tensor is None: return None
if tensor.dim() == 4: tensor = tensor.squeeze(0)
arr = (tensor.permute(1, 2, 0).cpu().clamp(0, 1).numpy() * 255).astype(np.uint8)
img = Image.fromarray(arr, "RGBA")
return img.resize(
(config.IMAGE_SIZE * upscale, config.IMAGE_SIZE * upscale),
Image.Resampling.NEAREST)
def tensor_to_pils(batch: torch.Tensor, upscale: int = 16) -> List[Image.Image]:
"""Batch tensor β†’ list of PIL images."""
if batch.dim() == 3: batch = batch.unsqueeze(0)
return [postprocess(batch[i], upscale) for i in range(batch.shape[0])]
def make_palette_image(palette: torch.Tensor, swatch_size: int = 32) -> Image.Image:
"""Render palette as a row of color swatches."""
n = palette.shape[0]
img = Image.new("RGBA", (n * swatch_size, swatch_size), (0, 0, 0, 0))
draw = ImageDraw.Draw(img)
for i, color in enumerate(palette):
rgba = tuple((color.cpu().numpy() * 255).astype(int).tolist())
draw.rectangle([i * swatch_size, 0, (i + 1) * swatch_size - 1, swatch_size - 1],
fill=rgba)
return img
def compare_images(before: Image.Image, after: Image.Image,
label_size: int = 10) -> Image.Image:
"""Side-by-side comparison with labels."""
w, h = before.size
total = Image.new("RGBA", (w * 2 + 4, h), (20, 20, 20, 255))
total.paste(before, (0, 0))
total.paste(after, (w + 4, 0))
return total
def extract_unique_colors(img: Image.Image, max_colors: int = 32) -> List[Tuple]:
"""Extract sorted unique RGBA colors (ignoring transparent)."""
arr = np.array(img.convert("RGBA"))
pixels = arr.reshape(-1, 4)
pixels = pixels[pixels[:, 3] > 10] # ignore near-transparent
unique = np.unique(pixels, axis=0)
return [tuple(c) for c in unique[:max_colors]]
# ============================================================================
# 6. GRADIO FUNCTIONS
# ============================================================================
def fn_load_model(choice):
return load_model(choice)
def fn_generate(seed_img, seed_ratio, steps, n_samples,
temperature, use_ema, upscale):
n_samples = int(n_samples)
steps = int(steps)
upscale = int(upscale)
seed_t = preprocess(seed_img) if seed_img is not None else None
out = engine.generate(
n_samples=n_samples, n_steps=steps,
seed_image=seed_t, seed_ratio=seed_ratio,
temperature=temperature, use_ema=use_ema)
pils = tensor_to_pils(out, upscale)
# Return gallery items + first image separately + palette
palette = engine.extract_palette(out[0:1], use_ema=use_ema)
pal_img = make_palette_image(palette, swatch_size=40)
gallery = [(p, f"Sample {i+1}") for i, p in enumerate(pils)]
return gallery, pils[0], pal_img
def fn_generate_single(seed_img, seed_ratio, steps, temperature, use_ema, upscale):
upscale = int(upscale)
steps = int(steps)
seed_t = preprocess(seed_img) if seed_img is not None else None
out = engine.generate(
n_samples=1, n_steps=steps,
seed_image=seed_t, seed_ratio=seed_ratio,
temperature=temperature, use_ema=use_ema)
pil = postprocess(out[0], upscale)
palette = engine.extract_palette(out[0:1], use_ema=use_ema)
pal_img = make_palette_image(palette, swatch_size=40)
return pil, pal_img
def fn_restore(image, strength, passes, use_ema, upscale):
if image is None: return None, None, None
upscale = int(upscale)
passes = int(passes)
in_t = preprocess(image)
if passes > 1:
out_t = engine.restore_multi_pass(in_t, strength=strength, passes=passes, use_ema=use_ema)
else:
out_t = engine.restore(in_t, strength=strength, use_ema=use_ema)
before = postprocess(in_t[0], upscale)
after = postprocess(out_t[0], upscale)
compare = compare_images(before, after)
palette = engine.extract_palette(out_t[0:1], use_ema=use_ema)
pal_img = make_palette_image(palette, swatch_size=40)
return after, compare, pal_img
def fn_inpaint(editor_dict, use_ema, upscale):
if not editor_dict or editor_dict.get("background") is None:
return None, None
upscale = int(upscale)
bg = editor_dict["background"].convert("RGBA")
in_t = preprocess(bg)
mask_t = torch.zeros((1, 1, 16, 16), device=config.DEVICE)
layers = editor_dict.get("layers", [])
if layers:
drawing = layers[0].convert("RGBA").resize((16, 16), Image.Resampling.NEAREST)
mask_np = (np.array(drawing)[:, :, 3] > 0).astype(np.float32)
mask_t = torch.from_numpy(mask_np).unsqueeze(0).unsqueeze(0).to(config.DEVICE)
out_t = engine.inpaint(in_t, mask_t, use_ema=use_ema)
result = postprocess(out_t[0], upscale)
palette = engine.extract_palette(out_t[0:1], use_ema=use_ema)
pal_img = make_palette_image(palette, swatch_size=40)
return result, pal_img
def fn_inpaint_alpha(image, use_ema, upscale):
"""Fill transparent pixels using the model."""
if image is None: return None, None
upscale = int(upscale)
in_t = preprocess(image)
mask_t = (in_t[:, 3:4] < 0.5).float()
out_t = engine.inpaint(in_t, mask_t, use_ema=use_ema)
result = postprocess(out_t[0], upscale)
palette = engine.extract_palette(out_t[0:1], use_ema=use_ema)
pal_img = make_palette_image(palette, swatch_size=40)
return result, pal_img
def fn_variation(image, diversity, n_var, steps, use_ema, upscale):
if image is None: return [], None
n_var = int(n_var)
steps = int(steps)
upscale = int(upscale)
in_t = preprocess(image)
out_t = engine.generate_variation(
in_t, diversity=diversity, n_samples=n_var,
n_steps=steps, use_ema=use_ema)
pils = tensor_to_pils(out_t, upscale)
palette = engine.extract_palette(out_t[0:1], use_ema=use_ema)
pal_img = make_palette_image(palette, swatch_size=40)
gallery = [(p, f"Variation {i+1}") for i, p in enumerate(pils)]
return gallery, pal_img
def fn_batch_generate(n_total, steps, temperature, use_ema, upscale):
n_total = int(n_total)
steps = int(steps)
upscale = int(upscale)
out_t = engine.generate(
n_samples=n_total, n_steps=steps,
temperature=temperature, use_ema=use_ema)
pils = tensor_to_pils(out_t, upscale)
gallery = [(p, f"#{i+1}") for i, p in enumerate(pils)]
return gallery
# ============================================================================
# 7. GRADIO UI
# ============================================================================
CSS = """
/* ── Pixel-perfect rendering for all output images ── */
.pixel-out img,
.pixel-gallery img,
.pixel-compare img {
image-rendering: pixelated !important;
image-rendering: crisp-edges !important;
}
/* ── Dark card styling ── */
.card { background: #1a1a2e; border-radius: 10px; padding: 12px; }
/* ── Status badge ── */
.status-box textarea { font-family: monospace; font-size: 12px; }
/* ── Palette strip ── */
.pal-out img {
image-rendering: pixelated !important;
border-radius: 4px;
border: 1px solid #444;
}
"""
_UPSCALE_CHOICES = ["4", "8", "16", "24", "32"]
_MODEL_CHOICES = ["Best (best.pth)", "Latest (latest.pth)"]
def _upscale_ctrl(default="16"):
return gr.Radio(
_UPSCALE_CHOICES, value=default, label="Preview Scale",
info="Pixel-perfect zoom multiplier (Γ— original 16px)")
def _ema_ctrl():
return gr.Checkbox(value=True, label="Use EMA Weights",
info="Usually better quality")
def _palette_out():
return gr.Image(label="Predicted Palette", type="pil",
elem_classes=["pal-out"])
with gr.Blocks(theme=gr.themes.Monochrome(), css=CSS,
title="PixelMAE v10 β€” Neural Sprite Engine") as app:
gr.Markdown(
"""
# πŸ‘Ύ PixelMAE β€” Neural Sprite Engine `v10`
Asymmetric Masked Autoencoder Β· 16 Γ— 16 RGBA Β· Generation Β· Restoration Β· Inpainting Β· Variations
""")
# ── Model loader ─────────────────────────────────────────────────────
with gr.Accordion("βš™οΈ Model & Checkpoint", open=False):
with gr.Row():
model_choice = gr.Radio(
_MODEL_CHOICES, value="Best (best.pth)",
label="Checkpoint")
load_btn = gr.Button("πŸ”„ Load / Reload Model", variant="secondary")
ckpt_info = gr.Markdown(value=_ckpt_info_text, label="Checkpoint Info")
load_btn.click(fn_load_model, inputs=model_choice, outputs=ckpt_info)
gr.Markdown("---")
# ── TABS ─────────────────────────────────────────────────────────────
with gr.Tabs():
# ── TAB 1: GENERATE ──────────────────────────────────────────────
with gr.TabItem("✨ Generate"):
gr.Markdown(
"Generate sprites from noise using iterative MaskGIT decoding. "
"Optionally seed from an existing image.")
with gr.Row():
with gr.Column(scale=1):
gen_seed = gr.Image(
label="Seed Image (optional)", type="pil",
image_mode="RGBA", height=200)
gen_ratio = gr.Slider(0.0, 1.0, 0.0, step=0.05,
label="Seed Fidelity",
info="0 = fully random Β· 1 = preserve seed completely")
gen_steps = gr.Slider(4, 64, 16, step=1,
label="Decoding Steps",
info="More steps β†’ smoother, slower")
gen_temp = gr.Slider(0.5, 2.0, 1.0, step=0.05,
label="Temperature",
info="< 1 sharper/conservative Β· > 1 more chaotic/varied")
gen_use_ema = _ema_ctrl()
gen_upscale = _upscale_ctrl("16")
gen_btn = gr.Button("⚑ Generate", variant="primary")
with gr.Column(scale=1):
gen_out = gr.Image(
label="Output Sprite", type="pil", format="png",
elem_classes=["pixel-out"])
gen_pal = _palette_out()
gen_btn.click(
fn_generate_single,
inputs=[gen_seed, gen_ratio, gen_steps,
gen_temp, gen_use_ema, gen_upscale],
outputs=[gen_out, gen_pal])
# ── TAB 2: BATCH GENERATE ────────────────────────────────────────
with gr.TabItem("πŸ—‚οΈ Batch Generate"):
gr.Markdown(
"Generate multiple sprites in one shot. "
"All sprites are rendered pixel-perfect in the gallery.")
with gr.Row():
with gr.Column(scale=1):
bg_n = gr.Slider(1, 16, 8, step=1,
label="Number of Sprites")
bg_steps = gr.Slider(4, 64, 16, step=1,
label="Decoding Steps")
bg_temp = gr.Slider(0.5, 2.0, 1.0, step=0.05,
label="Temperature")
bg_use_ema = _ema_ctrl()
bg_upscale = _upscale_ctrl("16")
bg_btn = gr.Button("⚑ Generate Batch", variant="primary")
with gr.Column(scale=2):
bg_gallery = gr.Gallery(
label="Generated Sprites", columns=4, rows=4,
height="auto", elem_classes=["pixel-gallery"], format="png")
bg_btn.click(
fn_batch_generate,
inputs=[bg_n, bg_steps, bg_temp, bg_use_ema, bg_upscale],
outputs=bg_gallery)
# ── TAB 3: RESTORE & REFINE ──────────────────────────────────────
with gr.TabItem("πŸ”„ Restore & Refine"):
gr.Markdown(
"Feed the model an existing sprite. It randomly masks regions and "
"reconstructs them, improving pixel coherence. Multi-pass iteratively refines.")
with gr.Row():
with gr.Column(scale=1):
rest_img = gr.Image(
label="Input Sprite", type="pil",
image_mode="RGBA", height=200)
rest_str = gr.Slider(0.0, 1.0, 0.5, step=0.05,
label="Mask Strength",
info="What fraction of pixels are randomly re-predicted")
rest_passes = gr.Slider(1, 6, 1, step=1,
label="Passes",
info="Multi-pass: each pass refines further")
rest_use_ema = _ema_ctrl()
rest_upscale = _upscale_ctrl("16")
rest_btn = gr.Button("πŸ”„ Restore", variant="primary")
with gr.Column(scale=1):
rest_out = gr.Image(
label="Restored Sprite", type="pil", format="png",
elem_classes=["pixel-out"])
rest_compare = gr.Image(
label="Before Β· After", type="pil",
elem_classes=["pixel-compare"])
rest_pal = _palette_out()
rest_btn.click(
fn_restore,
inputs=[rest_img, rest_str, rest_passes, rest_use_ema, rest_upscale],
outputs=[rest_out, rest_compare, rest_pal])
# ── TAB 4: SMART INPAINT ─────────────────────────────────────────
with gr.TabItem("πŸ–ŒοΈ Smart Inpaint"):
gr.Markdown(
"Upload a sprite and **paint the mask** over pixels you want the AI to redraw. "
"Use any brush color β€” coverage counts, not color.")
with gr.Row():
with gr.Column(scale=1):
inp_editor = gr.ImageEditor(
label="Draw Mask (paint = redo)",
type="pil", image_mode="RGBA",
brush=gr.Brush(colors=["#ff0000"], color_mode="fixed"),
height=300)
inp_use_ema = _ema_ctrl()
inp_upscale = _upscale_ctrl("16")
inp_btn = gr.Button("πŸ–ŒοΈ Inpaint", variant="primary")
with gr.Column(scale=1):
inp_out = gr.Image(
label="Inpainted Output", type="pil", format="png",
elem_classes=["pixel-out"])
inp_pal = _palette_out()
inp_btn.click(
fn_inpaint,
inputs=[inp_editor, inp_use_ema, inp_upscale],
outputs=[inp_out, inp_pal])
# ── TAB 5: FILL TRANSPARENT ──────────────────────────────────────
with gr.TabItem("πŸ” Fill Transparent"):
gr.Markdown(
"Upload a partially transparent sprite (RGBA). "
"The model will fill all transparent pixels based on the visible context.")
with gr.Row():
with gr.Column(scale=1):
alp_img = gr.Image(
label="Partial Sprite (RGBA)", type="pil",
image_mode="RGBA", height=200)
alp_use_ema = _ema_ctrl()
alp_upscale = _upscale_ctrl("16")
alp_btn = gr.Button("✨ Fill Transparent", variant="primary")
with gr.Column(scale=1):
alp_out = gr.Image(
label="Completed Sprite", type="pil", format="png",
elem_classes=["pixel-out"])
alp_pal = _palette_out()
alp_btn.click(
fn_inpaint_alpha,
inputs=[alp_img, alp_use_ema, alp_upscale],
outputs=[alp_out, alp_pal])
# ── TAB 6: VARIATIONS ────────────────────────────────────────────
with gr.TabItem("🎲 Variations"):
gr.Markdown(
"Upload a sprite and generate **N creative variations**. "
"Diversity 0 = almost identical Β· 1 = free improvisation on the theme.")
with gr.Row():
with gr.Column(scale=1):
var_img = gr.Image(
label="Source Sprite", type="pil",
image_mode="RGBA", height=200)
var_div = gr.Slider(0.0, 1.0, 0.5, step=0.05,
label="Diversity",
info="How far variants can deviate from the source")
var_n = gr.Slider(1, 16, 4, step=1,
label="Number of Variations")
var_steps = gr.Slider(4, 64, 16, step=1,
label="Decoding Steps")
var_use_ema = _ema_ctrl()
var_upscale = _upscale_ctrl("16")
var_btn = gr.Button("🎲 Generate Variations", variant="primary")
with gr.Column(scale=2):
var_gallery = gr.Gallery(
label="Variations", columns=4, rows=4,
height="auto", elem_classes=["pixel-gallery"],
format="png")
var_pal = _palette_out()
var_btn.click(
fn_variation,
inputs=[var_img, var_div, var_n, var_steps, var_use_ema, var_upscale],
outputs=[var_gallery, var_pal])
# ── Footer ────────────────────────────────────────────────────────────
gr.Markdown(
"""
---
**PixelMAE v10** Β· Asymmetric MAE Β· 4.2M params Β· 16Γ—16 RGBA
Model: [`MidFord327/PixelArt-MAE-v7`](https://huggingface.co/MidFord327/PixelArt-MAE-v7)
""")
if __name__ == "__main__":
app.launch()