sravan837 commited on
Commit
64ce653
Β·
verified Β·
1 Parent(s): 9ff7581

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +72 -83
app.py CHANGED
@@ -14,11 +14,9 @@ print(f"--- System Boot: Using {device} ---")
14
 
15
  # --- LOAD VISUAL SYSTEM (YOLO) ---
16
  try:
17
- # Try loading your custom trained model first
18
  yolo_model = YOLO("best.pt")
19
  print("βœ… Visual System: Custom EAGLE A7 Model Loaded")
20
  except:
21
- print("⚠️ Visual System: Custom model not found. Loading standard YOLOv11n...")
22
  yolo_model = YOLO("yolo11n.pt")
23
 
24
  # --- LOAD THERMAL SYSTEM (ResNet-18) ---
@@ -29,139 +27,130 @@ def get_thermal_model():
29
  return model
30
 
31
  thermal_model = get_thermal_model().to(device)
32
-
33
  MODEL_PATH = "thermal_landmine_scanner.pth"
 
34
  try:
35
  thermal_model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
36
  thermal_model.eval()
37
  print(f"βœ… Thermal System: Loaded {MODEL_PATH}")
38
  except Exception as e:
39
  print(f"❌ CRITICAL ERROR: Could not load thermal model. {e}")
40
- print(" Make sure 'thermal_landmine_scanner.pth' is in this folder!")
41
 
42
  # ==========================================
43
- # 2. CORE PROCESSING FUNCTIONS
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  # ==========================================
45
 
46
  def run_visual_detection(image):
47
- """
48
- Daytime RGB Detection using YOLOv11
49
- """
50
- if image is None: return None, "Waiting for video feed..."
51
-
52
- # YOLO handles preprocessing internally
53
- results = yolo_model.predict(image, conf=0.40) # 40% Confidence Threshold
54
-
55
- # Render detections
56
- annotated_frame = results[0].plot()
57
-
58
- # Status Logic
59
- count = len(results[0].boxes)
60
- if count > 0:
61
- status = f"⚠️ DANGER: {count} Threat(s) Identified"
62
- else:
63
- status = "βœ… Sector Clear"
64
-
65
- return annotated_frame, status
66
 
67
  def run_thermal_scan(image):
68
- """
69
- Nighttime Thermal Analysis using ResNet-18
70
- Strict preprocessing to match training data.
71
- """
72
- if image is None: return "N/A", "No Signal"
73
-
74
- # --- STEP 1: FORCE GRAYSCALE ---
75
- # Gradio sends RGB. We need Grayscale.
76
- # We use RGB2GRAY because Gradio is RGB by default.
77
  if len(image.shape) == 3:
78
  gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
79
  else:
80
  gray = image
81
 
82
- # --- STEP 2: CLAHE ENHANCEMENT ---
83
- # This MUST match the training/debug script exactly.
84
  clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
85
  enhanced_img = clahe.apply(gray)
86
-
87
- # --- STEP 3: RESIZE ---
88
  resized = cv2.resize(enhanced_img, (224, 224))
89
-
90
- # --- STEP 4: NORMALIZE TO 0.0 - 1.0 ---
91
- # Convert to float32 and divide by 255
92
  normalized_img = resized.astype(np.float32) / 255.0
93
 
94
- # --- STEP 5: TENSOR CONVERSION ---
95
- # Shape: [1, 1, 224, 224]
96
  tensor = torch.from_numpy(normalized_img).float().unsqueeze(0).unsqueeze(0)
97
  tensor = tensor.to(device)
98
 
99
- # --- STEP 6: INFERENCE ---
100
  with torch.no_grad():
101
  output = thermal_model(tensor)
102
  prob = torch.sigmoid(output).item()
103
 
104
- # --- DEBUG LOGGING (Check your terminal!) ---
105
- # This proves if the "App" sees what the "Debug Script" saw.
106
- print(f"πŸ”Ž Thermal Scan | Input Mean: {normalized_img.mean():.4f} | Prob: {prob:.4f}")
107
-
108
- # --- STEP 7: DECISION LOGIC ---
109
- # Based on your perfect debug run (Mine=0.99, Safe=0.00),
110
- # we use a standard 0.50 threshold.
111
- THRESHOLD = 0.50
112
 
113
- is_mine = prob > THRESHOLD
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
 
 
 
 
 
 
 
 
 
 
 
115
  if is_mine:
116
- label = "πŸ”΄ MINE DETECTED"
117
- color_hex = "red"
118
- conf_text = f"THREAT CONFIDENCE: {prob*100:.1f}%"
119
  else:
120
- label = "🟒 SAFE SOIL"
121
- color_hex = "green"
122
- conf_text = f"Safety Confidence: {(1-prob)*100:.1f}%"
123
 
124
- # Return HTML for colored text
125
- html_output = f"<h2 style='color: {color_hex}; text-align: center;'>{label}</h2>"
126
- return html_output, conf_text
127
 
128
  # ==========================================
129
- # 3. DASHBOARD UI (Gradio)
130
  # ==========================================
131
- custom_css = """
132
- .gradio-container {background-color: #1e1e1e; color: white}
133
- """
134
 
135
  with gr.Blocks(css=custom_css, title="EAGLE A7 Mission Control") as demo:
136
  gr.Markdown("# πŸ¦… EAGLE A7: Autonomous Demining Interface")
137
- gr.Markdown("**Status:** System Online | **AI Engines:** YOLOv11 + Thermal ResNet-18")
138
 
139
  with gr.Tabs():
140
- # --- TAB 1: VISUAL ---
141
  with gr.TabItem("β˜€οΈ Daytime Vision"):
142
  with gr.Row():
143
- with gr.Column():
144
- vis_input = gr.Image(label="Drone Camera Feed", type="numpy")
145
- vis_btn = gr.Button("SCAN SECTOR", variant="primary")
146
- with gr.Column():
147
- vis_output = gr.Image(label="AI Analysis")
148
- vis_status = gr.Textbox(label="Mission Status", interactive=False)
149
-
150
  vis_btn.click(run_visual_detection, inputs=vis_input, outputs=[vis_output, vis_status])
151
 
152
- # --- TAB 2: THERMAL ---
153
- with gr.TabItem("πŸŒ™ Night Vision (Thermal)"):
154
  with gr.Row():
155
  with gr.Column():
156
- therm_input = gr.Image(label="Thermal Sensor Feed (Grayscale)", type="numpy")
157
- therm_btn = gr.Button("ANALYZE HEAT SIGNATURE", variant="stop")
158
  with gr.Column():
159
- # We use HTML for big colored alerts
160
  therm_label = gr.HTML(label="Result")
161
- therm_conf = gr.Textbox(label="Sensor Telemetry", interactive=False)
162
 
163
- therm_btn.click(run_thermal_scan, inputs=therm_input, outputs=[therm_label, therm_conf])
164
 
165
- # Launch
166
- print("--- Launching Dashboard ---")
167
  demo.launch(server_name="0.0.0.0", share=True)
 
14
 
15
  # --- LOAD VISUAL SYSTEM (YOLO) ---
16
  try:
 
17
  yolo_model = YOLO("best.pt")
18
  print("βœ… Visual System: Custom EAGLE A7 Model Loaded")
19
  except:
 
20
  yolo_model = YOLO("yolo11n.pt")
21
 
22
  # --- LOAD THERMAL SYSTEM (ResNet-18) ---
 
27
  return model
28
 
29
  thermal_model = get_thermal_model().to(device)
 
30
  MODEL_PATH = "thermal_landmine_scanner.pth"
31
+
32
  try:
33
  thermal_model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
34
  thermal_model.eval()
35
  print(f"βœ… Thermal System: Loaded {MODEL_PATH}")
36
  except Exception as e:
37
  print(f"❌ CRITICAL ERROR: Could not load thermal model. {e}")
 
38
 
39
  # ==========================================
40
+ # 2. SETUP GRAD-CAM (THE "X-RAY" HOOK)
41
+ # ==========================================
42
+ # We need to steal the features from inside the model while it thinks
43
+ features_blob = []
44
+ def hook_feature(module, input, output):
45
+ features_blob.clear() # Clear old data
46
+ features_blob.append(output.data.cpu().numpy())
47
+
48
+ # Attach the spy hook to the last layer (Layer 4)
49
+ thermal_model.layer4.register_forward_hook(hook_feature)
50
+
51
+ # Get weights from the final decision layer
52
+ params = list(thermal_model.parameters())
53
+ weight_softmax = params[-2].data.cpu().numpy() # The weights connecting features to "Mine/Safe"
54
+
55
+ # ==========================================
56
+ # 3. PROCESSING FUNCTIONS
57
  # ==========================================
58
 
59
  def run_visual_detection(image):
60
+ if image is None: return None, "Waiting for feed..."
61
+ results = yolo_model.predict(image, conf=0.40)
62
+ return results[0].plot(), f"Objects Detected: {len(results[0].boxes)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
  def run_thermal_scan(image):
65
+ if image is None: return None, "No Signal", "N/A"
66
+
67
+ # --- PREPROCESSING (Standard) ---
 
 
 
 
 
 
68
  if len(image.shape) == 3:
69
  gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
70
  else:
71
  gray = image
72
 
 
 
73
  clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
74
  enhanced_img = clahe.apply(gray)
 
 
75
  resized = cv2.resize(enhanced_img, (224, 224))
 
 
 
76
  normalized_img = resized.astype(np.float32) / 255.0
77
 
 
 
78
  tensor = torch.from_numpy(normalized_img).float().unsqueeze(0).unsqueeze(0)
79
  tensor = tensor.to(device)
80
 
81
+ # --- INFERENCE ---
82
  with torch.no_grad():
83
  output = thermal_model(tensor)
84
  prob = torch.sigmoid(output).item()
85
 
86
+ # --- GENERATE HEATMAP (Explainable AI) ---
87
+ # 1. Get the features captured by our hook [1, 512, 7, 7]
88
+ feature_data = features_blob[0]
 
 
 
 
 
89
 
90
+ # 2. Calculate the "Attention Map"
91
+ cam = np.zeros((7, 7), dtype=np.float32)
92
+ # Use the weights for the "Mine" class to weight the features
93
+ for i, w in enumerate(weight_softmax[0]):
94
+ cam += w * feature_data[0, i, :, :]
95
+
96
+ # 3. Process the CAM
97
+ cam = np.maximum(cam, 0) # ReLU (Remove negative influence)
98
+ cam = cv2.resize(cam, (224, 224)) # Resize to image size
99
+ cam = cam - np.min(cam)
100
+ if np.max(cam) != 0:
101
+ cam = cam / np.max(cam) # Normalize 0-1
102
+
103
+ # 4. Colorize
104
+ heatmap = cv2.applyColorMap(np.uint8(255*cam), cv2.COLORMAP_JET)
105
 
106
+ # 5. Overlay on original (Grayscale -> RGB)
107
+ orig_rgb = cv2.cvtColor(resized, cv2.COLOR_GRAY2RGB)
108
+
109
+ # Mix: 60% Original Image + 40% Heatmap
110
+ # If prob is low (Safe), we show less heatmap so it doesn't look scary
111
+ intensity = 0.5 if prob > 0.5 else 0.2
112
+ overlay = cv2.addWeighted(orig_rgb, 1-intensity, heatmap, intensity, 0)
113
+
114
+ # --- DECISION ---
115
+ is_mine = prob > 0.50
116
  if is_mine:
117
+ label = f"<h2 style='color: red; text-align: center;'>πŸ”΄ MINE DETECTED</h2>"
118
+ conf_text = f"CONFIDENCE: {prob*100:.1f}%"
 
119
  else:
120
+ label = f"<h2 style='color: green; text-align: center;'>🟒 SAFE SOIL</h2>"
121
+ conf_text = f"Risk Level: {prob*100:.1f}%"
 
122
 
123
+ return overlay, label, conf_text
 
 
124
 
125
  # ==========================================
126
+ # 4. DASHBOARD UI
127
  # ==========================================
128
+ custom_css = ".gradio-container {background-color: #1e1e1e; color: white}"
 
 
129
 
130
  with gr.Blocks(css=custom_css, title="EAGLE A7 Mission Control") as demo:
131
  gr.Markdown("# πŸ¦… EAGLE A7: Autonomous Demining Interface")
 
132
 
133
  with gr.Tabs():
 
134
  with gr.TabItem("β˜€οΈ Daytime Vision"):
135
  with gr.Row():
136
+ vis_input = gr.Image(label="Input", type="numpy")
137
+ vis_output = gr.Image(label="YOLO Detections")
138
+ vis_btn = gr.Button("SCAN", variant="primary")
139
+ vis_status = gr.Textbox(label="Status")
 
 
 
140
  vis_btn.click(run_visual_detection, inputs=vis_input, outputs=[vis_output, vis_status])
141
 
142
+ with gr.TabItem("πŸŒ™ Night Vision (X-Ray)"):
143
+ gr.Markdown("### Thermal Anomaly Localization")
144
  with gr.Row():
145
  with gr.Column():
146
+ therm_input = gr.Image(label="Thermal Feed", type="numpy")
147
+ therm_btn = gr.Button("ANALYZE & LOCATE", variant="stop")
148
  with gr.Column():
149
+ therm_output = gr.Image(label="Target Localization (Heatmap)")
150
  therm_label = gr.HTML(label="Result")
151
+ therm_conf = gr.Textbox(label="Telemetry")
152
 
153
+ therm_btn.click(run_thermal_scan, inputs=therm_input, outputs=[therm_output, therm_label, therm_conf])
154
 
155
+ print("--- Launching Dashboard with X-Ray Vision ---")
 
156
  demo.launch(server_name="0.0.0.0", share=True)