Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import cv2 | |
| import PostProcess | |
| import time | |
| from huggingface_hub import hf_hub_download | |
| model_checkpoint = hf_hub_download( | |
| repo_id="sunjuice/firearm_checkpoint", | |
| filename="best.torchscript" | |
| ) | |
| video = hf_hub_download( | |
| repo_id="sunjuice/firearm_checkpoint", | |
| filename="evaluation.mp4" | |
| ) | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| torchmodel = torch.jit.load(model_checkpoint, map_location=device) | |
| torchmodel.eval() | |
| def preprocess_image(image_ori) -> torch.Tensor: | |
| image = cv2.resize(image_ori, (640, 640)) | |
| image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
| image_tensor = torch.from_numpy(image_rgb).float() | |
| image_tensor = image_tensor.permute(2, 0, 1) # Change from HWC to CHW format | |
| image_tensor = image_tensor / 255.0 # Normalize to [0, 1] | |
| image_tensor = image_tensor.unsqueeze(0) | |
| return image_tensor, image_rgb | |
| def run_model(image): | |
| start_time = time.time() | |
| result = torchmodel(image)[0] | |
| end_time = time.time() | |
| return result, (end_time - start_time) * 1000 | |
| def run_video(): | |
| cap = cv2.VideoCapture(video) | |
| postprocessor = PostProcess.PostProcessor() | |
| with torch.no_grad(): | |
| while True: | |
| ret, frame = cap.read() | |
| if not ret: | |
| cap.set(cv2.CAP_PROP_POS_FRAMES, 0) | |
| continue | |
| image, image_rgb = preprocess_image(frame) | |
| image = image.to(device) | |
| result, duration_ms = run_model(image) | |
| postprocessor.set_image(image_rgb) | |
| postprocessor.set_time(duration_ms) | |
| postprocessor.set_result(result) | |
| output_frame = postprocessor.get_frame() | |
| # Convert BGR → RGB for Gradio | |
| output_frame = cv2.cvtColor(output_frame, cv2.COLOR_BGR2RGB) | |
| yield output_frame | |
| cap.release() | |
| demo = gr.Interface( | |
| fn=run_video, | |
| inputs=[], | |
| outputs=gr.Image(streaming=True), | |
| live=True | |
| ) | |
| demo.launch() |