NishantAGI's picture
Update app.py
c330a5b verified
import gradio as gr
import torch
import torch.nn.functional as F
from transformers import SegformerForSemanticSegmentation
from torchvision import transforms
from PIL import Image
import numpy as np
import os
# 1. SETUP & MODEL LOADING
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SegformerForSemanticSegmentation.from_pretrained(
"nvidia/mit-b2",
num_labels=4,
id2label={0: "Soil", 1: "Bedrock", 2: "Sand", 3: "Big Rock"},
label2id={"Soil": 0, "Bedrock": 1, "Sand": 2, "Big Rock": 3},
ignore_mismatched_sizes=True
)
# Load your Stirling weights (Must be in the same folder as app.py)
try:
checkpoint = torch.load('SegFormer_B2_Final_Stirling_3456526New.pth', map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
print("Model Weights Loaded Successfully")
except Exception as e:
print(f"⚠️ Weights missing or error: {e}. Running with base weights.")
model.to(device).eval()
# 2. COLOR MAP (Brightened for visibility)
COLOR_MAP = {
0: [0, 255, 0], # Soil (Neon Green) - UPDATED
1: [0, 0, 255], # Bedrock (Electric Blue) - UPDATED
2: [255, 215, 0], # Sand (Yellow)
3: [255, 0, 0], # Big Rock (Red)
-1: [0, 0, 0]
}
def apply_mask_safe(preds, folder, img_path, suffix, w, h):
"""
Finds the mask by searching for the 9-digit SCLK ID inside the folder,
ignoring 'EDR', 'NLA', or other naming variations.
"""
filename = os.path.basename(img_path)
# 1. Extract the 9-digit numeric ID (SCLK)
# Example: NLA_601686301EDR... -> 601686301
import re
match = re.search(r'\d{9}', filename)
if not match:
print(f"❌ Could not find a 9-digit ID in {filename}")
return preds
seq_id = match.group(0)
print(f"DEBUG: Searching for ID {seq_id} in {folder}...")
# 2. Search the folder for a file containing this ID and the suffix (mxy/rng)
target_file = None
if os.path.exists(folder):
for f in os.listdir(folder):
# Check if the 9-digit ID is in the filename AND it's a .png
if seq_id in f and f.lower().endswith('.png'):
# Also check if it matches the specific suffix if needed
if suffix in f.lower():
target_file = f
break
# 3. Apply if found
if target_file:
path = os.path.join(folder, target_file)
mask = Image.open(path).convert('L').resize((w, h), Image.NEAREST)
mask_np = np.array(mask)
preds[mask_np > 0] = -1
print(f"βœ… SUCCESS: Applied mask from {target_file}")
else:
print(f"❌ FAIL: No mask for ID {seq_id} found in {folder}")
return preds
def segment_mars(img_path):
if not img_path: return None
# 1. Load Image
raw_img = Image.open(img_path).convert('RGB')
orig_w, orig_h = raw_img.size
# 2. Inference (SegFormer)
preprocess = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
input_tensor = preprocess(raw_img).unsqueeze(0).to(device)
with torch.no_grad():
outputs = model(pixel_values=input_tensor)
logits = F.interpolate(outputs.logits, size=(orig_h, orig_w), mode='bilinear')
probs = F.softmax(logits, dim=1)
confidences, preds = torch.max(probs, dim=1)
preds = preds.squeeze().cpu().numpy()
# 3. DIRECT FILENAME SWAP (EDR -> MXY / EDR -> RNG)
filename = os.path.basename(img_path)
# Generate mask names by replacing 'EDR' with 'MXY' or 'RNG'
# and changing extension to .png
mxy_filename = filename.replace("EDR", "MXY").replace(".JPG", ".png").replace(".jpg", ".png")
rng_filename = filename.replace("EDR", "RNG").replace(".JPG", ".png").replace(".jpg", ".png")
mxy_path = os.path.join("stirling_masks_bundle", "rover_mxy", mxy_filename)
rng_path = os.path.join("stirling_masks_bundle", "range_rng", rng_filename)
# Apply MXY Mask
if os.path.exists(mxy_path):
mxy = Image.open(mxy_path).convert('L').resize((orig_w, orig_h), Image.NEAREST)
preds[np.array(mxy) > 0] = -1
else:
print(f"❌ MXY Not Found: {mxy_path}")
# Apply RNG Mask
if os.path.exists(rng_path):
rng = Image.open(rng_path).convert('L').resize((orig_w, orig_h), Image.NEAREST)
preds[np.array(rng) > 0] = -1
else:
print(f"❌ RNG Not Found: {rng_path}")
# 4. OVERLAY
mask_rgb = np.zeros((orig_h, orig_w, 3), dtype=np.uint8)
for cls_id, color in COLOR_MAP.items():
mask_rgb[preds == cls_id] = color
overlay = (np.array(raw_img) * 0.5 + mask_rgb * 0.5).astype(np.uint8)
return Image.fromarray(overlay)
# 3. CUSTOM HTML LEGEND
legend_html = """
<div style="text-align: center; margin-bottom: 20px;">
<p style="font-size: 1.2em; font-weight: bold; margin-bottom: 15px;">
πŸ” Select an image to segment from the examples below
</p>
<div style="display: flex; flex-wrap: wrap; justify-content: center; gap: 15px; font-weight: bold;">
<div style="display: flex; align-items: center; gap: 8px;">
<div style="width: 18px; height: 18px; background-color: rgb(0, 255, 0); border: 1px solid #ccc; border-radius: 3px;"></div>
<span>Soil</span>
</div>
<div style="display: flex; align-items: center; gap: 8px;">
<div style="width: 18px; height: 18px; background-color: rgb(0, 0, 255); border: 1px solid #ccc; border-radius: 3px;"></div>
<span>Bedrock</span>
</div>
<div style="display: flex; align-items: center; gap: 8px;">
<div style="width: 18px; height: 18px; background-color: rgb(255, 215, 0); border: 1px solid #ccc; border-radius: 3px;"></div>
<span>Sand</span>
</div>
<div style="display: flex; align-items: center; gap: 8px;">
<div style="width: 18px; height: 18px; background-color: rgb(255, 0, 0); border: 1px solid #ccc; border-radius: 3px;"></div>
<span>Big Rock</span>
</div>
<div style="display: flex; align-items: center; gap: 8px;">
<div style="width: 18px; height: 18px; background-color: rgb(0, 0, 0); border: 1px solid #ccc; border-radius: 3px;"></div>
<span>Rover/Background</span>
</div>
</div>
</div>
"""
# 3. GRADIO INTERFACE
with gr.Blocks() as demo:
gr.Markdown(f"## NASA Curiosity Rover Terrain Classifier")
gr.HTML(legend_html)
with gr.Row():
img_input = gr.Image(type="filepath", label="Input Martian Image",interactive=False)
img_output = gr.Image(type="pil", label="Fused Ground Truth Prediction")
btn = gr.Button("Execute Data Fusion Segmentation")
btn.click(segment_mars, inputs=img_input, outputs=img_output)
gr.Examples(
examples=[
["NLA_601686301EDR_F0732112NCAM00353M1.JPG"],
["NLB_436292094EDR_F0211028NCAM00257M1.JPG"],
["NLB_486005519EDR_F0481570NCAM07813M1.JPG"],
["NLB_519658137EDR_F0550000NCAM00654M1.JPG"],
["NLB_541230242EDR_F0611140NCAM07753M1.JPG"],
["NLB_621571338EDR_F0763002NCAM00207M1.JPG"]
],
inputs=img_input
)
if __name__ == "__main__":
demo.launch(share=True)