Video_Summary_Beta / utils.py
AmithAdiraju1694's picture
change_medoid_inference (#3)
56a8330 verified
from streamlit import session_state as sst
import time
import torch.nn.functional as F
import cv2
import av
import heapq
import numpy as np
from preprocessing import preprocess_images
import time
from io import BytesIO
import torch
import torchvision.models as models
import torch.nn as nn
import soundfile as sf
import subprocess
from typing import List
prompt_audio_summarization = "This is a video transcript, tell me what is this about: "
class SiameseNetwork(nn.Module):
def __init__(self, model_name="vit_b_16"):
super(SiameseNetwork, self).__init__()
self.encoder = models.vit_b_16(weights="IMAGENET1K_V1") # Pretrained ViT
self.encoder.heads = nn.Identity() # Remove classification head
self.fc = nn.Linear(768, 128) # Reduce to 128-d embedding
def forward(self, frames):
B,num_frames,H,W,C = frames.shape # (Batch,num_frames, H, W, C)
# Flatten frames into batch dimension for ViT
frames = frames.permute(0,1,4,2,3).reshape(B * num_frames, C,H,W)
# Extract frame-level embeddings
emb = self.encoder(frames)
# Reshape back to (B, T, 768) and average over T
#TODO: Change this to use LSTM instead of averaging
emb = emb.reshape(B, num_frames, -1).mean(dim=1) # (B, 768)
# Pass through fully connected layer
emb = self.fc(emb)
return emb
def timer(func):
def wrapper(*args, **kwargs):
start = time.time()
result = func(*args, **kwargs)
duration = time.time() - start
wrapper.total_time += duration
print(f"Execution time of {func}: {duration}")
return result
wrapper.total_time = 0
return wrapper
def navigate_to(page: str) -> None:
"""
Function to set the current page in the state of streamlit. A helper for
simulating navigation in streamlit.
Parameters:
page: str, required.
Returns:
None
"""
sst["page"] = page
@timer
def read_important_frames(video_bytes, top_k_frames) -> List:
# reading uploaded vidoe in memory
video_io = BytesIO(video_bytes)
# opening uploaded video frames
container = av.open(video_io, format='mp4')
prev_frame = None; important_frames = []
# for each frame, find if it's movement worthy and push to heap for top_k movement frames
for frameId, frame in enumerate( container.decode(video=0) ): # Decode all frames
img = frame.to_ndarray(format="bgr24") # Convert frame to NumPy array (BGR format)
assert len(img.shape) == 3, f"Instead it is: {img.shape}"
if prev_frame is not None:
# Compute frame difference in gray scale for efficiency
diff = cv2.absdiff(prev_frame, img)
gray_diff = cv2.cvtColor(diff, cv2.COLOR_BGR2GRAY)
movement_score = np.sum(gray_diff) # Sum of pixel differences
processed_frame = preprocess_images(frame.to_ndarray(format="rgb24") ,
224,
224
)
# Thresholding to detect movement (adjust based on video)
if len(important_frames) < top_k_frames: # Tune threshold for motion sensitivity
heapq.heappush(important_frames,
(movement_score, frameId, processed_frame)
)
else:
heapq.heappushpop(important_frames,
(movement_score, frameId, processed_frame)
)
prev_frame = img # Update previous frame
# sorting top_k frames in chronological order of their appearance. This is quickest LOC.
important_frames = [item[2] for item in sorted(important_frames, key = lambda x: x[1])]
return important_frames
@timer
def extract_audio(video_bytes):
"""Extracts raw audio from a video file given as bytes without writing temp files."""
# Run FFmpeg to extract raw WAV audio without writing a file
process = subprocess.run(
["ffmpeg", "-i", "pipe:0", "-ac", "1", "-ar", "16000", "-c:a", "pcm_s16le", "-f", "wav", "pipe:1"],
input=video_bytes,
stdout=subprocess.PIPE,
stderr=subprocess.DEVNULL
)
# Convert FFmpeg output to a BytesIO stream
audio_stream = BytesIO(process.stdout)
# Read the audio stream into a NumPy array
audio_array, sample_rate = sf.read(audio_stream, dtype="float32")
# Convert to PyTorch tensor (Whisper expects a torch.Tensor)
audio_tensor = torch.tensor(audio_array)
return audio_tensor
def batch_generator(array_list, batch_size=5):
"""
Generator that yields batches of 5 NumPy arrays stacked along the first dimension.
Parameters:
array_list (list of np.ndarray): List of NumPy arrays of shape (H, W, C).
batch_size (int): Number of arrays per batch (default is 5).
Yields:
np.ndarray: A batch of shape (batch_size, H, W, C).
"""
for i in range(0, len(array_list), batch_size):
batch = array_list[i:i + batch_size]
if len(batch) == batch_size:
yield np.stack(batch, axis=0)
@timer
def cosine_sim(emb1, emb2, threshold = 0.5):
cosine_sim = F.cosine_similarity(emb1, emb2)
counts = torch.count_nonzero(cosine_sim > threshold).numpy()
return (cosine_sim.mean(), counts)