File size: 4,254 Bytes
38bfe91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
"""
PixelModel - the weights ARE the image.

model.png stores all weights as pixel RGB values (0-255 β†’ floats).
Architecture:
  - Row 0..H//3-1      β†’ Layer 1 weights  (prompt_dim β†’ hidden)
  - Row H//3..2*H//3-1 β†’ Layer 2 weights  (hidden β†’ hidden)
  - Row 2*H//3..H-1    β†’ Layer 3 weights  (hidden β†’ output_flat)

Width W = max(prompt_dim, hidden, output_flat), padded/sliced as needed.
RGB channels encode sign & magnitude:
  R = weight magnitude (0..255 β†’ 0..1)
  G = reserved (bias values, layer 1)
  B = sign bit: <128 = negative, >=128 = positive
"""

import torch
import torch.nn.functional as F
import numpy as np
from PIL import Image


# ── config ────────────────────────────────────────────────────────────────────
PROMPT_DIM  = 32    # prompt embedding size
HIDDEN      = 64    # hidden layer width
OUT_SIZE    = 32    # output image side length
OUT_FLAT    = OUT_SIZE * OUT_SIZE * 3   # 3072

# model.png dimensions that fit all weights
# Layer sizes: (PROMPT_DIM→HIDDEN), (HIDDEN→HIDDEN), (HIDDEN→OUT_FLAT)
# We stack weight matrices vertically, each row = one output neuron's weights
# pad width to W = max of all input dims
MODEL_W = max(PROMPT_DIM, HIDDEN)          # 64
MODEL_H = HIDDEN + HIDDEN + OUT_FLAT       # 64+64+3072 = 3200
# β†’ model.png is 64Γ—3200 px  (~600 KB uncompressed, tiny compressed)


def prompt_to_embedding(prompt: str) -> torch.Tensor:
    """Deterministic char-level embedding β†’ PROMPT_DIM vector."""
    vec = torch.zeros(PROMPT_DIM)
    for i, ch in enumerate(prompt.lower()):
        idx = i % PROMPT_DIM
        vec[idx] += (ord(ch) / 127.0)
    # normalise
    norm = vec.norm()
    if norm > 0:
        vec = vec / norm
    return vec  # shape (PROMPT_DIM,)


def pixels_to_weights(pixels: torch.Tensor) -> tuple:
    """
    pixels: (H, W, 3) float32 tensor, values 0..1

    Returns W1, W2, W3 weight matrices.
    """
    H, W, _ = pixels.shape

    r = pixels[:, :, 0]   # magnitude 0..1
    b = pixels[:, :, 2]   # sign: <0.5 = neg, >=0.5 = pos

    # sign-magnitude β†’ signed float in roughly [-1, 1]
    sign  = torch.where(b >= 0.5, torch.ones_like(b), -torch.ones_like(b))
    vals  = sign * r  # signed weights

    row = 0

    # W1: HIDDEN Γ— PROMPT_DIM
    W1_rows = HIDDEN
    W1 = vals[row: row + W1_rows, :PROMPT_DIM]          # (64, 32)
    row += W1_rows

    # W2: HIDDEN Γ— HIDDEN
    W2_rows = HIDDEN
    W2 = vals[row: row + W2_rows, :HIDDEN]               # (64, 64)
    row += W2_rows

    # W3: OUT_FLAT Γ— HIDDEN
    W3_rows = OUT_FLAT
    W3 = vals[row: row + W3_rows, :HIDDEN]               # (3072, 64)
    row += W3_rows

    return W1, W2, W3


def forward(pixels: torch.Tensor, prompt: str) -> torch.Tensor:
    """
    pixels : (H, W, 3) float32, values 0..1  ← the model
    prompt : str
    returns: (OUT_SIZE, OUT_SIZE, 3) float32, values 0..1
    """
    emb = prompt_to_embedding(prompt)           # (32,)
    W1, W2, W3 = pixels_to_weights(pixels)

    x = torch.tanh(W1 @ emb)                   # (64,)
    x = torch.tanh(W2 @ x)                     # (64,)
    x = torch.sigmoid(W3 @ x)                  # (3072,)  values 0..1

    img = x.reshape(OUT_SIZE, OUT_SIZE, 3)
    return img


def load_model(path: str) -> torch.Tensor:
    """Load model.png β†’ float tensor (H, W, 3) in [0,1]."""
    img = Image.open(path).convert("RGB")
    arr = np.array(img, dtype=np.float32) / 255.0
    return torch.tensor(arr)


def save_model(pixels: torch.Tensor, path: str):
    """Save float tensor (H, W, 3) in [0,1] β†’ model.png."""
    arr = (pixels.detach().clamp(0, 1).numpy() * 255).astype(np.uint8)
    Image.fromarray(arr, mode="RGB").save(path)


def init_model(path: str):
    """Create a fresh random model.png."""
    # small random weights around 0 β†’ R~0.1, B random sign
    r = torch.rand(MODEL_H, MODEL_W) * 0.2          # small magnitude
    g = torch.zeros(MODEL_H, MODEL_W)
    b = (torch.rand(MODEL_H, MODEL_W) > 0.5).float()  # random sign
    pixels = torch.stack([r, g, b], dim=2)
    save_model(pixels, path)
    print(f"Initialised model: {path}  ({MODEL_W}Γ—{MODEL_H} px)")