refactor_eh (#2)
Browse files- Added exception handling, fixed models and tokenizer being on different devices, fixed boiler plate code (e8f31c7e08bdb0472e2071aa660fea5ffea2eda7)
- README.md +1 -3
- app.py +0 -2
- model_inference.py +26 -111
- pages.py +38 -36
- runtime.txt +0 -1
- utils.py +2 -5
README.md
CHANGED
|
@@ -4,10 +4,8 @@ emoji: 👁
|
|
| 4 |
colorFrom: indigo
|
| 5 |
colorTo: indigo
|
| 6 |
sdk: streamlit
|
| 7 |
-
python_version: 3.9.6
|
| 8 |
sdk_version: 1.42.0
|
| 9 |
app_file: app.py
|
| 10 |
pinned: false
|
| 11 |
---
|
| 12 |
-
|
| 13 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
|
| 4 |
colorFrom: indigo
|
| 5 |
colorTo: indigo
|
| 6 |
sdk: streamlit
|
| 7 |
+
python_version: 3.9.6-slim
|
| 8 |
sdk_version: 1.42.0
|
| 9 |
app_file: app.py
|
| 10 |
pinned: false
|
| 11 |
---
|
|
|
|
|
|
app.py
CHANGED
|
@@ -1,7 +1,5 @@
|
|
| 1 |
-
import streamlit as st
|
| 2 |
from streamlit import session_state as sst
|
| 3 |
import asyncio
|
| 4 |
-
import torch
|
| 5 |
|
| 6 |
from pages import landing_page, model_inference_page
|
| 7 |
|
|
|
|
|
|
|
| 1 |
from streamlit import session_state as sst
|
| 2 |
import asyncio
|
|
|
|
| 3 |
|
| 4 |
from pages import landing_page, model_inference_page
|
| 5 |
|
model_inference.py
CHANGED
|
@@ -1,75 +1,19 @@
|
|
| 1 |
-
from transformers import pipeline
|
| 2 |
-
import torch
|
| 3 |
-
from PIL import Image
|
| 4 |
-
|
| 5 |
-
import torch.nn as nn
|
| 6 |
-
import torchvision.models as models
|
| 7 |
-
import torch.nn.functional as F
|
| 8 |
-
|
| 9 |
-
from PIL import Image
|
| 10 |
-
from utils import prompt_frame_summarization, assistant_role, prompt_audio_summarization
|
| 11 |
import streamlit as st
|
| 12 |
-
|
| 13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
import numpy as np
|
| 15 |
import whisper
|
| 16 |
-
from utils import batch_generator, cosine_sim
|
| 17 |
from streamlit import session_state as sst
|
| 18 |
import onnxruntime
|
| 19 |
|
| 20 |
|
| 21 |
-
|
| 22 |
-
class SiameseNetwork(nn.Module):
|
| 23 |
-
def __init__(self, model_name="vit_b_16"):
|
| 24 |
-
super(SiameseNetwork, self).__init__()
|
| 25 |
-
|
| 26 |
-
self.encoder = models.vit_b_16(weights="IMAGENET1K_V1") # Pretrained ViT
|
| 27 |
-
self.encoder.heads = nn.Identity() # Remove classification head
|
| 28 |
-
|
| 29 |
-
self.fc = nn.Linear(768, 128) # Reduce to 128-d embedding
|
| 30 |
-
|
| 31 |
-
def forward(self, video_frames1, video_frames2):
|
| 32 |
-
"""
|
| 33 |
-
video1: (B, nf, H, W, C) # Batch of videos (50 frames each)
|
| 34 |
-
video2: (B, nf, H, W, C)
|
| 35 |
-
"""
|
| 36 |
-
B,num_frames,H,W,C = video_frames1.shape # (Batch, Channels, H, W)
|
| 37 |
-
|
| 38 |
-
# Flatten frames into batch dimension for ViT
|
| 39 |
-
video_frames1 = video_frames1.permute(0,1,4,2,3).reshape(B * num_frames, C,H,W)
|
| 40 |
-
video_frames2 = video_frames2.reshape(0,1,4,2,3).reshape(B * num_frames, C,H,W)
|
| 41 |
-
|
| 42 |
-
# Extract frame-level embeddings
|
| 43 |
-
emb1 = self.encoder(video_frames1) # (B*num_frames, 768)
|
| 44 |
-
emb2 = self.encoder(video_frames2)
|
| 45 |
-
|
| 46 |
-
# Reshape back to (B, T, 768) and average over T
|
| 47 |
-
#TODO: Change this to use LSTM instead of averaging
|
| 48 |
-
emb1 = emb1.reshape(B, num_frames, -1).mean(dim=1) # (B, 768)
|
| 49 |
-
emb2 = emb2.reshape(B, num_frames, -1).mean(dim=1)
|
| 50 |
-
|
| 51 |
-
# Pass through fully connected layer
|
| 52 |
-
emb1 = self.fc(emb1) # (B, 128)
|
| 53 |
-
emb2 = self.fc(emb2)
|
| 54 |
-
|
| 55 |
-
return emb1, emb2
|
| 56 |
-
|
| 57 |
-
def inference(self, video_frames):
|
| 58 |
-
"""
|
| 59 |
-
video: (B, 50, C, H, W)
|
| 60 |
-
"""
|
| 61 |
-
B, num_frames, H, W, C = video_frames.shape
|
| 62 |
-
|
| 63 |
-
video_frames = video_frames.permute(0,1,4,2,3).reshape(B * num_frames, C,H,W)
|
| 64 |
-
emb = self.encoder(video_frames)
|
| 65 |
-
emb = emb.reshape(B, num_frames, -1).mean(dim=1)
|
| 66 |
-
emb = self.fc(emb)
|
| 67 |
-
|
| 68 |
-
return emb
|
| 69 |
-
|
| 70 |
-
|
| 71 |
@timer
|
| 72 |
-
def get_text_from_audio(audio_tensors):
|
| 73 |
"""Transcribe multiple audio tensors in parallel using Whisper's batch processing."""
|
| 74 |
# Transcribe the in-memory audio
|
| 75 |
audio_tensors = audio_tensors.to(sst['device'])
|
|
@@ -80,52 +24,21 @@ def get_text_from_audio(audio_tensors):
|
|
| 80 |
|
| 81 |
@timer
|
| 82 |
def summarize_from_text(raw_transcription):
|
| 83 |
-
|
| 84 |
-
summary = text_summarizer(prompt_audio_summarization + raw_transcription,
|
| 85 |
-
max_length=108,
|
| 86 |
-
min_length=36, do_sample=False)[0]['summary_text']
|
| 87 |
-
|
| 88 |
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
"""
|
| 104 |
-
|
| 105 |
-
processor = None
|
| 106 |
-
messages = None
|
| 107 |
-
model = None
|
| 108 |
-
tokenizer = None
|
| 109 |
-
|
| 110 |
-
if video_frames is None or len(video_frames) == 0:
|
| 111 |
-
return "Error: No video frames available."
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
# Ensure frames are properly formatted
|
| 115 |
-
video_frames = [Image.fromarray(frame.astype("uint8")) for frame in video_frames]
|
| 116 |
-
|
| 117 |
-
# Ensure correct format for processor
|
| 118 |
-
inputs = processor(messages, images=None, videos=[video_frames])
|
| 119 |
-
|
| 120 |
-
inputs.update({
|
| 121 |
-
"tokenizer": tokenizer,
|
| 122 |
-
"max_new_tokens": 54,
|
| 123 |
-
"decode_text": True,
|
| 124 |
-
})
|
| 125 |
-
|
| 126 |
-
summary_text = model.generate(**inputs)
|
| 127 |
-
|
| 128 |
-
return summary_text
|
| 129 |
|
| 130 |
@timer
|
| 131 |
def rate_video_frames(video_frames):
|
|
@@ -154,7 +67,9 @@ def rate_video_frames(video_frames):
|
|
| 154 |
def load_models():
|
| 155 |
sst['device'] = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 156 |
transcriber = whisper.load_model("base", device = sst['device'])
|
| 157 |
-
|
|
|
|
|
|
|
| 158 |
|
| 159 |
base_frame_emb = torch.tensor(
|
| 160 |
np.load('base_frame_medoid.npz')['arr'],
|
|
@@ -167,7 +82,7 @@ def load_models():
|
|
| 167 |
)
|
| 168 |
|
| 169 |
return (
|
| 170 |
-
transcriber,
|
| 171 |
)
|
| 172 |
|
| 173 |
audio_transcriber_model, text_summarizer, video_rating_model,base_frame_emb = load_models()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import streamlit as st
|
| 2 |
+
import torch
|
| 3 |
+
from utils import (
|
| 4 |
+
prompt_audio_summarization,
|
| 5 |
+
timer,
|
| 6 |
+
cosine_sim
|
| 7 |
+
)
|
| 8 |
+
from transformers import BartForConditionalGeneration, BartTokenizer
|
| 9 |
import numpy as np
|
| 10 |
import whisper
|
|
|
|
| 11 |
from streamlit import session_state as sst
|
| 12 |
import onnxruntime
|
| 13 |
|
| 14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
@timer
|
| 16 |
+
def get_text_from_audio(audio_tensors) -> str:
|
| 17 |
"""Transcribe multiple audio tensors in parallel using Whisper's batch processing."""
|
| 18 |
# Transcribe the in-memory audio
|
| 19 |
audio_tensors = audio_tensors.to(sst['device'])
|
|
|
|
| 24 |
|
| 25 |
@timer
|
| 26 |
def summarize_from_text(raw_transcription):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
|
| 28 |
+
inputs = text_summarizer[0](prompt_audio_summarization + raw_transcription,
|
| 29 |
+
return_tensors="pt",
|
| 30 |
+
max_length=1024,
|
| 31 |
+
truncation=True)\
|
| 32 |
+
.to(sst['device'])
|
| 33 |
+
|
| 34 |
+
summary_ids = text_summarizer[1].generate(**inputs,
|
| 35 |
+
max_length=150,
|
| 36 |
+
min_length=30,
|
| 37 |
+
length_penalty=2.0,
|
| 38 |
+
num_beams=4
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
return text_summarizer[0].decode(summary_ids[0], skip_special_tokens=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
|
| 43 |
@timer
|
| 44 |
def rate_video_frames(video_frames):
|
|
|
|
| 67 |
def load_models():
|
| 68 |
sst['device'] = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 69 |
transcriber = whisper.load_model("base", device = sst['device'])
|
| 70 |
+
|
| 71 |
+
model = BartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn").to(sst['device'])
|
| 72 |
+
tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn")
|
| 73 |
|
| 74 |
base_frame_emb = torch.tensor(
|
| 75 |
np.load('base_frame_medoid.npz')['arr'],
|
|
|
|
| 82 |
)
|
| 83 |
|
| 84 |
return (
|
| 85 |
+
transcriber, (tokenizer, model), session, base_frame_emb
|
| 86 |
)
|
| 87 |
|
| 88 |
audio_transcriber_model, text_summarizer, video_rating_model,base_frame_emb = load_models()
|
pages.py
CHANGED
|
@@ -1,13 +1,9 @@
|
|
| 1 |
import streamlit as st
|
| 2 |
from streamlit import session_state as sst
|
| 3 |
-
import time
|
| 4 |
-
|
| 5 |
-
import pandas as pd
|
| 6 |
from utils import navigate_to
|
| 7 |
|
| 8 |
from model_inference import rate_video_frames,get_text_from_audio, summarize_from_text
|
| 9 |
from utils import read_important_frames, extract_audio
|
| 10 |
-
import numpy as np
|
| 11 |
|
| 12 |
|
| 13 |
# Define size limits (adjust based on your system)
|
|
@@ -33,25 +29,35 @@ async def landing_page():
|
|
| 33 |
else:
|
| 34 |
# bytes object which can be translated to audio or video
|
| 35 |
video_bytes = uploaded_file.read()
|
| 36 |
-
|
|
|
|
| 37 |
with st.spinner("Getting most important moments from your video."):
|
| 38 |
-
important_frames = read_important_frames(video_bytes, 100)
|
| 39 |
-
st.success(f"Got important moments.")
|
| 40 |
|
| 41 |
-
|
|
|
|
| 42 |
|
| 43 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
with st.spinner("Getting audio transcript from your video for summary"):
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
|
| 50 |
-
|
| 51 |
-
|
|
|
|
|
|
|
|
|
|
| 52 |
|
| 53 |
-
# add audio transcript to session state
|
| 54 |
-
sst["audio_transcript"] = audio_transcript_bytes
|
| 55 |
|
| 56 |
st.button("Summarize & Analyze Video",
|
| 57 |
on_click = navigate_to,
|
|
@@ -67,13 +73,11 @@ async def model_inference_page():
|
|
| 67 |
|
| 68 |
important_frames = sst["important_frames"]
|
| 69 |
with st.spinner("Generating Movie Scale rating for your video"):
|
| 70 |
-
|
|
|
|
|
|
|
|
|
|
| 71 |
|
| 72 |
-
if len(video_rating_scale) > 0:
|
| 73 |
-
pass
|
| 74 |
-
else:
|
| 75 |
-
video_rating_scale = "Sorry, we couldn't find any images from your video, hence couldn't generate any summary"
|
| 76 |
-
|
| 77 |
st.toast("Done")
|
| 78 |
st.header("Movie Scale Rating of Your Video: ", divider = True)
|
| 79 |
st.write(video_rating_scale)
|
|
@@ -84,21 +88,19 @@ async def model_inference_page():
|
|
| 84 |
if "audio_transcript" in sst:
|
| 85 |
|
| 86 |
with st.spinner("Extracting text from audio file"):
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
video_summary_text = summarize_from_text(video_raw_text)
|
| 92 |
st.toast("Done")
|
| 93 |
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
print("Time taken to generate text summary from raw text in seconds: ", summarize_from_text.total_time)
|
| 102 |
|
| 103 |
st.header("Audio Transcript summary of your video: ", divider = True)
|
| 104 |
st.write(video_summary_text)
|
|
|
|
| 1 |
import streamlit as st
|
| 2 |
from streamlit import session_state as sst
|
|
|
|
|
|
|
|
|
|
| 3 |
from utils import navigate_to
|
| 4 |
|
| 5 |
from model_inference import rate_video_frames,get_text_from_audio, summarize_from_text
|
| 6 |
from utils import read_important_frames, extract_audio
|
|
|
|
| 7 |
|
| 8 |
|
| 9 |
# Define size limits (adjust based on your system)
|
|
|
|
| 29 |
else:
|
| 30 |
# bytes object which can be translated to audio or video
|
| 31 |
video_bytes = uploaded_file.read()
|
| 32 |
+
|
| 33 |
+
# Try to get important frames from this video, if not don't add this key for further inference processing
|
| 34 |
with st.spinner("Getting most important moments from your video."):
|
|
|
|
|
|
|
| 35 |
|
| 36 |
+
try:
|
| 37 |
+
important_frames = read_important_frames(video_bytes, 100)
|
| 38 |
|
| 39 |
+
st.success(f"Got important moments.")
|
| 40 |
+
|
| 41 |
+
# add important frames to session state and redirect to model inference page
|
| 42 |
+
sst["important_frames"] = important_frames
|
| 43 |
+
|
| 44 |
+
except Exception as e:
|
| 45 |
+
st.write(f"Sorry couldn't extract important frames from this video & can't rate this on movie scale, because of error: {e}")
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
# Try to get audio from this video, if not don't add this key for further inference processing
|
| 49 |
with st.spinner("Getting audio transcript from your video for summary"):
|
| 50 |
+
try:
|
| 51 |
+
audio_transcript_bytes = extract_audio(video_bytes)
|
| 52 |
+
|
| 53 |
+
st.success(f"Got audio transcript.")
|
| 54 |
|
| 55 |
+
# add audio transcript to session state
|
| 56 |
+
sst["audio_transcript"] = audio_transcript_bytes
|
| 57 |
+
|
| 58 |
+
except Exception as e:
|
| 59 |
+
st.write(f"Sorry couldn't extract audio from this video & can't rate summarize it, because of error: {e}")
|
| 60 |
|
|
|
|
|
|
|
| 61 |
|
| 62 |
st.button("Summarize & Analyze Video",
|
| 63 |
on_click = navigate_to,
|
|
|
|
| 73 |
|
| 74 |
important_frames = sst["important_frames"]
|
| 75 |
with st.spinner("Generating Movie Scale rating for your video"):
|
| 76 |
+
try:
|
| 77 |
+
video_rating_scale = rate_video_frames(important_frames)
|
| 78 |
+
except Exception as e:
|
| 79 |
+
video_rating_scale = f"Sorry, we couldn't generate rating of your video because of this error: {e} "
|
| 80 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
st.toast("Done")
|
| 82 |
st.header("Movie Scale Rating of Your Video: ", divider = True)
|
| 83 |
st.write(video_rating_scale)
|
|
|
|
| 88 |
if "audio_transcript" in sst:
|
| 89 |
|
| 90 |
with st.spinner("Extracting text from audio file"):
|
| 91 |
+
try:
|
| 92 |
+
video_summary_text = get_text_from_audio(sst["audio_transcript"])
|
| 93 |
+
except Exception as e:
|
| 94 |
+
video_summary_text = f"Sorry, we couldn't extract text from audio of this file because of this error: {e} "
|
|
|
|
| 95 |
st.toast("Done")
|
| 96 |
|
| 97 |
+
if video_summary_text[:5] != "Sorry":
|
| 98 |
+
with st.spinner("Summarizing text from entire transcript"):
|
| 99 |
+
try:
|
| 100 |
+
video_summary_text = summarize_from_text(video_summary_text)
|
| 101 |
+
except Exception as e:
|
| 102 |
+
video_summary_text = f"Sorry, we couldn't summarize text from audio of this file because of this error: {e} "
|
| 103 |
+
st.toast("Done")
|
|
|
|
| 104 |
|
| 105 |
st.header("Audio Transcript summary of your video: ", divider = True)
|
| 106 |
st.write(video_summary_text)
|
runtime.txt
DELETED
|
@@ -1 +0,0 @@
|
|
| 1 |
-
3.9.*
|
|
|
|
|
|
utils.py
CHANGED
|
@@ -10,17 +10,14 @@ import numpy as np
|
|
| 10 |
from preprocessing import preprocess_images
|
| 11 |
import time
|
| 12 |
|
| 13 |
-
import io
|
| 14 |
from io import BytesIO
|
| 15 |
import torch
|
| 16 |
import soundfile as sf
|
| 17 |
import subprocess
|
| 18 |
from typing import List
|
| 19 |
|
| 20 |
-
|
| 21 |
-
prompt_frame_summarization = "These are important frames of a video file. Please generate summary such that end user gets gist of what the video is about."
|
| 22 |
prompt_audio_summarization = "This is a video transcript, tell me what is this about: "
|
| 23 |
-
|
| 24 |
|
| 25 |
def timer(func):
|
| 26 |
def wrapper(*args, **kwargs):
|
|
@@ -52,7 +49,7 @@ def navigate_to(page: str) -> None:
|
|
| 52 |
def read_important_frames(video_bytes, top_k_frames) -> List:
|
| 53 |
|
| 54 |
# reading uploaded vidoe in memory
|
| 55 |
-
video_io =
|
| 56 |
|
| 57 |
# opening uploaded video frames
|
| 58 |
container = av.open(video_io, format='mp4')
|
|
|
|
| 10 |
from preprocessing import preprocess_images
|
| 11 |
import time
|
| 12 |
|
|
|
|
| 13 |
from io import BytesIO
|
| 14 |
import torch
|
| 15 |
import soundfile as sf
|
| 16 |
import subprocess
|
| 17 |
from typing import List
|
| 18 |
|
|
|
|
|
|
|
| 19 |
prompt_audio_summarization = "This is a video transcript, tell me what is this about: "
|
| 20 |
+
|
| 21 |
|
| 22 |
def timer(func):
|
| 23 |
def wrapper(*args, **kwargs):
|
|
|
|
| 49 |
def read_important_frames(video_bytes, top_k_frames) -> List:
|
| 50 |
|
| 51 |
# reading uploaded vidoe in memory
|
| 52 |
+
video_io = BytesIO(video_bytes)
|
| 53 |
|
| 54 |
# opening uploaded video frames
|
| 55 |
container = av.open(video_io, format='mp4')
|