File size: 960 Bytes
c336648 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 |
#
# These are currently near-useless, but at least they're instant.
#
import json
import torch
def linear_encoder(img, ver="v1", weights="./linear_weights.json"):
"""Encodes tensor RGB[3,W,H](0.0-1.0) into tensor LATENT[4,W,H]"""
with open(weights) as f:
w = json.load(f)
w = w[ver]
lat = torch.stack([
( # A channel
(img[0]*w["A"]["R"]) + # R
(img[1]*w["A"]["G"]) + # G
(img[2]*w["A"]["B"]) + # B
w["A"]["C"] # Constant
),( # B channel
(img[0]*w["B"]["R"]) + # R
(img[1]*w["B"]["G"]) + # G
(img[2]*w["B"]["B"]) + # B
w["B"]["C"] # Constant
),( # C channel
(img[0]*w["C"]["R"]) + # R
(img[1]*w["C"]["G"]) + # G
(img[2]*w["C"]["B"]) + # B
w["C"]["C"] # Constant
),( # D channel
(img[0]*w["D"]["R"]) + # R
(img[1]*w["D"]["G"]) + # G
(img[2]*w["D"]["B"]) + # B
w["D"]["C"] # Constant
),
])
return lat
|