Spaces:
Sleeping
Sleeping
chuanenlin
commited on
Commit
Β·
83c81a5
1
Parent(s):
a9cbf7c
Revamp
Browse files- .DS_Store +0 -0
- SessionState.py +0 -70
- cached_data/example_features.pt +3 -0
- cached_data/example_fps.npy +3 -0
- cached_data/example_frame_indices.npy +3 -0
- cached_data/example_frames.npy +3 -0
- requirements.txt +4 -2
- whichframe.py +309 -105
.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
SessionState.py
DELETED
|
@@ -1,70 +0,0 @@
|
|
| 1 |
-
import streamlit.report_thread as ReportThread
|
| 2 |
-
from streamlit.server.server import Server
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
class SessionState():
|
| 6 |
-
"""SessionState: Add per-session state to Streamlit."""
|
| 7 |
-
def __init__(self, **kwargs):
|
| 8 |
-
"""A new SessionState object.
|
| 9 |
-
|
| 10 |
-
Parameters
|
| 11 |
-
----------
|
| 12 |
-
**kwargs : any
|
| 13 |
-
Default values for the session state.
|
| 14 |
-
|
| 15 |
-
Example
|
| 16 |
-
-------
|
| 17 |
-
>>> session_state = SessionState(user_name='', favorite_color='black')
|
| 18 |
-
>>> session_state.user_name = 'Mary'
|
| 19 |
-
''
|
| 20 |
-
>>> session_state.favorite_color
|
| 21 |
-
'black'
|
| 22 |
-
|
| 23 |
-
"""
|
| 24 |
-
for key, val in kwargs.items():
|
| 25 |
-
setattr(self, key, val)
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
def get(**kwargs):
|
| 29 |
-
"""Gets a SessionState object for the current session.
|
| 30 |
-
|
| 31 |
-
Creates a new object if necessary.
|
| 32 |
-
|
| 33 |
-
Parameters
|
| 34 |
-
----------
|
| 35 |
-
**kwargs : any
|
| 36 |
-
Default values you want to add to the session state, if we're creating a
|
| 37 |
-
new one.
|
| 38 |
-
|
| 39 |
-
Example
|
| 40 |
-
-------
|
| 41 |
-
>>> session_state = get(user_name='', favorite_color='black')
|
| 42 |
-
>>> session_state.user_name
|
| 43 |
-
''
|
| 44 |
-
>>> session_state.user_name = 'Mary'
|
| 45 |
-
>>> session_state.favorite_color
|
| 46 |
-
'black'
|
| 47 |
-
|
| 48 |
-
Since you set user_name above, next time your script runs this will be the
|
| 49 |
-
result:
|
| 50 |
-
>>> session_state = get(user_name='', favorite_color='black')
|
| 51 |
-
>>> session_state.user_name
|
| 52 |
-
'Mary'
|
| 53 |
-
|
| 54 |
-
"""
|
| 55 |
-
# Hack to get the session object from Streamlit.
|
| 56 |
-
|
| 57 |
-
session_id = ReportThread.get_report_ctx().session_id
|
| 58 |
-
session_info = Server.get_current()._get_session_info(session_id)
|
| 59 |
-
|
| 60 |
-
if session_info is None:
|
| 61 |
-
raise RuntimeError('Could not get Streamlit session object.')
|
| 62 |
-
|
| 63 |
-
this_session = session_info.session
|
| 64 |
-
|
| 65 |
-
# Got the session object! Now let's attach some state into it.
|
| 66 |
-
|
| 67 |
-
if not hasattr(this_session, '_custom_session_state'):
|
| 68 |
-
this_session._custom_session_state = SessionState(**kwargs)
|
| 69 |
-
|
| 70 |
-
return this_session._custom_session_state
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
cached_data/example_features.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:acb80bcbcb93af49b4bfc874f9823402fd30802aadefa21a4bb10ae13853fee9
|
| 3 |
+
size 695497
|
cached_data/example_fps.npy
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:531bf47f7c8d488f38892c54649751f669325416158545dadb696ea8875456ef
|
| 3 |
+
size 136
|
cached_data/example_frame_indices.npy
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:eb6e12a8f3e0a3a71d30a1c9adcbfa686403a9a3fee8d0dfd38320e1a6840b0a
|
| 3 |
+
size 2840
|
cached_data/example_frames.npy
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:dfdc073b8a2236707ed3e75bdac537163dc0f3b1d65fc92834b9c352491e895d
|
| 3 |
+
size 234316928
|
requirements.txt
CHANGED
|
@@ -1,6 +1,8 @@
|
|
|
|
|
| 1 |
Pillow
|
| 2 |
-
|
| 3 |
opencv-python-headless
|
| 4 |
torch
|
| 5 |
git+https://github.com/openai/CLIP.git
|
| 6 |
-
humanfriendly
|
|
|
|
|
|
| 1 |
+
streamlit>=1.1.0
|
| 2 |
Pillow
|
| 3 |
+
yt-dlp
|
| 4 |
opencv-python-headless
|
| 5 |
torch
|
| 6 |
git+https://github.com/openai/CLIP.git
|
| 7 |
+
humanfriendly
|
| 8 |
+
numpy
|
whichframe.py
CHANGED
|
@@ -6,124 +6,328 @@ from PIL import Image
|
|
| 6 |
import clip as openai_clip
|
| 7 |
import torch
|
| 8 |
import math
|
| 9 |
-
import SessionState
|
| 10 |
from humanfriendly import format_timespan
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
def fetch_video(url):
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
|
| 52 |
def img_to_bytes(img):
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
def
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
time = format_timespan(seconds)
|
| 67 |
-
if ss.input == "file":
|
| 68 |
-
st.write("Seen at " + str(time) + " into the video.")
|
| 69 |
else:
|
| 70 |
-
|
| 71 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
|
| 73 |
-
def
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
|
| 82 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
|
| 84 |
hide_streamlit_style = """
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 95 |
st.markdown(hide_streamlit_style, unsafe_allow_html=True)
|
| 96 |
|
| 97 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 98 |
|
| 99 |
-
st.title("Which Frame?")
|
| 100 |
-
st.markdown("
|
| 101 |
-
|
|
|
|
|
|
|
| 102 |
|
| 103 |
-
|
|
|
|
| 104 |
|
| 105 |
-
|
| 106 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 107 |
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
ss.video, ss.video_name = fetch_video(url)
|
| 114 |
-
ss.id = extract.video_id(url)
|
| 115 |
-
ss.url = "https://www.youtube.com/watch?v=" + ss.id
|
| 116 |
-
else:
|
| 117 |
-
st.error("Please upload a video or link to a valid YouTube video")
|
| 118 |
-
st.stop()
|
| 119 |
-
ss.video_frames, ss.fps = extract_frames(ss.video_name)
|
| 120 |
-
ss.video_features = encode_frames(ss.video_frames)
|
| 121 |
-
st.video(ss.url)
|
| 122 |
-
ss.progress = 2
|
| 123 |
-
|
| 124 |
-
if ss.progress == 2:
|
| 125 |
-
ss.text_query = st.text_input("Enter search query (Example: a person with sunglasses and earphones)")
|
| 126 |
-
|
| 127 |
-
if st.button("Submit"):
|
| 128 |
-
if ss.text_query is not None:
|
| 129 |
-
text_search(ss.text_query)
|
|
|
|
| 6 |
import clip as openai_clip
|
| 7 |
import torch
|
| 8 |
import math
|
|
|
|
| 9 |
from humanfriendly import format_timespan
|
| 10 |
+
from moviepy.video.io.VideoFileClip import VideoFileClip
|
| 11 |
+
import numpy as np
|
| 12 |
+
import time
|
| 13 |
+
import os
|
| 14 |
+
import yt_dlp
|
| 15 |
+
import io
|
| 16 |
+
|
| 17 |
+
EXAMPLE_URL = "https://www.youtube.com/watch?v=zTvJJnoWIPk"
|
| 18 |
+
CACHED_DATA_PATH = "cached_data/"
|
| 19 |
+
|
| 20 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 21 |
+
model, preprocess = openai_clip.load("ViT-B/32", device=device)
|
| 22 |
|
| 23 |
def fetch_video(url):
|
| 24 |
+
try:
|
| 25 |
+
ydl_opts = {
|
| 26 |
+
'format': 'bestvideo[height<=360][ext=mp4]',
|
| 27 |
+
'quiet': True,
|
| 28 |
+
'no_warnings': True
|
| 29 |
+
}
|
| 30 |
+
with yt_dlp.YoutubeDL(ydl_opts) as ydl:
|
| 31 |
+
info = ydl.extract_info(url, download=False)
|
| 32 |
+
duration = info.get('duration', 0)
|
| 33 |
+
if duration >= 300: # 5 minutes
|
| 34 |
+
st.error("Please find a YouTube video shorter than 5 minutes.")
|
| 35 |
+
st.stop()
|
| 36 |
+
video_url = info['url']
|
| 37 |
+
return None, video_url
|
| 38 |
+
|
| 39 |
+
except Exception as e:
|
| 40 |
+
st.error(f"Error fetching video: {str(e)}")
|
| 41 |
+
st.error("Try another YouTube video or check if the URL is correct.")
|
| 42 |
+
st.stop()
|
| 43 |
+
|
| 44 |
+
def extract_frames(video, status_text, progress_bar):
|
| 45 |
+
cap = cv2.VideoCapture(video)
|
| 46 |
+
frames = []
|
| 47 |
+
fps = cap.get(cv2.CAP_PROP_FPS)
|
| 48 |
+
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
| 49 |
+
step = max(1, round(fps/2))
|
| 50 |
+
total_frames = frame_count // step
|
| 51 |
+
frame_indices = []
|
| 52 |
+
for i in range(0, frame_count, step):
|
| 53 |
+
cap.set(cv2.CAP_PROP_POS_FRAMES, i)
|
| 54 |
+
ret, frame = cap.read()
|
| 55 |
+
if ret:
|
| 56 |
+
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
| 57 |
+
frames.append(Image.fromarray(frame_rgb))
|
| 58 |
+
frame_indices.append(i)
|
| 59 |
+
|
| 60 |
+
current_frame = len(frames)
|
| 61 |
+
status_text.text(f'Extracting frames... ({min(current_frame, total_frames)}/{total_frames})')
|
| 62 |
+
progress = min(current_frame / total_frames, 1.0)
|
| 63 |
+
progress_bar.progress(progress)
|
| 64 |
+
|
| 65 |
+
cap.release()
|
| 66 |
+
return frames, fps, frame_indices
|
| 67 |
+
|
| 68 |
+
def encode_frames(video_frames, status_text):
|
| 69 |
+
batch_size = 256
|
| 70 |
+
batches = math.ceil(len(video_frames) / batch_size)
|
| 71 |
+
video_features = torch.empty([0, 512], dtype=torch.float32).to(device)
|
| 72 |
+
|
| 73 |
+
for i in range(batches):
|
| 74 |
+
batch_frames = video_frames[i*batch_size : (i+1)*batch_size]
|
| 75 |
+
batch_preprocessed = torch.stack([preprocess(frame) for frame in batch_frames]).to(device)
|
| 76 |
+
with torch.no_grad():
|
| 77 |
+
batch_features = model.encode_image(batch_preprocessed)
|
| 78 |
+
batch_features = batch_features.float()
|
| 79 |
+
batch_features /= batch_features.norm(dim=-1, keepdim=True)
|
| 80 |
+
video_features = torch.cat((video_features, batch_features))
|
| 81 |
+
status_text.text(f'Encoding frames... ({(i+1)*batch_size}/{len(video_frames)})')
|
| 82 |
+
|
| 83 |
+
return video_features
|
| 84 |
|
| 85 |
def img_to_bytes(img):
|
| 86 |
+
img_byte_arr = io.BytesIO()
|
| 87 |
+
img.save(img_byte_arr, format='JPEG')
|
| 88 |
+
img_byte_arr = img_byte_arr.getvalue()
|
| 89 |
+
return img_byte_arr
|
| 90 |
+
|
| 91 |
+
def get_youtube_timestamp_url(url, frame_idx, frame_indices):
|
| 92 |
+
frame_count = frame_indices[frame_idx]
|
| 93 |
+
fps = st.session_state.fps
|
| 94 |
+
seconds = frame_count / fps
|
| 95 |
+
seconds_rounded = int(seconds)
|
| 96 |
+
|
| 97 |
+
if url == EXAMPLE_URL:
|
| 98 |
+
video_id = "zTvJJnoWIPk"
|
|
|
|
|
|
|
|
|
|
| 99 |
else:
|
| 100 |
+
try:
|
| 101 |
+
from urllib.parse import urlparse, parse_qs
|
| 102 |
+
parsed_url = urlparse(url)
|
| 103 |
+
video_id = parse_qs(parsed_url.query)['v'][0]
|
| 104 |
+
except:
|
| 105 |
+
return None, None
|
| 106 |
+
|
| 107 |
+
return f"https://youtu.be/{video_id}?t={seconds_rounded}", seconds
|
| 108 |
+
|
| 109 |
+
def display_results(best_photo_idx, video_frames):
|
| 110 |
+
st.subheader("Top 10 Results")
|
| 111 |
+
for frame_id in best_photo_idx:
|
| 112 |
+
result = video_frames[frame_id]
|
| 113 |
+
st.image(result, width=400)
|
| 114 |
+
|
| 115 |
+
timestamp_url, seconds = get_youtube_timestamp_url(st.session_state.url, frame_id, st.session_state.frame_indices)
|
| 116 |
+
if timestamp_url:
|
| 117 |
+
st.markdown(f"[βΆοΈ Play video at {format_timespan(int(seconds))}]({timestamp_url})")
|
| 118 |
+
|
| 119 |
+
def text_search(search_query, video_features, video_frames, display_results_count=10):
|
| 120 |
+
display_results_count = min(display_results_count, len(video_frames))
|
| 121 |
+
|
| 122 |
+
with torch.no_grad():
|
| 123 |
+
text_tokens = openai_clip.tokenize(search_query).to(device)
|
| 124 |
+
text_features = model.encode_text(text_tokens)
|
| 125 |
+
text_features = text_features.float()
|
| 126 |
+
text_features /= text_features.norm(dim=-1, keepdim=True)
|
| 127 |
+
|
| 128 |
+
video_features = video_features.float()
|
| 129 |
+
|
| 130 |
+
similarities = (100.0 * video_features @ text_features.T)
|
| 131 |
+
values, best_photo_idx = similarities.topk(display_results_count, dim=0)
|
| 132 |
+
display_results(best_photo_idx, video_frames)
|
| 133 |
|
| 134 |
+
def image_search(query_image, video_features, video_frames, display_results_count=10):
|
| 135 |
+
query_image = preprocess(query_image).unsqueeze(0).to(device)
|
| 136 |
+
|
| 137 |
+
with torch.no_grad():
|
| 138 |
+
image_features = model.encode_image(query_image)
|
| 139 |
+
image_features = image_features.float()
|
| 140 |
+
image_features /= image_features.norm(dim=-1, keepdim=True)
|
| 141 |
+
|
| 142 |
+
video_features = video_features.float()
|
| 143 |
+
|
| 144 |
+
similarities = (100.0 * video_features @ image_features.T)
|
| 145 |
+
values, best_photo_idx = similarities.topk(display_results_count, dim=0)
|
| 146 |
+
display_results(best_photo_idx, video_frames)
|
| 147 |
+
|
| 148 |
+
def text_and_image_search(search_query, query_image, video_features, video_frames, display_results_count=10):
|
| 149 |
+
with torch.no_grad():
|
| 150 |
+
text_tokens = openai_clip.tokenize(search_query).to(device)
|
| 151 |
+
text_features = model.encode_text(text_tokens)
|
| 152 |
+
text_features = text_features.float()
|
| 153 |
+
text_features /= text_features.norm(dim=-1, keepdim=True)
|
| 154 |
+
|
| 155 |
+
query_image = preprocess(query_image).unsqueeze(0).to(device)
|
| 156 |
+
with torch.no_grad():
|
| 157 |
+
image_features = model.encode_image(query_image)
|
| 158 |
+
image_features = image_features.float()
|
| 159 |
+
image_features /= image_features.norm(dim=-1, keepdim=True)
|
| 160 |
+
|
| 161 |
+
combined_features = (text_features + image_features) / 2
|
| 162 |
+
|
| 163 |
+
video_features = video_features.float()
|
| 164 |
+
similarities = (100.0 * video_features @ combined_features.T)
|
| 165 |
+
values, best_photo_idx = similarities.topk(display_results_count, dim=0)
|
| 166 |
+
display_results(best_photo_idx, video_frames)
|
| 167 |
+
|
| 168 |
+
def load_cached_data(url):
|
| 169 |
+
if url == EXAMPLE_URL:
|
| 170 |
+
try:
|
| 171 |
+
video_frames = np.load(f"{CACHED_DATA_PATH}example_frames.npy", allow_pickle=True)
|
| 172 |
+
video_features = torch.load(f"{CACHED_DATA_PATH}example_features.pt")
|
| 173 |
+
fps = np.load(f"{CACHED_DATA_PATH}example_fps.npy")
|
| 174 |
+
frame_indices = np.load(f"{CACHED_DATA_PATH}example_frame_indices.npy")
|
| 175 |
+
return video_frames, video_features, fps, frame_indices
|
| 176 |
+
except:
|
| 177 |
+
return None, None, None, None
|
| 178 |
+
return None, None, None, None
|
| 179 |
+
|
| 180 |
+
def save_cached_data(url, video_frames, video_features, fps, frame_indices):
|
| 181 |
+
if url == EXAMPLE_URL:
|
| 182 |
+
os.makedirs(CACHED_DATA_PATH, exist_ok=True)
|
| 183 |
+
np.save(f"{CACHED_DATA_PATH}example_frames.npy", video_frames)
|
| 184 |
+
torch.save(video_features, f"{CACHED_DATA_PATH}example_features.pt")
|
| 185 |
+
np.save(f"{CACHED_DATA_PATH}example_fps.npy", fps)
|
| 186 |
+
np.save(f"{CACHED_DATA_PATH}example_frame_indices.npy", frame_indices)
|
| 187 |
|
| 188 |
+
def clear_cached_data():
|
| 189 |
+
if os.path.exists(CACHED_DATA_PATH):
|
| 190 |
+
try:
|
| 191 |
+
for file in os.listdir(CACHED_DATA_PATH):
|
| 192 |
+
file_path = os.path.join(CACHED_DATA_PATH, file)
|
| 193 |
+
if os.path.isfile(file_path):
|
| 194 |
+
os.unlink(file_path)
|
| 195 |
+
os.rmdir(CACHED_DATA_PATH)
|
| 196 |
+
except Exception as e:
|
| 197 |
+
print(f"Error clearing cache: {e}")
|
| 198 |
+
|
| 199 |
+
st.set_page_config(page_title="Which Frame? ποΈπ", page_icon = "π", layout = "centered", initial_sidebar_state = "collapsed")
|
| 200 |
|
| 201 |
hide_streamlit_style = """
|
| 202 |
+
<style>
|
| 203 |
+
/* Hide Streamlit elements */
|
| 204 |
+
#MainMenu {visibility: hidden;}
|
| 205 |
+
footer {visibility: hidden;}
|
| 206 |
+
* {
|
| 207 |
+
font-family: Avenir;
|
| 208 |
+
}
|
| 209 |
+
.block-container {
|
| 210 |
+
max-width: 800px;
|
| 211 |
+
padding: 2rem 1rem;
|
| 212 |
+
}
|
| 213 |
+
.stTextInput input {
|
| 214 |
+
border-radius: 8px;
|
| 215 |
+
border: 1px solid #E0E0E0;
|
| 216 |
+
padding: 0.75rem;
|
| 217 |
+
font-size: 1rem;
|
| 218 |
+
}
|
| 219 |
+
.stRadio [role="radiogroup"] {
|
| 220 |
+
background: #F8F8F8;
|
| 221 |
+
padding: 1rem;
|
| 222 |
+
border-radius: 12px;
|
| 223 |
+
}
|
| 224 |
+
h1 {text-align: center;}
|
| 225 |
+
.css-gma2qf {display: flex; justify-content: center; font-size: 36px; font-weight: bold;}
|
| 226 |
+
a:link {text-decoration: none;}
|
| 227 |
+
a:hover {text-decoration: none;}
|
| 228 |
+
.st-ba {font-family: Avenir;}
|
| 229 |
+
.st-button {text-align: center;}
|
| 230 |
+
</style>
|
| 231 |
+
"""
|
| 232 |
st.markdown(hide_streamlit_style, unsafe_allow_html=True)
|
| 233 |
|
| 234 |
+
if 'progress' not in st.session_state:
|
| 235 |
+
st.session_state.progress = 1
|
| 236 |
+
if 'video_frames' not in st.session_state:
|
| 237 |
+
st.session_state.video_frames = None
|
| 238 |
+
if 'video_features' not in st.session_state:
|
| 239 |
+
st.session_state.video_features = None
|
| 240 |
+
if 'fps' not in st.session_state:
|
| 241 |
+
st.session_state.fps = None
|
| 242 |
+
if 'video_name' not in st.session_state:
|
| 243 |
+
st.session_state.video_name = 'videos/example.mp4'
|
| 244 |
|
| 245 |
+
st.title("Which Frame? ποΈπ")
|
| 246 |
+
st.markdown("""
|
| 247 |
+
Search a video semantically. For example, which frame has "a person with sunglasses"?
|
| 248 |
+
Search using text, images, or a mix of text + image. WhichFrame uses [CLIP](https://github.com/openai/CLIP) for zero-shot frame classification.
|
| 249 |
+
""")
|
| 250 |
|
| 251 |
+
if 'url' not in st.session_state:
|
| 252 |
+
st.session_state.url = ''
|
| 253 |
|
| 254 |
+
url = st.text_input("Enter a YouTube URL (e.g., https://www.youtube.com/watch?v=zTvJJnoWIPk)", key="url_input")
|
| 255 |
+
|
| 256 |
+
if st.button("Process Video"):
|
| 257 |
+
if not url:
|
| 258 |
+
st.error("Please enter a YouTube URL first")
|
| 259 |
+
else:
|
| 260 |
+
try:
|
| 261 |
+
cached_frames, cached_features, cached_fps, cached_frame_indices = load_cached_data(url)
|
| 262 |
+
|
| 263 |
+
if cached_frames is not None:
|
| 264 |
+
st.session_state.video_frames = cached_frames
|
| 265 |
+
st.session_state.video_features = cached_features
|
| 266 |
+
st.session_state.fps = cached_fps
|
| 267 |
+
st.session_state.frame_indices = cached_frame_indices
|
| 268 |
+
st.session_state.url = url
|
| 269 |
+
st.session_state.progress = 2
|
| 270 |
+
st.success("Loaded cached video data!")
|
| 271 |
+
else:
|
| 272 |
+
with st.spinner('Fetching video...'):
|
| 273 |
+
video, video_url = fetch_video(url)
|
| 274 |
+
st.session_state.url = url
|
| 275 |
+
|
| 276 |
+
progress_bar = st.progress(0)
|
| 277 |
+
status_text = st.empty()
|
| 278 |
+
|
| 279 |
+
# Extract frames
|
| 280 |
+
st.session_state.video_frames, st.session_state.fps, st.session_state.frame_indices = extract_frames(video_url, status_text, progress_bar)
|
| 281 |
+
|
| 282 |
+
# Encode frames
|
| 283 |
+
st.session_state.video_features = encode_frames(st.session_state.video_frames, status_text)
|
| 284 |
+
|
| 285 |
+
save_cached_data(url, st.session_state.video_frames, st.session_state.video_features, st.session_state.fps, st.session_state.frame_indices)
|
| 286 |
+
status_text.text('Finalizing...')
|
| 287 |
+
st.session_state.progress = 2
|
| 288 |
+
progress_bar.progress(100)
|
| 289 |
+
status_text.empty()
|
| 290 |
+
progress_bar.empty()
|
| 291 |
+
st.success("Video processed successfully!")
|
| 292 |
+
|
| 293 |
+
except Exception as e:
|
| 294 |
+
st.error(f"Error processing video: {str(e)}")
|
| 295 |
+
|
| 296 |
+
if st.session_state.progress == 2:
|
| 297 |
+
search_type = st.radio("Search Method", ["Text Search", "Image Search", "Text + Image Search"], index=0)
|
| 298 |
+
|
| 299 |
+
if search_type == "Text Search": # Text Search
|
| 300 |
+
text_query = st.text_input("Type a search query (e.g., 'red car' or 'person with sunglasses')")
|
| 301 |
+
if st.button("Search"):
|
| 302 |
+
if not text_query:
|
| 303 |
+
st.error("Please enter a search query first")
|
| 304 |
+
else:
|
| 305 |
+
text_search(text_query, st.session_state.video_features, st.session_state.video_frames)
|
| 306 |
+
elif search_type == "Image Search": # Image Search
|
| 307 |
+
uploaded_file = st.file_uploader("Upload a query image", type=['png', 'jpg', 'jpeg'])
|
| 308 |
+
if uploaded_file is not None:
|
| 309 |
+
query_image = Image.open(uploaded_file).convert('RGB')
|
| 310 |
+
st.image(query_image, caption="Query Image", width=200)
|
| 311 |
+
if st.button("Search"):
|
| 312 |
+
if uploaded_file is None:
|
| 313 |
+
st.error("Please upload an image first")
|
| 314 |
+
else:
|
| 315 |
+
image_search(query_image, st.session_state.video_features, st.session_state.video_frames)
|
| 316 |
+
else: # Text + Image Search
|
| 317 |
+
text_query = st.text_input("Type a search query")
|
| 318 |
+
uploaded_file = st.file_uploader("Upload a query image", type=['png', 'jpg', 'jpeg'])
|
| 319 |
+
if uploaded_file is not None:
|
| 320 |
+
query_image = Image.open(uploaded_file).convert('RGB')
|
| 321 |
+
st.image(query_image, caption="Query Image", width=200)
|
| 322 |
+
|
| 323 |
+
if st.button("Search"):
|
| 324 |
+
if not text_query or uploaded_file is None:
|
| 325 |
+
st.error("Please provide both text query and image")
|
| 326 |
+
else:
|
| 327 |
+
text_and_image_search(text_query, query_image, st.session_state.video_features, st.session_state.video_frames)
|
| 328 |
|
| 329 |
+
st.markdown("---")
|
| 330 |
+
st.markdown(
|
| 331 |
+
"By [David Chuan-En Lin](https://chuanenlin.com/). "
|
| 332 |
+
"Play with the code at [https://github.com/chuanenlin/whichframe](https://github.com/chuanenlin/whichframe)."
|
| 333 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|