Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| from model import CNNLSTMClassifier | |
| from utils import extract_frames | |
| import shutil | |
| import os | |
| import cv2 | |
| model = CNNLSTMClassifier() | |
| model.load_state_dict(torch.load("lbw_classifier.pt", map_location='cpu')) | |
| model.eval() | |
| classes = ["Not LBW", "LBW"] | |
| def predict(video_file): | |
| if isinstance(video_file, dict) and "name" in video_file: | |
| video_path = video_file["name"] | |
| else: | |
| video_path = video_file | |
| # Predict | |
| frames = extract_frames(video_path) | |
| with torch.no_grad(): | |
| output = model(frames) | |
| pred = torch.argmax(output, dim=1).item() | |
| prob = torch.softmax(output, dim=1)[0][pred].item() | |
| label = f"{classes[pred]} ({prob:.2%})" | |
| # Create annotated video | |
| cap = cv2.VideoCapture(video_path) | |
| out_path = "/tmp/annotated_video.mp4" | |
| fourcc = cv2.VideoWriter_fourcc(*"mp4v") | |
| fps = cap.get(cv2.CAP_PROP_FPS) | |
| width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
| height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
| out = cv2.VideoWriter(out_path, fourcc, fps, (width, height)) | |
| font = cv2.FONT_HERSHEY_SIMPLEX | |
| color = (0, 255, 0) if pred == 1 else (0, 0, 255) | |
| while True: | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| cv2.putText(frame, label, (30, 60), font, 2, color, 4, cv2.LINE_AA) | |
| out.write(frame) | |
| cap.release() | |
| out.release() | |
| return out_path | |
| iface = gr.Interface( | |
| fn=predict, | |
| inputs=gr.Video(), | |
| outputs=gr.Video(), # ← return annotated video | |
| title="Smart LBW Classifier", | |
| description="Upload a cricket video. The model will analyze the frames and overlay the LBW prediction." | |
| ) | |
| iface.launch() | |