Spaces:
Sleeping
Sleeping
| import cv2 | |
| import torch | |
| import torchvision.transforms.functional as F | |
| import os | |
| from .tokenizer import Tokenizer | |
| from .vocabulary import vocabulary | |
| import numpy as np | |
| tokenizer = Tokenizer() | |
| def load_video_for_gif(path): | |
| cap = cv2.VideoCapture(path) | |
| frames = [] | |
| for _ in range(int(cap.get(cv2.CAP_PROP_FRAME_COUNT))): | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| # Crop the mouth region | |
| frame = frame[190:236, 80:220, :] # [H, W, C] | |
| frames.append(frame) | |
| cap.release() | |
| # Convert to list of uint8 numpy arrays | |
| frames_np = [np.array(f).astype(np.uint8) for f in frames] | |
| return frames_np # List of [H, W, 3] | |
| def load_video(path): | |
| cap = cv2.VideoCapture(path) | |
| frames = [] | |
| for _ in range(int(cap.get(cv2.CAP_PROP_FRAME_COUNT))): | |
| ret, frame = cap.read() | |
| frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| # Convert to torch tensor [H, W, C] → [C, H, W] | |
| frame = torch.from_numpy(frame).permute(2, 0, 1).float() | |
| # Grayscale | |
| frame = F.rgb_to_grayscale(frame) | |
| frame = frame[:, 190:236, 80:220] # keep channel dimension | |
| frames.append(frame) | |
| cap.release() | |
| frames = torch.stack(frames) # Shape: [T, 1, H, W] | |
| # Normalize (per video) | |
| mean = frames.mean() | |
| std = frames.std() | |
| frames = (frames - mean) / (std + 1e-8) | |
| return frames # Shape: [T, 1, 46, 140] | |
| def load_alignments(path: str): | |
| with open(path, 'r') as f: | |
| lines = f.readlines() | |
| tokens = [] | |
| for line in lines: | |
| line = line.split() | |
| if line[2] != "sil": # skip silence | |
| tokens.append(" ") | |
| tokens.append(line[2]) | |
| # Join into one string | |
| text = "".join(tokens).strip() | |
| return text | |
| def load_data(path: str): | |
| file_name = path.split('/')[-1].split('.')[0] | |
| video_path = os.path.join('data',f'{file_name}.mpg') | |
| alignment_path = os.path.join('data',f'{file_name}.align') | |
| frames = load_video(video_path) | |
| alignments = load_alignments(alignment_path) | |
| return frames, alignments | |
| def ctc_greedy_decoder(logits_batch, vocab = tokenizer.int_to_str, blank_id=0): | |
| # logits_batch: (batch, time, vocab_size) | |
| probabilities = torch.argmax(logits_batch, dim=-1) # (batch, time) | |
| output = [] | |
| tokens = [] | |
| for seq in probabilities: # iterate over batch | |
| decoded = [] | |
| indices = [] | |
| prev = None | |
| for idx in seq.tolist(): | |
| # CTC rule: ignore blanks + collapse repeats | |
| if idx != blank_id and idx != prev: | |
| indices.append(idx) | |
| decoded.append(vocab[idx]) | |
| prev = idx | |
| output.append("".join(decoded)) | |
| tokens.append(indices) | |
| return output, tokens |