File size: 960 Bytes
c336648
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
#
# These are currently near-useless, but at least they're instant.
#
import json
import torch

def linear_encoder(img, ver="v1", weights="./linear_weights.json"):
	"""Encodes tensor RGB[3,W,H](0.0-1.0) into tensor LATENT[4,W,H]"""
	with open(weights) as f:
		w = json.load(f)
	w = w[ver]
	lat = torch.stack([
		( # A channel
			(img[0]*w["A"]["R"]) +  # R
			(img[1]*w["A"]["G"]) +  # G
			(img[2]*w["A"]["B"]) +  # B
			w["A"]["C"]             # Constant
		),( # B channel
			(img[0]*w["B"]["R"]) +  # R
			(img[1]*w["B"]["G"]) +  # G
			(img[2]*w["B"]["B"]) +  # B
			w["B"]["C"]             # Constant
		),( # C channel
			(img[0]*w["C"]["R"]) +  # R
			(img[1]*w["C"]["G"]) +  # G
			(img[2]*w["C"]["B"]) +  # B
			w["C"]["C"]             # Constant
		),( # D channel
			(img[0]*w["D"]["R"]) +  # R
			(img[1]*w["D"]["G"]) +  # G
			(img[2]*w["D"]["B"]) +  # B
			w["D"]["C"]             # Constant
		),
	])
	return lat