Spaces:
Sleeping
Sleeping
| import torch | |
| import cv2 | |
| import numpy as np | |
| import gradio as gr | |
| import segmentation_models_pytorch as smp | |
| from albumentations import Compose, Resize, Normalize | |
| from albumentations.pytorch import ToTensorV2 | |
| from tqdm import tqdm | |
| import tempfile | |
| import os | |
| # Load the trained model | |
| class UNetCamouflage(torch.nn.Module): | |
| def __init__(self, encoder_name="resnet50", encoder_weights=None): | |
| super().__init__() | |
| self.model = smp.Unet( | |
| encoder_name=encoder_name, | |
| encoder_weights=encoder_weights, | |
| in_channels=3, | |
| classes=1 | |
| ) | |
| def forward(self, x): | |
| return self.model(x) | |
| # Load model weights | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model = UNetCamouflage().to(device) | |
| model.load_state_dict(torch.load("CamoVision_Final.pth", map_location=device)) | |
| model.eval() | |
| # Image preprocessing | |
| transform = Compose([ | |
| Resize(256, 256), | |
| Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), | |
| ToTensorV2() | |
| ]) | |
| def predict_image(image): | |
| """Process the image and detect camouflage objects.""" | |
| img = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
| augmented = transform(image=img) | |
| img_tensor = augmented["image"].unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| output = torch.sigmoid(model(img_tensor)) | |
| mask = (output.squeeze().cpu().numpy() > 0.5).astype(np.uint8) * 255 | |
| mask_resized = cv2.resize(mask, (image.shape[1], image.shape[0])) | |
| overlay = cv2.addWeighted(image, 0.7, cv2.cvtColor(mask_resized, cv2.COLOR_GRAY2BGR), 0.3, 0) | |
| return overlay, mask_resized | |
| def predict_video(video_path): | |
| """Process the video and detect camouflage objects in each frame.""" | |
| cap = cv2.VideoCapture(video_path) | |
| fourcc = cv2.VideoWriter_fourcc(*'mp4v') | |
| temp_video_path = os.path.join(tempfile.gettempdir(), "processed_video.mp4") | |
| out = cv2.VideoWriter(temp_video_path, fourcc, cap.get(cv2.CAP_PROP_FPS), | |
| (int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)))) | |
| total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| progress_bar = tqdm(total=total_frames, desc="Processing Video", unit="frame") | |
| while cap.isOpened(): | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| overlay, _ = predict_image(frame) | |
| out.write(overlay) | |
| progress_bar.update(1) | |
| cap.release() | |
| out.release() | |
| progress_bar.close() | |
| return temp_video_path | |
| # Custom CSS for professional UI | |
| custom_css = """ | |
| footer {visibility: hidden;} | |
| #app-title {text-align: center; font-size: 32px; font-weight: bold; color: #FFFFFF; margin-top: 20px;} | |
| #app-description {text-align: center; font-size: 18px; color: #CCCCCC;} | |
| #footer-text {text-align: center; font-size: 16px; color: #BBBBBB; margin-top: 20px;} | |
| """ | |
| # Gradio Blocks UI | |
| with gr.Blocks(css=custom_css, theme="soft") as app: | |
| gr.HTML('<div id="app-title">π΅οΈββοΈ Camouflage Object Detection</div>') | |
| gr.HTML('<div id="app-description">Upload an image or video to detect hidden objects with AI-powered camouflage detection.</div>') | |
| with gr.Tabs(): | |
| with gr.Tab("Image Detection"): | |
| with gr.Row(): | |
| image_input = gr.Image(type="numpy", label="Upload Image") | |
| submit_img_btn = gr.Button("π Detect") | |
| with gr.Row(): | |
| output_overlay = gr.Image(type="numpy", label="Detected Camouflage (Overlay)") | |
| output_mask = gr.Image(type="numpy", label="Segmentation Mask") | |
| submit_img_btn.click(predict_image, inputs=[image_input], outputs=[output_overlay, output_mask]) | |
| with gr.Tab("Video Detection"): | |
| with gr.Row(): | |
| video_input = gr.Video(label="Upload Video") | |
| submit_vid_btn = gr.Button("π Process Video") | |
| with gr.Row(): | |
| output_video = gr.Video(label="Processed Video") | |
| progress_label = gr.Label("Processing: 0%", label="Progress") | |
| def process_video_with_progress(video_file): | |
| processed_video = predict_video(video_file) | |
| return processed_video, "Processing Complete!" | |
| submit_vid_btn.click(process_video_with_progress, inputs=[video_input], outputs=[output_video, progress_label]) | |
| gr.HTML('<div id="footer-text">Made with β€οΈ by Jaskaranjeet Singh</div>') | |
| # Launch the app | |
| if __name__ == "__main__": | |
| app.launch(share=True) |