CamoVision2.0 / app.py
Jazz1508's picture
Update app.py
1a8568c verified
import torch
import cv2
import numpy as np
import gradio as gr
import segmentation_models_pytorch as smp
from albumentations import Compose, Resize, Normalize
from albumentations.pytorch import ToTensorV2
from tqdm import tqdm
import tempfile
import os
# Load the trained model
class UNetCamouflage(torch.nn.Module):
def __init__(self, encoder_name="resnet50", encoder_weights=None):
super().__init__()
self.model = smp.Unet(
encoder_name=encoder_name,
encoder_weights=encoder_weights,
in_channels=3,
classes=1
)
def forward(self, x):
return self.model(x)
# Load model weights
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UNetCamouflage().to(device)
model.load_state_dict(torch.load("CamoVision_Final.pth", map_location=device))
model.eval()
# Image preprocessing
transform = Compose([
Resize(256, 256),
Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
ToTensorV2()
])
def predict_image(image):
"""Process the image and detect camouflage objects."""
img = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
augmented = transform(image=img)
img_tensor = augmented["image"].unsqueeze(0).to(device)
with torch.no_grad():
output = torch.sigmoid(model(img_tensor))
mask = (output.squeeze().cpu().numpy() > 0.5).astype(np.uint8) * 255
mask_resized = cv2.resize(mask, (image.shape[1], image.shape[0]))
overlay = cv2.addWeighted(image, 0.7, cv2.cvtColor(mask_resized, cv2.COLOR_GRAY2BGR), 0.3, 0)
return overlay, mask_resized
def predict_video(video_path):
"""Process the video and detect camouflage objects in each frame."""
cap = cv2.VideoCapture(video_path)
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
temp_video_path = os.path.join(tempfile.gettempdir(), "processed_video.mp4")
out = cv2.VideoWriter(temp_video_path, fourcc, cap.get(cv2.CAP_PROP_FPS),
(int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))))
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
progress_bar = tqdm(total=total_frames, desc="Processing Video", unit="frame")
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
overlay, _ = predict_image(frame)
out.write(overlay)
progress_bar.update(1)
cap.release()
out.release()
progress_bar.close()
return temp_video_path
# Custom CSS for professional UI
custom_css = """
footer {visibility: hidden;}
#app-title {text-align: center; font-size: 32px; font-weight: bold; color: #FFFFFF; margin-top: 20px;}
#app-description {text-align: center; font-size: 18px; color: #CCCCCC;}
#footer-text {text-align: center; font-size: 16px; color: #BBBBBB; margin-top: 20px;}
"""
# Gradio Blocks UI
with gr.Blocks(css=custom_css, theme="soft") as app:
gr.HTML('<div id="app-title">πŸ•΅οΈβ€β™‚οΈ Camouflage Object Detection</div>')
gr.HTML('<div id="app-description">Upload an image or video to detect hidden objects with AI-powered camouflage detection.</div>')
with gr.Tabs():
with gr.Tab("Image Detection"):
with gr.Row():
image_input = gr.Image(type="numpy", label="Upload Image")
submit_img_btn = gr.Button("πŸ” Detect")
with gr.Row():
output_overlay = gr.Image(type="numpy", label="Detected Camouflage (Overlay)")
output_mask = gr.Image(type="numpy", label="Segmentation Mask")
submit_img_btn.click(predict_image, inputs=[image_input], outputs=[output_overlay, output_mask])
with gr.Tab("Video Detection"):
with gr.Row():
video_input = gr.Video(label="Upload Video")
submit_vid_btn = gr.Button("πŸ” Process Video")
with gr.Row():
output_video = gr.Video(label="Processed Video")
progress_label = gr.Label("Processing: 0%", label="Progress")
def process_video_with_progress(video_file):
processed_video = predict_video(video_file)
return processed_video, "Processing Complete!"
submit_vid_btn.click(process_video_with_progress, inputs=[video_input], outputs=[output_video, progress_label])
gr.HTML('<div id="footer-text">Made with ❀️ by Jaskaranjeet Singh</div>')
# Launch the app
if __name__ == "__main__":
app.launch(share=True)