File size: 3,764 Bytes
480a079
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63c8085
480a079
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import gradio as gr
import tempfile
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import numpy as np
import cv2
import torch
import clip
import os
from tqdm import tqdm
from PIL import Image

# Load the CLIP model
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device)

state = {
    'video_embedding': None,
    'text_embedding': None,
    'similarity_graph': None,
    'last_video_path': None  # Add this line to store the last processed video file path
}


def process_video(video_file):
    video_file_path = os.path.abspath(video_file.name)
    state['last_video_path'] = video_file_path 

    cap = cv2.VideoCapture(video_file_path)

    if not cap.isOpened():
        raise ValueError(f"Failed to open video file: {video_file}")
    frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    fps = int(cap.get(cv2.CAP_PROP_FPS))

    image_vectors = torch.zeros((frame_count, 512), device=device)
    for i in tqdm(range(frame_count)):
        ret, frame = cap.read()
        if ret:
            with torch.no_grad():
                image_vectors[i] = model.encode_image(
                    preprocess(Image.fromarray(frame)).unsqueeze(0).to(device)
                )
        else:
            print(f"Failed to read frame {i}")
            break

    state['video_embedding'] = image_vectors
    calculate_similarity()


def process_text(query_text):
    text_inputs = torch.cat([clip.tokenize([query_text]).to(device)])
    with torch.no_grad():
        text_features = model.encode_text(text_inputs)
    text_features /= text_features.norm(dim=-1, keepdim=True)
    state['text_embedding'] = text_features  #
    calculate_similarity()


def calculate_similarity(video_file=None, query_text=None):
    if video_file:
        video_file_path = os.path.abspath(video_file.name)
        # Only process the video if the file path has changed
        if video_file_path != state['last_video_path']:
            process_video(video_file)
    if query_text:
        process_text(query_text)

    image_vectors = state['video_embedding']
    text_features = state['text_embedding']
    if image_vectors is None or text_features is None:
        return "Please provide both video and text input"  # or return an error image

    image_vectors /= torch.norm(image_vectors, dim=1, keepdim=True)
    similarities = (image_vectors @ text_features.T).squeeze(1)
    closest_idx = similarities.argmax().item()

    frame_count = image_vectors.shape[0]
    fps = state.get('fps', 30) 
    time_in_seconds = np.arange(frame_count) / fps
    similarity_scores = similarities.cpu().numpy()

    plt.figure(figsize=(10, 5))
    plt.plot(time_in_seconds, similarity_scores, label='Similarity Score', linestyle='-', color='blue')
    plt.axvline(x=closest_idx/fps, color='red', linestyle='--', label=f'Closest Match at {closest_idx/fps:.2f} seconds')
    plt.xticks(np.arange(0, time_in_seconds[-1] + 10, 10))
    plt.xlabel('Video Time (seconds)')
    plt.ylabel('Similarity Score')
    plt.legend(loc='upper right')
    plt.title('Similarity Score vs Video Time')
    plt.grid(True)

    plt.savefig("output_plot.png")  # Save the plot to a file
    plt.close()  # Close the plot to free up memory

    state['similarity_graph'] = "output_plot.png"  # Save graph to state
    return "output_plot.png", None

def get_similarity_graph():
    return state['similarity_graph']  # Return the saved graph

# Define Gradio interface
iface = gr.Interface(
    fn=calculate_similarity, 
    inputs=[gr.inputs.File(label="Upload a video"), gr.Textbox(label="Enter text")], 
    outputs=[gr.outputs.Image(type="filepath", label="Similarity Graph"), gr.outputs.Textbox(label="Error Message")]
)
iface.launch()