FarhanAK128's picture
Update model_utils/utils.py
5635c4f verified
import cv2
import torch
import torchvision.transforms.functional as F
import os
from .tokenizer import Tokenizer
from .vocabulary import vocabulary
import numpy as np
tokenizer = Tokenizer()
def load_video_for_gif(path):
cap = cv2.VideoCapture(path)
frames = []
for _ in range(int(cap.get(cv2.CAP_PROP_FRAME_COUNT))):
ret, frame = cap.read()
if not ret:
break
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
# Crop the mouth region
frame = frame[190:236, 80:220, :] # [H, W, C]
frames.append(frame)
cap.release()
# Convert to list of uint8 numpy arrays
frames_np = [np.array(f).astype(np.uint8) for f in frames]
return frames_np # List of [H, W, 3]
def load_video(path):
cap = cv2.VideoCapture(path)
frames = []
for _ in range(int(cap.get(cv2.CAP_PROP_FRAME_COUNT))):
ret, frame = cap.read()
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
# Convert to torch tensor [H, W, C] → [C, H, W]
frame = torch.from_numpy(frame).permute(2, 0, 1).float()
# Grayscale
frame = F.rgb_to_grayscale(frame)
frame = frame[:, 190:236, 80:220] # keep channel dimension
frames.append(frame)
cap.release()
frames = torch.stack(frames) # Shape: [T, 1, H, W]
# Normalize (per video)
mean = frames.mean()
std = frames.std()
frames = (frames - mean) / (std + 1e-8)
return frames # Shape: [T, 1, 46, 140]
def load_alignments(path: str):
with open(path, 'r') as f:
lines = f.readlines()
tokens = []
for line in lines:
line = line.split()
if line[2] != "sil": # skip silence
tokens.append(" ")
tokens.append(line[2])
# Join into one string
text = "".join(tokens).strip()
return text
def load_data(path: str):
file_name = path.split('/')[-1].split('.')[0]
video_path = os.path.join('data',f'{file_name}.mpg')
alignment_path = os.path.join('data',f'{file_name}.align')
frames = load_video(video_path)
alignments = load_alignments(alignment_path)
return frames, alignments
def ctc_greedy_decoder(logits_batch, vocab = tokenizer.int_to_str, blank_id=0):
# logits_batch: (batch, time, vocab_size)
probabilities = torch.argmax(logits_batch, dim=-1) # (batch, time)
output = []
tokens = []
for seq in probabilities: # iterate over batch
decoded = []
indices = []
prev = None
for idx in seq.tolist():
# CTC rule: ignore blanks + collapse repeats
if idx != blank_id and idx != prev:
indices.append(idx)
decoded.append(vocab[idx])
prev = idx
output.append("".join(decoded))
tokens.append(indices)
return output, tokens