Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import torch.nn as nn | |
| from torchvision import models | |
| from ultralytics import YOLO | |
| import cv2 | |
| import numpy as np | |
| # ========================================== | |
| # 1. SETUP & MODEL LOADING | |
| # ========================================== | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"--- System Boot: Using {device} ---") | |
| # --- LOAD VISUAL SYSTEM (YOLO) --- | |
| try: | |
| yolo_model = YOLO("best.pt") | |
| print("β Visual System: Custom EAGLE A7 Model Loaded") | |
| except: | |
| yolo_model = YOLO("yolo11n.pt") | |
| # --- LOAD THERMAL SYSTEM (ResNet-18) --- | |
| def get_thermal_model(): | |
| model = models.resnet18(weights=None) | |
| model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False) | |
| model.fc = nn.Linear(model.fc.in_features, 1) | |
| return model | |
| thermal_model = get_thermal_model().to(device) | |
| MODEL_PATH = "thermal_landmine_scanner.pth" | |
| try: | |
| thermal_model.load_state_dict(torch.load(MODEL_PATH, map_location=device)) | |
| thermal_model.eval() | |
| print(f"β Thermal System: Loaded {MODEL_PATH}") | |
| except Exception as e: | |
| print(f"β CRITICAL ERROR: Could not load thermal model. {e}") | |
| # ========================================== | |
| # 2. SETUP GRAD-CAM (THE "X-RAY" HOOK) | |
| # ========================================== | |
| # We need to steal the features from inside the model while it thinks | |
| features_blob = [] | |
| def hook_feature(module, input, output): | |
| features_blob.clear() # Clear old data | |
| features_blob.append(output.data.cpu().numpy()) | |
| # Attach the spy hook to the last layer (Layer 4) | |
| thermal_model.layer4.register_forward_hook(hook_feature) | |
| # Get weights from the final decision layer | |
| params = list(thermal_model.parameters()) | |
| weight_softmax = params[-2].data.cpu().numpy() # The weights connecting features to "Mine/Safe" | |
| # ========================================== | |
| # 3. PROCESSING FUNCTIONS | |
| # ========================================== | |
| def run_visual_detection(image): | |
| if image is None: return None, "Waiting for feed..." | |
| results = yolo_model.predict(image, conf=0.65) | |
| return results[0].plot(), f"Objects Detected: {len(results[0].boxes)}" | |
| def run_thermal_scan(image): | |
| if image is None: return None, "No Signal", "N/A" | |
| # --- PREPROCESSING (Standard) --- | |
| if len(image.shape) == 3: | |
| gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY) | |
| else: | |
| gray = image | |
| clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8)) | |
| enhanced_img = clahe.apply(gray) | |
| resized = cv2.resize(enhanced_img, (224, 224)) | |
| normalized_img = resized.astype(np.float32) / 255.0 | |
| tensor = torch.from_numpy(normalized_img).float().unsqueeze(0).unsqueeze(0) | |
| tensor = tensor.to(device) | |
| # --- INFERENCE --- | |
| with torch.no_grad(): | |
| output = thermal_model(tensor) | |
| prob = torch.sigmoid(output).item() | |
| # --- GENERATE HEATMAP (Explainable AI) --- | |
| # 1. Get the features captured by our hook [1, 512, 7, 7] | |
| feature_data = features_blob[0] | |
| # 2. Calculate the "Attention Map" | |
| cam = np.zeros((7, 7), dtype=np.float32) | |
| # Use the weights for the "Mine" class to weight the features | |
| for i, w in enumerate(weight_softmax[0]): | |
| cam += w * feature_data[0, i, :, :] | |
| # 3. Process the CAM | |
| cam = np.maximum(cam, 0) # ReLU (Remove negative influence) | |
| cam = cv2.resize(cam, (224, 224)) # Resize to image size | |
| cam = cam - np.min(cam) | |
| if np.max(cam) != 0: | |
| cam = cam / np.max(cam) # Normalize 0-1 | |
| # 4. Colorize | |
| heatmap = cv2.applyColorMap(np.uint8(255*cam), cv2.COLORMAP_JET) | |
| # 5. Overlay on original (Grayscale -> RGB) | |
| orig_rgb = cv2.cvtColor(resized, cv2.COLOR_GRAY2RGB) | |
| # Mix: 60% Original Image + 40% Heatmap | |
| # If prob is low (Safe), we show less heatmap so it doesn't look scary | |
| intensity = 0.5 if prob > 0.5 else 0.2 | |
| overlay = cv2.addWeighted(orig_rgb, 1-intensity, heatmap, intensity, 0) | |
| # --- DECISION --- | |
| is_mine = prob > 0.50 | |
| if is_mine: | |
| label = f"<h2 style='color: red; text-align: center;'>π΄ MINE DETECTED</h2>" | |
| conf_text = f"CONFIDENCE: {prob*100:.1f}%" | |
| else: | |
| label = f"<h2 style='color: green; text-align: center;'>π’ SAFE SOIL</h2>" | |
| conf_text = f"Risk Level: {prob*100:.1f}%" | |
| return overlay, label, conf_text | |
| # ========================================== | |
| # 4. DASHBOARD UI | |
| # ========================================== | |
| custom_css = ".gradio-container {background-color: #1e1e1e; color: white}" | |
| with gr.Blocks(css=custom_css, title="EAGLE A7 Mission Control") as demo: | |
| gr.Markdown("# π¦ EAGLE A7: Autonomous Demining Interface") | |
| with gr.Tabs(): | |
| with gr.TabItem("βοΈ Daytime Vision"): | |
| with gr.Row(): | |
| vis_input = gr.Image(label="Input", type="numpy") | |
| vis_output = gr.Image(label="YOLO Detections") | |
| vis_btn = gr.Button("SCAN", variant="primary") | |
| vis_status = gr.Textbox(label="Status") | |
| vis_btn.click(run_visual_detection, inputs=vis_input, outputs=[vis_output, vis_status]) | |
| with gr.TabItem("π Night Vision (X-Ray)"): | |
| gr.Markdown("### Thermal Anomaly Localization") | |
| with gr.Row(): | |
| with gr.Column(): | |
| therm_input = gr.Image(label="Thermal Feed", type="numpy") | |
| therm_btn = gr.Button("ANALYZE & LOCATE", variant="stop") | |
| with gr.Column(): | |
| therm_output = gr.Image(label="Target Localization (Heatmap)") | |
| therm_label = gr.HTML(label="Result") | |
| therm_conf = gr.Textbox(label="Telemetry") | |
| therm_btn.click(run_thermal_scan, inputs=therm_input, outputs=[therm_output, therm_label, therm_conf]) | |
| print("--- Launching Dashboard with X-Ray Vision ---") | |
| demo.launch(server_name="0.0.0.0", share=True) |