File size: 2,840 Bytes
982555b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
664a71f
982555b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import cv2
import torch
import torchvision.transforms.functional as F
import os
from .tokenizer import Tokenizer
from .vocabulary import vocabulary
import numpy as np

tokenizer = Tokenizer()

def load_video_for_gif(path):
    cap = cv2.VideoCapture(path)
    frames = []
    for _ in range(int(cap.get(cv2.CAP_PROP_FRAME_COUNT))):
        ret, frame = cap.read()
        if not ret:
            break
        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)

        # Crop the mouth region
        frame = frame[190:236, 80:220, :]  # [H, W, C]

        frames.append(frame)
    cap.release()

    # Convert to list of uint8 numpy arrays
    frames_np = [np.array(f).astype(np.uint8) for f in frames]
    return frames_np  # List of [H, W, 3]

def load_video(path):

    cap = cv2.VideoCapture(path)
    frames = []
    for _ in range(int(cap.get(cv2.CAP_PROP_FRAME_COUNT))):
        ret, frame = cap.read()
        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)

        # Convert to torch tensor [H, W, C] → [C, H, W]
        frame = torch.from_numpy(frame).permute(2, 0, 1).float()

        # Grayscale
        frame = F.rgb_to_grayscale(frame)
        frame = frame[:, 190:236, 80:220]   # keep channel dimension
        frames.append(frame)
    cap.release()

    frames = torch.stack(frames)   # Shape: [T, 1, H, W]

    # Normalize (per video)
    mean = frames.mean()
    std = frames.std()
    frames = (frames - mean) / (std + 1e-8)

    return frames  # Shape: [T, 1, 46, 140]

def load_alignments(path: str):
    with open(path, 'r') as f:
        lines = f.readlines()

    tokens = []
    for line in lines:
        line = line.split()
        if line[2] != "sil":  # skip silence
            tokens.append(" ")
            tokens.append(line[2])

    # Join into one string
    text = "".join(tokens).strip()

    return text

def load_data(path: str):
    file_name = path.split('/')[-1].split('.')[0]
    video_path = os.path.join('data',f'{file_name}.mpg')
    alignment_path = os.path.join('data',f'{file_name}.align')
    frames = load_video(video_path)
    alignments = load_alignments(alignment_path)

    return frames, alignments

def ctc_greedy_decoder(logits_batch, vocab = tokenizer.int_to_str, blank_id=0):
    # logits_batch: (batch, time, vocab_size)
    probabilities = torch.argmax(logits_batch, dim=-1)  # (batch, time)
    output = []

    tokens = []

    for seq in probabilities:  # iterate over batch
        decoded = []
        indices = []
        prev = None
        for idx in seq.tolist():
            # CTC rule: ignore blanks + collapse repeats
            if idx != blank_id and idx != prev:
                indices.append(idx)
                decoded.append(vocab[idx])
            prev = idx
        output.append("".join(decoded))
        tokens.append(indices)

    return output, tokens