shashnk commited on
Commit
480a079
·
1 Parent(s): 9c1354c

Upload 3 files

Browse files
Files changed (4) hide show
  1. .gitattributes +1 -0
  2. Trondheim Norway 4K.mp4 +3 -0
  3. main.py +110 -0
  4. requirements.txt +8 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ Trondheim[[:space:]]Norway[[:space:]]4K.mp4 filter=lfs diff=lfs merge=lfs -text
Trondheim Norway 4K.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ce4fc3f306411df158d7069b0a63aacb3fc2ea07379d2fdf35f6933713498084
3
+ size 32297332
main.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import tempfile
3
+ import matplotlib
4
+ matplotlib.use('Agg')
5
+ import matplotlib.pyplot as plt
6
+ import numpy as np
7
+ import cv2
8
+ import torch
9
+ import clip
10
+ import os
11
+ from tqdm import tqdm
12
+ from PIL import Image
13
+
14
+ # Load the CLIP model
15
+ device = "cuda" if torch.cuda.is_available() else "cpu"
16
+ model, preprocess = clip.load("ViT-B/32", device)
17
+
18
+ state = {
19
+ 'video_embedding': None,
20
+ 'text_embedding': None,
21
+ 'similarity_graph': None,
22
+ 'last_video_path': None # Add this line to store the last processed video file path
23
+ }
24
+
25
+
26
+ def process_video(video_file):
27
+ video_file_path = os.path.abspath(video_file.name)
28
+ state['last_video_path'] = video_file_path
29
+
30
+ cap = cv2.VideoCapture(video_file_path)
31
+
32
+ if not cap.isOpened():
33
+ raise ValueError(f"Failed to open video file: {video_file}")
34
+ frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
35
+ fps = int(cap.get(cv2.CAP_PROP_FPS))
36
+
37
+ image_vectors = torch.zeros((frame_count, 512), device=device)
38
+ for i in tqdm(range(frame_count)):
39
+ ret, frame = cap.read()
40
+ if ret:
41
+ with torch.no_grad():
42
+ image_vectors[i] = model.encode_image(
43
+ preprocess(Image.fromarray(frame)).unsqueeze(0).to(device)
44
+ )
45
+ else:
46
+ print(f"Failed to read frame {i}")
47
+ break
48
+
49
+ state['video_embedding'] = image_vectors
50
+ calculate_similarity()
51
+
52
+
53
+ def process_text(query_text):
54
+ text_inputs = torch.cat([clip.tokenize([query_text]).to(device)])
55
+ with torch.no_grad():
56
+ text_features = model.encode_text(text_inputs)
57
+ text_features /= text_features.norm(dim=-1, keepdim=True)
58
+ state['text_embedding'] = text_features #
59
+ calculate_similarity()
60
+
61
+
62
+ def calculate_similarity(video_file=None, query_text=None):
63
+ if video_file:
64
+ video_file_path = os.path.abspath(video_file.name)
65
+ # Only process the video if the file path has changed
66
+ if video_file_path != state['last_video_path']:
67
+ process_video(video_file)
68
+ if query_text:
69
+ process_text(query_text)
70
+
71
+ image_vectors = state['video_embedding']
72
+ text_features = state['text_embedding']
73
+ if image_vectors is None or text_features is None:
74
+ return "Please provide both video and text input" # or return an error image
75
+
76
+ image_vectors /= torch.norm(image_vectors, dim=1, keepdim=True)
77
+ similarities = (image_vectors @ text_features.T).squeeze(1)
78
+ closest_idx = similarities.argmax().item()
79
+
80
+ frame_count = image_vectors.shape[0]
81
+ fps = state.get('fps', 30)
82
+ time_in_seconds = np.arange(frame_count) / fps
83
+ similarity_scores = similarities.cpu().numpy()
84
+
85
+ plt.figure(figsize=(10, 5))
86
+ plt.plot(time_in_seconds, similarity_scores, label='Similarity Score', linestyle='-', color='blue')
87
+ plt.axvline(x=closest_idx/fps, color='red', linestyle='--', label=f'Closest Match at {closest_idx/fps:.2f} seconds')
88
+ plt.xticks(np.arange(0, time_in_seconds[-1] + 10, 10))
89
+ plt.xlabel('Video Time (seconds)')
90
+ plt.ylabel('Similarity Score')
91
+ plt.legend(loc='upper right')
92
+ plt.title('Similarity Score vs Video Time')
93
+ plt.grid(True)
94
+
95
+ plt.savefig("output_plot.png") # Save the plot to a file
96
+ plt.close() # Close the plot to free up memory
97
+
98
+ state['similarity_graph'] = "output_plot.png" # Save graph to state
99
+ return "output_plot.png", None
100
+
101
+ def get_similarity_graph():
102
+ return state['similarity_graph'] # Return the saved graph
103
+
104
+ # Define Gradio interface
105
+ iface = gr.Interface(
106
+ fn=calculate_similarity,
107
+ inputs=[gr.inputs.File(label="Upload a video"), gr.inputs.Textbox(label="Enter text")],
108
+ outputs=[gr.outputs.Image(type="filepath", label="Similarity Graph"), gr.outputs.Textbox(label="Error Message")]
109
+ )
110
+ iface.launch()
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ gradio
2
+ matplotlib
3
+ numpy
4
+ opencv-python
5
+ torch
6
+ openai-clip
7
+ tqdm
8
+ Pillow