pixelmodel / model.py
wop's picture
Upload 23 files
38bfe91 verified
Raw
History Blame Contribute Delete
4.25 kB
"""
PixelModel - the weights ARE the image.
model.png stores all weights as pixel RGB values (0-255 β†’ floats).
Architecture:
- Row 0..H//3-1 β†’ Layer 1 weights (prompt_dim β†’ hidden)
- Row H//3..2*H//3-1 β†’ Layer 2 weights (hidden β†’ hidden)
- Row 2*H//3..H-1 β†’ Layer 3 weights (hidden β†’ output_flat)
Width W = max(prompt_dim, hidden, output_flat), padded/sliced as needed.
RGB channels encode sign & magnitude:
R = weight magnitude (0..255 β†’ 0..1)
G = reserved (bias values, layer 1)
B = sign bit: <128 = negative, >=128 = positive
"""
import torch
import torch.nn.functional as F
import numpy as np
from PIL import Image
# ── config ────────────────────────────────────────────────────────────────────
PROMPT_DIM = 32 # prompt embedding size
HIDDEN = 64 # hidden layer width
OUT_SIZE = 32 # output image side length
OUT_FLAT = OUT_SIZE * OUT_SIZE * 3 # 3072
# model.png dimensions that fit all weights
# Layer sizes: (PROMPT_DIM→HIDDEN), (HIDDEN→HIDDEN), (HIDDEN→OUT_FLAT)
# We stack weight matrices vertically, each row = one output neuron's weights
# pad width to W = max of all input dims
MODEL_W = max(PROMPT_DIM, HIDDEN) # 64
MODEL_H = HIDDEN + HIDDEN + OUT_FLAT # 64+64+3072 = 3200
# β†’ model.png is 64Γ—3200 px (~600 KB uncompressed, tiny compressed)
def prompt_to_embedding(prompt: str) -> torch.Tensor:
"""Deterministic char-level embedding β†’ PROMPT_DIM vector."""
vec = torch.zeros(PROMPT_DIM)
for i, ch in enumerate(prompt.lower()):
idx = i % PROMPT_DIM
vec[idx] += (ord(ch) / 127.0)
# normalise
norm = vec.norm()
if norm > 0:
vec = vec / norm
return vec # shape (PROMPT_DIM,)
def pixels_to_weights(pixels: torch.Tensor) -> tuple:
"""
pixels: (H, W, 3) float32 tensor, values 0..1
Returns W1, W2, W3 weight matrices.
"""
H, W, _ = pixels.shape
r = pixels[:, :, 0] # magnitude 0..1
b = pixels[:, :, 2] # sign: <0.5 = neg, >=0.5 = pos
# sign-magnitude β†’ signed float in roughly [-1, 1]
sign = torch.where(b >= 0.5, torch.ones_like(b), -torch.ones_like(b))
vals = sign * r # signed weights
row = 0
# W1: HIDDEN Γ— PROMPT_DIM
W1_rows = HIDDEN
W1 = vals[row: row + W1_rows, :PROMPT_DIM] # (64, 32)
row += W1_rows
# W2: HIDDEN Γ— HIDDEN
W2_rows = HIDDEN
W2 = vals[row: row + W2_rows, :HIDDEN] # (64, 64)
row += W2_rows
# W3: OUT_FLAT Γ— HIDDEN
W3_rows = OUT_FLAT
W3 = vals[row: row + W3_rows, :HIDDEN] # (3072, 64)
row += W3_rows
return W1, W2, W3
def forward(pixels: torch.Tensor, prompt: str) -> torch.Tensor:
"""
pixels : (H, W, 3) float32, values 0..1 ← the model
prompt : str
returns: (OUT_SIZE, OUT_SIZE, 3) float32, values 0..1
"""
emb = prompt_to_embedding(prompt) # (32,)
W1, W2, W3 = pixels_to_weights(pixels)
x = torch.tanh(W1 @ emb) # (64,)
x = torch.tanh(W2 @ x) # (64,)
x = torch.sigmoid(W3 @ x) # (3072,) values 0..1
img = x.reshape(OUT_SIZE, OUT_SIZE, 3)
return img
def load_model(path: str) -> torch.Tensor:
"""Load model.png β†’ float tensor (H, W, 3) in [0,1]."""
img = Image.open(path).convert("RGB")
arr = np.array(img, dtype=np.float32) / 255.0
return torch.tensor(arr)
def save_model(pixels: torch.Tensor, path: str):
"""Save float tensor (H, W, 3) in [0,1] β†’ model.png."""
arr = (pixels.detach().clamp(0, 1).numpy() * 255).astype(np.uint8)
Image.fromarray(arr, mode="RGB").save(path)
def init_model(path: str):
"""Create a fresh random model.png."""
# small random weights around 0 β†’ R~0.1, B random sign
r = torch.rand(MODEL_H, MODEL_W) * 0.2 # small magnitude
g = torch.zeros(MODEL_H, MODEL_W)
b = (torch.rand(MODEL_H, MODEL_W) > 0.5).float() # random sign
pixels = torch.stack([r, g, b], dim=2)
save_model(pixels, path)
print(f"Initialised model: {path} ({MODEL_W}Γ—{MODEL_H} px)")