CamoVision / app.py
Jazz1508's picture
Update app.py
fe4e370 verified
import torch
import cv2
import numpy as np
import gradio as gr
import segmentation_models_pytorch as smp
from albumentations import Compose, Resize, Normalize
from albumentations.pytorch import ToTensorV2
# Load the trained model
class UNetCamouflage(torch.nn.Module):
def __init__(self, encoder_name="resnet50", encoder_weights=None):
super().__init__()
self.model = smp.Unet(
encoder_name=encoder_name,
encoder_weights=encoder_weights,
in_channels=3,
classes=1
)
def forward(self, x):
return self.model(x)
# Load model weights
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UNetCamouflage().to(device)
model.load_state_dict(torch.load("CamoVision_Final.pth", map_location=device))
model.eval()
# Image preprocessing
transform = Compose([
Resize(256, 256),
Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
ToTensorV2()
])
def predict(image):
"""Process the image and detect camouflage objects."""
img = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
augmented = transform(image=img)
img_tensor = augmented["image"].unsqueeze(0).to(device)
with torch.no_grad():
output = torch.sigmoid(model(img_tensor))
# Convert model output to binary mask
mask = (output.squeeze().cpu().numpy() > 0.5).astype(np.uint8) * 255
# Resize mask to original image size
mask_resized = cv2.resize(mask, (image.shape[1], image.shape[0]))
# Overlay mask on image
overlay = cv2.addWeighted(image, 0.7, cv2.cvtColor(mask_resized, cv2.COLOR_GRAY2BGR), 0.3, 0)
return overlay, mask_resized
# Custom CSS for professional UI
custom_css = """
footer {visibility: hidden;}
#app-title {text-align: center; font-size: 32px; font-weight: bold; color: #FFFFFF; margin-top: 20px;}
#app-description {text-align: center; font-size: 18px; color: #CCCCCC;}
#footer-text {text-align: center; font-size: 16px; color: #BBBBBB; margin-top: 20px;}
"""
# Gradio Blocks UI
with gr.Blocks(css=custom_css, theme="soft") as app:
gr.HTML('<div id="app-title">🕵️‍♂️ Camouflage Object Detection</div>')
gr.HTML('<div id="app-description">Upload an image and detect hidden objects with AI-powered camouflage detection.</div>')
with gr.Row():
with gr.Column():
image_input = gr.Image(type="numpy", label="Upload Image")
submit_btn = gr.Button("🔍 Detect")
with gr.Column():
output_overlay = gr.Image(type="numpy", label="Detected Camouflage (Overlay)")
output_mask = gr.Image(type="numpy", label="Segmentation Mask")
submit_btn.click(predict, inputs=[image_input], outputs=[output_overlay, output_mask])
gr.HTML('<div id="footer-text">Made with ❤️ by Jaskaranjeet Singh</div>')
# Launch the app
if __name__ == "__main__":
app.launch()