import gradio as gr import torch import torch.nn as nn from torchvision import models from ultralytics import YOLO import cv2 import numpy as np # ========================================== # 1. SETUP & MODEL LOADING # ========================================== device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"--- System Boot: Using {device} ---") # --- LOAD VISUAL SYSTEM (YOLO) --- try: yolo_model = YOLO("best.pt") print("✅ Visual System: Custom EAGLE A7 Model Loaded") except: yolo_model = YOLO("yolo11n.pt") # --- LOAD THERMAL SYSTEM (ResNet-18) --- def get_thermal_model(): model = models.resnet18(weights=None) model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False) model.fc = nn.Linear(model.fc.in_features, 1) return model thermal_model = get_thermal_model().to(device) MODEL_PATH = "thermal_landmine_scanner.pth" try: thermal_model.load_state_dict(torch.load(MODEL_PATH, map_location=device)) thermal_model.eval() print(f"✅ Thermal System: Loaded {MODEL_PATH}") except Exception as e: print(f"❌ CRITICAL ERROR: Could not load thermal model. {e}") # ========================================== # 2. SETUP GRAD-CAM (THE "X-RAY" HOOK) # ========================================== # We need to steal the features from inside the model while it thinks features_blob = [] def hook_feature(module, input, output): features_blob.clear() # Clear old data features_blob.append(output.data.cpu().numpy()) # Attach the spy hook to the last layer (Layer 4) thermal_model.layer4.register_forward_hook(hook_feature) # Get weights from the final decision layer params = list(thermal_model.parameters()) weight_softmax = params[-2].data.cpu().numpy() # The weights connecting features to "Mine/Safe" # ========================================== # 3. PROCESSING FUNCTIONS # ========================================== def run_visual_detection(image): if image is None: return None, "Waiting for feed..." results = yolo_model.predict(image, conf=0.65) return results[0].plot(), f"Objects Detected: {len(results[0].boxes)}" def run_thermal_scan(image): if image is None: return None, "No Signal", "N/A" # --- PREPROCESSING (Standard) --- if len(image.shape) == 3: gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY) else: gray = image clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8)) enhanced_img = clahe.apply(gray) resized = cv2.resize(enhanced_img, (224, 224)) normalized_img = resized.astype(np.float32) / 255.0 tensor = torch.from_numpy(normalized_img).float().unsqueeze(0).unsqueeze(0) tensor = tensor.to(device) # --- INFERENCE --- with torch.no_grad(): output = thermal_model(tensor) prob = torch.sigmoid(output).item() # --- GENERATE HEATMAP (Explainable AI) --- # 1. Get the features captured by our hook [1, 512, 7, 7] feature_data = features_blob[0] # 2. Calculate the "Attention Map" cam = np.zeros((7, 7), dtype=np.float32) # Use the weights for the "Mine" class to weight the features for i, w in enumerate(weight_softmax[0]): cam += w * feature_data[0, i, :, :] # 3. Process the CAM cam = np.maximum(cam, 0) # ReLU (Remove negative influence) cam = cv2.resize(cam, (224, 224)) # Resize to image size cam = cam - np.min(cam) if np.max(cam) != 0: cam = cam / np.max(cam) # Normalize 0-1 # 4. Colorize heatmap = cv2.applyColorMap(np.uint8(255*cam), cv2.COLORMAP_JET) # 5. Overlay on original (Grayscale -> RGB) orig_rgb = cv2.cvtColor(resized, cv2.COLOR_GRAY2RGB) # Mix: 60% Original Image + 40% Heatmap # If prob is low (Safe), we show less heatmap so it doesn't look scary intensity = 0.5 if prob > 0.5 else 0.2 overlay = cv2.addWeighted(orig_rgb, 1-intensity, heatmap, intensity, 0) # --- DECISION --- is_mine = prob > 0.50 if is_mine: label = f"

🔴 MINE DETECTED

" conf_text = f"CONFIDENCE: {prob*100:.1f}%" else: label = f"

🟢 SAFE SOIL

" conf_text = f"Risk Level: {prob*100:.1f}%" return overlay, label, conf_text # ========================================== # 4. DASHBOARD UI # ========================================== custom_css = ".gradio-container {background-color: #1e1e1e; color: white}" with gr.Blocks(css=custom_css, title="EAGLE A7 Mission Control") as demo: gr.Markdown("# 🦅 EAGLE A7: Autonomous Demining Interface") with gr.Tabs(): with gr.TabItem("☀️ Daytime Vision"): with gr.Row(): vis_input = gr.Image(label="Input", type="numpy") vis_output = gr.Image(label="YOLO Detections") vis_btn = gr.Button("SCAN", variant="primary") vis_status = gr.Textbox(label="Status") vis_btn.click(run_visual_detection, inputs=vis_input, outputs=[vis_output, vis_status]) with gr.TabItem("🌙 Night Vision (X-Ray)"): gr.Markdown("### Thermal Anomaly Localization") with gr.Row(): with gr.Column(): therm_input = gr.Image(label="Thermal Feed", type="numpy") therm_btn = gr.Button("ANALYZE & LOCATE", variant="stop") with gr.Column(): therm_output = gr.Image(label="Target Localization (Heatmap)") therm_label = gr.HTML(label="Result") therm_conf = gr.Textbox(label="Telemetry") therm_btn.click(run_thermal_scan, inputs=therm_input, outputs=[therm_output, therm_label, therm_conf]) print("--- Launching Dashboard with X-Ray Vision ---") demo.launch(server_name="0.0.0.0", share=True)