Spaces:
Sleeping
Sleeping
| """Prototype of the actual app engine: pixel-perfect editing of real PNG | |
| textures via grammar-constrained decoding on the local Gemma server. | |
| Pipeline: PNG -> token-stable wire format (palette quantized to <=26 keys, | |
| alpha thresholded, grid cells separated by spaces) -> per-file GBNF grammar | |
| locking the footprint (exact or 2x upscale) -> local llama.cpp generation -> | |
| parse -> true-alpha PNG out + checkerboard preview. | |
| Supports non-square textures (Minecraft mob atlases are 64x32 etc.). | |
| Usage: | |
| python pixel_editor.py <input.png> "<instruction>" [--upscale] [--out DIR] | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import asyncio | |
| import sys | |
| import time | |
| from collections import Counter | |
| from pathlib import Path | |
| import httpx | |
| from PIL import Image | |
| sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) | |
| import config | |
| import render | |
| import validate | |
| # The app allows richer palettes than the benchmark's 12-entry cap. | |
| # validate.py defines its own module-level constant, so patch it there. | |
| config.MAX_PALETTE = 64 | |
| validate.MAX_PALETTE = 64 | |
| URL = "http://localhost:8080/v1/chat/completions" | |
| KEY_ALPHABET = "ABCDEFGHIJKLMNOPQRSTUVWXYZ" # up to 26 opaque colors | |
| MAX_COLORS = len(KEY_ALPHABET) | |
| APP_SYSTEM = """You are an expert pixel artist editing real game textures \ | |
| (Minecraft resource pack style). You receive a texture as a palette-indexed \ | |
| character grid and an instruction. Recolor and restyle it with conviction: \ | |
| full color ramps (dark, mid, light per material), hue-shifted shading (shadows \ | |
| toward purple/blue, highlights warm), consistent upper-left light, no flat \ | |
| single-color fills, no lazy tint shifts. The pixel layout is locked; you choose \ | |
| only the colors. Output the PALETTE block then the GRID block in the exact \ | |
| format of the input. Grid rows use one space between single-character cells; \ | |
| preserve that row format exactly. Write nothing else.""" | |
| def _format_rows(rows: list[str], spaced: bool) -> list[str]: | |
| if not spaced: | |
| return list(rows) | |
| return [" ".join(row) for row in rows] | |
| def sprite_to_wire(sprite: validate.Sprite, spaced: bool = False) -> str: | |
| lines = ["PALETTE"] | |
| for key, rgb in sprite.palette.items(): | |
| if rgb is None: | |
| lines.append("{} transparent".format(key)) | |
| else: | |
| lines.append("{} {},{},{}".format(key, *rgb)) | |
| lines.append("GRID {}x{}".format(sprite.width, sprite.height)) | |
| lines.extend(_format_rows(sprite.rows, spaced)) | |
| return "\n".join(lines) | |
| def png_to_wire(path: Path, spaced: bool = False) -> tuple[str, int, int]: | |
| """Convert a PNG to wire format. Returns (wire_text, width, height).""" | |
| im = Image.open(path).convert("RGBA") | |
| w, h = im.size | |
| pixels = list(im.get_flattened_data() if hasattr(im, "get_flattened_data") else im.getdata()) | |
| opaque = [(r, g, b) for r, g, b, a in pixels if a >= 128] | |
| counts = Counter(opaque) | |
| if len(counts) > MAX_COLORS: | |
| # Quantize opaque colors down to MAX_COLORS with Pillow's median cut. | |
| tmp = Image.new("RGB", (len(opaque), 1)) | |
| tmp.putdata(opaque) | |
| quant = tmp.quantize(colors=MAX_COLORS) | |
| qpal = quant.getpalette()[: MAX_COLORS * 3] | |
| centers = [tuple(qpal[i * 3 : i * 3 + 3]) for i in range(MAX_COLORS)] | |
| def nearest(c: tuple[int, int, int]) -> tuple[int, int, int]: | |
| return min(centers, key=lambda k: sum((a - b) ** 2 for a, b in zip(k, c))) | |
| mapping = {c: nearest(c) for c in counts} | |
| counts = Counter(mapping[c] for c in opaque) | |
| else: | |
| mapping = {c: c for c in counts} | |
| keys: dict[tuple[int, int, int], str] = {} | |
| for i, (color, _n) in enumerate(counts.most_common()): | |
| keys[color] = KEY_ALPHABET[i] | |
| rows = [] | |
| idx = 0 | |
| for y in range(h): | |
| row = [] | |
| for x in range(w): | |
| r, g, b, a = pixels[idx] | |
| idx += 1 | |
| row.append("." if a < 128 else keys[mapping[(r, g, b)]]) | |
| rows.append("".join(row)) | |
| palette: dict[str, tuple[int, int, int] | None] = {".": None} | |
| for color, key in sorted(keys.items(), key=lambda kv: kv[1]): | |
| palette[key] = color | |
| sprite = validate.Sprite(palette=palette, width=w, height=h, rows=rows) | |
| return sprite_to_wire(sprite, spaced=spaced), w, h | |
| def wire_to_png(sprite: validate.Sprite, out_path: Path) -> None: | |
| """Write the true-alpha 1:1 PNG (the actual resource-pack artifact).""" | |
| im = Image.new("RGBA", (sprite.width, sprite.height)) | |
| data = [] | |
| for row in sprite.rows: | |
| for ch in row: | |
| rgb = sprite.palette[ch] | |
| data.append((0, 0, 0, 0) if rgb is None else (*rgb, 255)) | |
| im.putdata(data) | |
| out_path.parent.mkdir(parents=True, exist_ok=True) | |
| im.save(out_path) | |
| def _row_grammar(row: str, spaced: bool) -> str: | |
| if spaced: | |
| parts: list[str] = [] | |
| for i, ch in enumerate(row): | |
| parts.append('"."' if ch == "." else "ckey") | |
| if i != len(row) - 1: | |
| parts.append('" "') | |
| return " ".join(parts) | |
| parts = ['"."' if ch == "." else "ckey" for ch in row] | |
| # Collapse consecutive '.' literals into one string literal for a smaller | |
| # grammar (matters at 128x128). | |
| merged: list[str] = [] | |
| dots = 0 | |
| for part in parts: | |
| if part == '"."': | |
| dots += 1 | |
| else: | |
| if dots: | |
| merged.append('"{}"'.format("." * dots)) | |
| dots = 0 | |
| merged.append(part) | |
| if dots: | |
| merged.append('"{}"'.format("." * dots)) | |
| return " ".join(merged) | |
| def build_grammar( | |
| rows: list[str], n_keys: int, upscale: bool, spaced: bool = False | |
| ) -> str: | |
| """Footprint-locked GBNF. '.' cells are literal; colored cells sample ckey.""" | |
| if upscale: | |
| rows = ["".join(ch * 2 for ch in r) for r in rows for _ in (0, 1)] | |
| w, h = len(rows[0]), len(rows) | |
| key_class = KEY_ALPHABET[:n_keys] | |
| palette_part = "".join( | |
| '"{}" " " rgb "\\n" '.format(k) for k in key_class | |
| ) | |
| row_refs = " ".join("r{}".format(y) for y in range(h)) | |
| lines = [ | |
| 'root ::= "PALETTE\\n. transparent\\n" {} "GRID {}x{}\\n" {}'.format( | |
| palette_part, w, h, row_refs | |
| ), | |
| 'rgb ::= num "," num "," num', | |
| 'num ::= ("25" [0-5]) | ("2" [0-4] [0-9]) | ("1" [0-9] [0-9]) | ([1-9] [0-9]) | [0-9]', | |
| "ckey ::= [{}]".format(key_class), | |
| ] | |
| for y, row in enumerate(rows): | |
| lines.append('r{} ::= {} "\\n"'.format(y, _row_grammar(row, spaced))) | |
| return "\n".join(lines) | |
| async def edit_file(path: Path, instruction: str, upscale: bool, out_dir: Path) -> dict: | |
| wire, w, h = png_to_wire(path, spaced=True) | |
| in_sprite, perr = validate.parse_sprite(wire) | |
| assert in_sprite is not None, perr | |
| n_keys = len([k for k in in_sprite.palette if k != "."]) | |
| grammar = build_grammar(in_sprite.rows, n_keys, upscale, spaced=True) | |
| if upscale: | |
| contract = ( | |
| "Redraw this texture at {}x{} (2x). Every input pixel becomes a 2x2 " | |
| "block: transparent stays transparent, colored stays colored. Add " | |
| "finer shading and detail within that constraint." | |
| ).format(w * 2, h * 2) | |
| else: | |
| contract = ( | |
| "Edit this texture. The grid stays {}x{} and every transparent cell " | |
| "stays transparent; change only the colors of non-transparent cells." | |
| ).format(w, h) | |
| user_msg = "{}\n\nInstruction: {}\n\nHere is the input texture:\n{}".format( | |
| contract, instruction, wire | |
| ) | |
| out_cells = (w * 2) * (h * 2) if upscale else w * h | |
| max_tokens = min(int(out_cells * 1.6) + 800, 40000) | |
| payload = { | |
| "model": "gemma-4-12b", | |
| "messages": [ | |
| {"role": "system", "content": APP_SYSTEM}, | |
| {"role": "user", "content": user_msg}, | |
| ], | |
| "max_tokens": max_tokens, | |
| "temperature": 0.7, | |
| "chat_template_kwargs": {"enable_thinking": False}, | |
| "grammar": grammar, | |
| } | |
| started = time.perf_counter() | |
| async with httpx.AsyncClient(timeout=1800.0) as http: | |
| resp = await http.post(URL, json=payload) | |
| latency = time.perf_counter() - started | |
| resp.raise_for_status() | |
| data = resp.json() | |
| text = data["choices"][0]["message"]["content"] or "" | |
| finish = data["choices"][0].get("finish_reason") | |
| sprite, perr = validate.parse_sprite(text) | |
| result = { | |
| "file": path.name, | |
| "mode": "upscale2x" if upscale else "exact", | |
| "latency_s": round(latency, 1), | |
| "completion_tokens": data.get("usage", {}).get("completion_tokens"), | |
| "finish_reason": finish, | |
| "parsed": sprite is not None, | |
| "error": perr, | |
| } | |
| if sprite is not None: | |
| stem = "{}__{}".format(path.stem, "2x" if upscale else "edit") | |
| wire_to_png(sprite, out_dir / (stem + ".png")) | |
| scale = max(2, 512 // max(sprite.width, sprite.height)) | |
| render.save_render(sprite, out_dir / (stem + "_preview.png"), scale=scale) | |
| # Footprint check (the pixel-perfect guarantee, verified not assumed) | |
| in_fp = {(x, y) for y, r in enumerate(in_sprite.rows) for x, ch in enumerate(r) if ch != "."} | |
| if upscale: | |
| in_fp = {(2 * x + dx, 2 * y + dy) for x, y in in_fp for dx in (0, 1) for dy in (0, 1)} | |
| out_fp = {(x, y) for y, r in enumerate(sprite.rows) for x, ch in enumerate(r) if ch != "."} | |
| result["footprint_perfect"] = in_fp == out_fp | |
| return result | |
| async def main() -> None: | |
| ap = argparse.ArgumentParser() | |
| ap.add_argument("input") | |
| ap.add_argument("instruction") | |
| ap.add_argument("--upscale", action="store_true") | |
| ap.add_argument("--out", default=str(Path(__file__).parent / "edited")) | |
| args = ap.parse_args() | |
| result = await edit_file(Path(args.input), args.instruction, args.upscale, Path(args.out)) | |
| for k, v in result.items(): | |
| print("{}: {}".format(k, v)) | |
| if __name__ == "__main__": | |
| asyncio.run(main()) | |