Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import tempfile | |
| import matplotlib | |
| matplotlib.use('Agg') | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import cv2 | |
| import torch | |
| import clip | |
| import os | |
| from tqdm import tqdm | |
| from PIL import Image | |
| # Load the CLIP model | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model, preprocess = clip.load("ViT-B/32", device) | |
| state = { | |
| 'video_embedding': None, | |
| 'text_embedding': None, | |
| 'similarity_graph': None, | |
| 'last_video_path': None # Add this line to store the last processed video file path | |
| } | |
| def process_video(video_file): | |
| video_file_path = os.path.abspath(video_file.name) | |
| state['last_video_path'] = video_file_path | |
| cap = cv2.VideoCapture(video_file_path) | |
| if not cap.isOpened(): | |
| raise ValueError(f"Failed to open video file: {video_file}") | |
| frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| fps = int(cap.get(cv2.CAP_PROP_FPS)) | |
| image_vectors = torch.zeros((frame_count, 512), device=device) | |
| for i in tqdm(range(frame_count)): | |
| ret, frame = cap.read() | |
| if ret: | |
| with torch.no_grad(): | |
| image_vectors[i] = model.encode_image( | |
| preprocess(Image.fromarray(frame)).unsqueeze(0).to(device) | |
| ) | |
| else: | |
| print(f"Failed to read frame {i}") | |
| break | |
| state['video_embedding'] = image_vectors | |
| calculate_similarity() | |
| def process_text(query_text): | |
| text_inputs = torch.cat([clip.tokenize([query_text]).to(device)]) | |
| with torch.no_grad(): | |
| text_features = model.encode_text(text_inputs) | |
| text_features /= text_features.norm(dim=-1, keepdim=True) | |
| state['text_embedding'] = text_features # | |
| calculate_similarity() | |
| def calculate_similarity(video_file=None, query_text=None): | |
| if video_file: | |
| video_file_path = os.path.abspath(video_file.name) | |
| # Only process the video if the file path has changed | |
| if video_file_path != state['last_video_path']: | |
| process_video(video_file) | |
| if query_text: | |
| process_text(query_text) | |
| image_vectors = state['video_embedding'] | |
| text_features = state['text_embedding'] | |
| if image_vectors is None or text_features is None: | |
| return "Please provide both video and text input" # or return an error image | |
| image_vectors /= torch.norm(image_vectors, dim=1, keepdim=True) | |
| similarities = (image_vectors @ text_features.T).squeeze(1) | |
| closest_idx = similarities.argmax().item() | |
| frame_count = image_vectors.shape[0] | |
| fps = state.get('fps', 30) | |
| time_in_seconds = np.arange(frame_count) / fps | |
| similarity_scores = similarities.cpu().numpy() | |
| plt.figure(figsize=(10, 5)) | |
| plt.plot(time_in_seconds, similarity_scores, label='Similarity Score', linestyle='-', color='blue') | |
| plt.axvline(x=closest_idx/fps, color='red', linestyle='--', label=f'Closest Match at {closest_idx/fps:.2f} seconds') | |
| plt.xticks(np.arange(0, time_in_seconds[-1] + 10, 10)) | |
| plt.xlabel('Video Time (seconds)') | |
| plt.ylabel('Similarity Score') | |
| plt.legend(loc='upper right') | |
| plt.title('Similarity Score vs Video Time') | |
| plt.grid(True) | |
| plt.savefig("output_plot.png") # Save the plot to a file | |
| plt.close() # Close the plot to free up memory | |
| state['similarity_graph'] = "output_plot.png" # Save graph to state | |
| return "output_plot.png", None | |
| def get_similarity_graph(): | |
| return state['similarity_graph'] # Return the saved graph | |
| # Define Gradio interface | |
| iface = gr.Interface( | |
| fn=calculate_similarity, | |
| inputs=[gr.inputs.File(label="Upload a video"), gr.Textbox(label="Enter text")], | |
| outputs=[gr.outputs.Image(type="filepath", label="Similarity Graph"), gr.outputs.Textbox(label="Error Message")] | |
| ) | |
| iface.launch() |