import gradio as gr import torch from model import CNNLSTMClassifier from utils import extract_frames import shutil import os import cv2 model = CNNLSTMClassifier() model.load_state_dict(torch.load("lbw_classifier.pt", map_location='cpu')) model.eval() classes = ["Not LBW", "LBW"] def predict(video_file): if isinstance(video_file, dict) and "name" in video_file: video_path = video_file["name"] else: video_path = video_file # Predict frames = extract_frames(video_path) with torch.no_grad(): output = model(frames) pred = torch.argmax(output, dim=1).item() prob = torch.softmax(output, dim=1)[0][pred].item() label = f"{classes[pred]} ({prob:.2%})" # Create annotated video cap = cv2.VideoCapture(video_path) out_path = "/tmp/annotated_video.mp4" fourcc = cv2.VideoWriter_fourcc(*"mp4v") fps = cap.get(cv2.CAP_PROP_FPS) width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) out = cv2.VideoWriter(out_path, fourcc, fps, (width, height)) font = cv2.FONT_HERSHEY_SIMPLEX color = (0, 255, 0) if pred == 1 else (0, 0, 255) while True: ret, frame = cap.read() if not ret: break cv2.putText(frame, label, (30, 60), font, 2, color, 4, cv2.LINE_AA) out.write(frame) cap.release() out.release() return out_path iface = gr.Interface( fn=predict, inputs=gr.Video(), outputs=gr.Video(), # ← return annotated video title="Smart LBW Classifier", description="Upload a cricket video. The model will analyze the frames and overlay the LBW prediction." ) iface.launch()