File size: 1,594 Bytes
c8b42eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import os

import cv2
import torch
from matplotlib import pyplot as plt

from uniception.models.encoders.base import ViTEncoderInput
from uniception.models.encoders.cosmos import CosmosEncoder
from uniception.models.prediction_heads.cosmos import CosmosSingleChannel

base_path = os.path.dirname(os.path.abspath(__file__))

encoder = CosmosEncoder(
    name="cosmos",
    patch_size=8,
    pretrained_checkpoint_path=os.path.join(
        base_path, "../../../checkpoints/encoders/cosmos/Cosmos-Tokenizer-CI8x8/encoder.pth"
    ),
)

decoder = CosmosSingleChannel(
    patch_size=8,
    pretrained_checkpoint_path=os.path.join(base_path, "../../../checkpoints/prediction_heads/cosmos/decoder_8.pth"),
)

example_image = cv2.imread(os.path.join(base_path, "./example.png"))
example_image = cv2.cvtColor(example_image, cv2.COLOR_BGR2RGB)
example_tensor = torch.tensor(example_image).permute(2, 0, 1).unsqueeze(0).float() / 255.0
example_tensor = example_tensor * 2.0 - 1.0  # Normalize to [-1, 1] according to the COSMOS Encoder

encoded_latent = encoder(ViTEncoderInput("cosmos", example_tensor)).features

decoded_image = decoder(encoded_latent)
decoded_image = (decoded_image + 1.0) / 2.0  # Denormalize to [0, 1] for visualization

# plot the original and decoded images
plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.imshow(example_image)
plt.title("Original Image")
plt.axis("off")

plt.subplot(1, 2, 2)
plt.imshow(decoded_image.squeeze().detach().permute(1, 2, 0).cpu().numpy())
plt.title("Decoded Image")
plt.axis("off")

plt.savefig(os.path.join(base_path, "example_decoded.png"))