Spaces:
Sleeping
Sleeping
| import os | |
| import shutil | |
| from PIL import Image | |
| import ffmpeg | |
| import streamlit as st | |
| import torch | |
| from transformers import AutoProcessor, AutoModel | |
| from src.lstm_model import LSTMNetwork | |
| from src.frames import extract_frames, convert_to_mp4 | |
| # Required dictionary | |
| idx_to_class = {0: 'cover', 1: 'defense', 2: 'flick', 3: 'hook', 4: 'late_cut', | |
| 5: 'lofted', 6: 'pull', 7: 'square_cut', 8: 'straight', 9: 'sweep'} | |
| class_label_mapping = {'cover': 0, 'defense': 1, 'flick': 2, 'hook': 3, 'late_cut': 4, | |
| 'lofted': 5, 'pull': 6, 'square_cut': 7, 'straight': 8, 'sweep': 9} | |
| # Definig the paths | |
| CLIP_MODEL_PATH = "clip-cricket-classifier.pt" | |
| SIGLIP_MODEL_PATH = "siglip-cricket-classifier.pt" | |
| CLIP_MODEL_ID = "openai/clip-vit-base-patch32" | |
| SIGLIP_MODEL_ID = "google/siglip-base-patch16-224" | |
| def embeddings_creators(MODEL_ID): | |
| embedding_processor = AutoProcessor.from_pretrained(MODEL_ID) | |
| embedding_model = AutoModel.from_pretrained(MODEL_ID) | |
| embedding_model.to(device) | |
| return embedding_processor, embedding_model | |
| def load_model(MODEL_PATH): | |
| if MODEL_PATH == CLIP_MODEL_PATH: | |
| input_size = 512 | |
| elif MODEL_PATH == SIGLIP_MODEL_PATH: | |
| input_size = 768 | |
| else: | |
| raise ValueError(f"Invalid model path: {MODEL_PATH}") | |
| model = LSTMNetwork(input_size=input_size, hidden_size=256, num_classes=10).to(device) | |
| model.load_state_dict(torch.load(MODEL_PATH)) | |
| return model | |
| # device | |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| def app(): | |
| st.image("assets/banner.png") | |
| st.title("Cricket Shot Classifier", anchor=False) | |
| model_choice = st.radio("Select a model", ["None", "CLIP", "SIGLIP"]) | |
| if model_choice == "None": | |
| st.stop() | |
| st.write("Please select a model") | |
| if model_choice == "CLIP": | |
| embedding_processor, embedding_model = embeddings_creators(CLIP_MODEL_ID) | |
| model = load_model(CLIP_MODEL_PATH) | |
| elif model_choice == "SIGLIP": | |
| embedding_processor, embedding_model = embeddings_creators(SIGLIP_MODEL_ID) | |
| model = load_model(SIGLIP_MODEL_PATH) | |
| # List sample videos from assets folder | |
| sample_videos = [f for f in os.listdir("assets") if f.endswith(('.avi'))] | |
| if not sample_videos: | |
| st.error("No sample videos found in assets folder") | |
| st.stop() | |
| selected_video = st.selectbox("Select a sample video", sample_videos) | |
| video_path = os.path.join("assets", selected_video) | |
| save_directory = './demo' | |
| os.makedirs(save_directory, exist_ok=True) | |
| new_video_path = f"{save_directory}/{selected_video}" | |
| shutil.copy2(video_path, new_video_path) | |
| final_video_path = f"{save_directory}/{os.path.splitext(os.path.basename(new_video_path))[0]}.mp4" | |
| if not new_video_path.lower().endswith('.mp4'): | |
| convert_to_mp4(new_video_path, final_video_path) | |
| else: | |
| final_video_path = new_video_path | |
| st.video(final_video_path) | |
| frames_dir = f"{save_directory}/frames" | |
| os.makedirs(frames_dir, exist_ok=True) | |
| extract_frames(final_video_path, frames_dir) | |
| st.write("Frames extracted from the video.") | |
| inference_paths = [os.path.join(frames_dir, f) for f in os.listdir(frames_dir) if f.endswith(('.jpg', '.jpeg', '.png'))] | |
| inference_images = [Image.open(path).convert("RGB") for path in inference_paths] | |
| tokens = embedding_processor( | |
| text=None, | |
| images=inference_images, | |
| return_tensors="pt" | |
| ).to(device) | |
| inference_embeddings = embedding_model.get_image_features(**tokens) | |
| with torch.no_grad(): | |
| output = model(inference_embeddings.unsqueeze(0)) | |
| prob = output.softmax(dim=1) | |
| _, indices = torch.sort(prob[0], descending=True) | |
| for idx in indices: | |
| i = idx.item() | |
| st.write(f"Prediction: {idx_to_class[i]}") | |
| st.progress(int(prob[0][i].item() * 100)) | |
| try: | |
| shutil.rmtree(frames_dir) | |
| os.remove(new_video_path) | |
| os.remove(final_video_path) | |
| print(f"Folder '{frames_dir}' and its contents have been deleted.") | |
| except Exception as e: | |
| print(f"Error while deleting folder '{frames_dir}': {e}") | |
| if __name__ == "__main__": | |
| app() | |