File size: 5,756 Bytes
4ec6f12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
# https://tree.rocks/get-heatmap-from-cnn-convolution-neural-network-aka-grad-cam-222e08f57a34

import cv2, os, torch, re
import matplotlib.pyplot as plt
from scipy.ndimage import zoom
import numpy as np
from model import MakiAlexNet
from tqdm import tqdm

# from tensorflow.keras.applications.resnet50 import preprocess_input, decode_predictions
TOP_ACCURACY_PERCENTILE = 10

TEST_IMAGE = "dataset/root/train/left1_frame_10.jpg"
MODEL_PARAMS = "alexnet_cognitive.pth"
GIF_STORE = "dataset/gifs/"
TRAIN_STORE = "dataset/root/train/"

model = MakiAlexNet()
model.load_state_dict(torch.load(MODEL_PARAMS))
model.eval()

# Make model run on cuda if available.
if torch.cuda.is_available():
    model = model.cuda()
    print("Running on cuda")


print(dir(model))

for name, module in model.named_modules():
    # Print the layer name
    print(name)


def extract_file_paths(filename):
    """With aid from https://regex101.com/, regex."""
    extractor_reg = r"(left|right)([0-9]+)(_frame_)([0-9]+)"
    result = re.search(extractor_reg, filename)
    frame_no = result.group(4)
    frame_name = result.group(1)
    video_no = result.group(2)
    return frame_no, frame_name, video_no


def create_mp4_from_frames(file_name, frames):
    """Generate MP4/GIF file with the collection of frames given with a duration of 2000 msec. """
    print("Sorted frames: ", sorted(frames))
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    height, width, _ = cv2.imread(frames[0]).shape
    fps = 20  # Adjust the frames per second (FPS) as needed
    video_path = os.path.join(os.getcwd(), "dataset", "gifs", f"{file_name}.mp4")
    video = cv2.VideoWriter(video_path, fourcc, fps, (width, height))
    for frame_path in sorted(frames):
        # Convert BRG to RGB
        image = cv2.imread(frame_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        # if image.dtype != np.uint8:
        #     image = (image * 255).astype(np.uint8)  # Convert to uint8
        video.write(image)

    # Release the VideoWriter
    video.release()



current_video_name = None
selected_frames = []  # stores matrices for the GIF generation.
for image_filename in ["left1_frame_5.jpg"]: # tqdm(sorted(os.listdir(TRAIN_STORE)), desc="Running Images"):  # :

    frame_no, frame_name, video_no = extract_file_paths(image_filename)
    obtained_video_name = video_no+"vid"+frame_name
    if current_video_name != obtained_video_name:
        # We have a new video sequence, so save current sequences and name
        if selected_frames:
            filename = f"{current_video_name}"
            # Create gif from the frames.
            if current_video_name:
                create_mp4_from_frames(filename, selected_frames)
            # Clear frames and hand off to new handle.
            selected_frames = []
            current_video_name = obtained_video_name

    # With the number and name of the file paths, we can then determine which should be part of the specific GIF file.
    # f"frame_no,fileno,video_no.gif"

    img = cv2.imread(os.path.join(TRAIN_STORE, image_filename))
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img = torch.unsqueeze(torch.tensor(img.astype(np.float32)), 0)  # Convert image to tensor with float32, and extended batch size dimension.  (Batch, Channel, W,H)
    X = torch.einsum("BWHC->BCWH", img)
    if torch.cuda.is_available():
        X = X.cuda()

    output = model(X)
    # print(output)
    # print(model.layer_outputs)
    conv = model.layer_outputs['Conv2d']

    conv = torch.einsum("BCWH->BWHC", conv).cpu().detach().numpy()
    # print(conv.shape)  # torch.Size([1, 256, 12, 12])
    # conv = conv.squeeze(0)
    # print(conv.shape)  # torch.Size([256, 12, 12])
    scale = 224 / 12  # 256x5x5 after this additional.

    plt.figure(figsize=(16, 16))
    total_mat = None
    for i in range(256):
        plt.subplot(16, 16, i + 1)
        plt.imshow(img.squeeze(0))
        plt.imshow(zoom(conv[0, :,:,i], zoom=(scale, scale)), cmap='jet', alpha=0.3)
    plt.show()
    # wait for user to press a key

    #     mat = zoom(conv[0, :, :, i], zoom=(scale, scale))
    #     threshold = np.percentile(mat.flatten(), TOP_ACCURACY_PERCENTILE)
    #     # The Lower threshold is to zero, the more specific the look is shown.
    #
    #     mask = mat > threshold
    #     # OR: filter_map = np.where(filter_map <= threshold, 0, filter_map)
    #
    #     # Rescale remaining values (adjust new_range if needed)
    #     new_range = 1  # Adjust based on your desired final range
    #     filter_map = np.where(mask, (mat - threshold) / (mat.max() - threshold) * new_range, 0)
    #
    #     # I just add all the maps together, which is really noisy.
    #     if type(total_mat) != type(None):
    #         total_mat += filter_map
    #     else:
    #         total_mat = filter_map
    #
    # # Normalize based on largest value,
    # # Store this image in a collection, in which a GIF will be made, that lasts at least 2 seconds.
    # total_mat = total_mat / abs(np.max(total_mat))
    # #
    # image = img.squeeze(0)  # .detach().numpy().astype(np.float32)
    #
    #
    # plt.imshow(plt.imread(os.path.join(os.getcwd(), "dataset/root/train", image_filename)))  # full path needed
    # plt.imshow(total_mat, cmap='jet', alpha=0.3)
    #
    # # selected_frames.append()
    # filename = frame_name+frame_no+video_no+".jpg"
    # file_path = os.path.join(os.getcwd(), "dataset/gifs/raw/", filename)
    # plt.savefig(file_path)
    # selected_frames.append(file_path)


exit()


# plt.figure(figsize=(16, 16))
# for i in range(36):
#     plt.subplot(6, 6, i + 1)
#     plt.imshow(cv2.imread(TEST_IMAGE))
#     plt.imshow(zoom(conv[0, :,:,i], zoom=(scale, scale)), cmap='jet', alpha=0.3)
#
# plt.show()