| import gradio as gr |
| import torch |
| import torch.nn.functional as F |
| from transformers import SegformerForSemanticSegmentation |
| from torchvision import transforms |
| from PIL import Image |
| import numpy as np |
| import os |
|
|
| |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
| model = SegformerForSemanticSegmentation.from_pretrained( |
| "nvidia/mit-b2", |
| num_labels=4, |
| id2label={0: "Soil", 1: "Bedrock", 2: "Sand", 3: "Big Rock"}, |
| label2id={"Soil": 0, "Bedrock": 1, "Sand": 2, "Big Rock": 3}, |
| ignore_mismatched_sizes=True |
| ) |
|
|
| |
| try: |
| checkpoint = torch.load('SegFormer_B2_Final_Stirling_3456526New.pth', map_location=device) |
| model.load_state_dict(checkpoint['model_state_dict']) |
| print("Model Weights Loaded Successfully") |
| except Exception as e: |
| print(f"β οΈ Weights missing or error: {e}. Running with base weights.") |
|
|
| model.to(device).eval() |
|
|
| |
| COLOR_MAP = { |
| 0: [0, 255, 0], |
| 1: [0, 0, 255], |
| 2: [255, 215, 0], |
| 3: [255, 0, 0], |
| -1: [0, 0, 0] |
| } |
|
|
| def apply_mask_safe(preds, folder, img_path, suffix, w, h): |
| """ |
| Finds the mask by searching for the 9-digit SCLK ID inside the folder, |
| ignoring 'EDR', 'NLA', or other naming variations. |
| """ |
| filename = os.path.basename(img_path) |
| |
| |
| |
| import re |
| match = re.search(r'\d{9}', filename) |
| if not match: |
| print(f"β Could not find a 9-digit ID in {filename}") |
| return preds |
| |
| seq_id = match.group(0) |
| print(f"DEBUG: Searching for ID {seq_id} in {folder}...") |
|
|
| |
| target_file = None |
| if os.path.exists(folder): |
| for f in os.listdir(folder): |
| |
| if seq_id in f and f.lower().endswith('.png'): |
| |
| if suffix in f.lower(): |
| target_file = f |
| break |
| |
| |
| if target_file: |
| path = os.path.join(folder, target_file) |
| mask = Image.open(path).convert('L').resize((w, h), Image.NEAREST) |
| mask_np = np.array(mask) |
| preds[mask_np > 0] = -1 |
| print(f"β
SUCCESS: Applied mask from {target_file}") |
| else: |
| print(f"β FAIL: No mask for ID {seq_id} found in {folder}") |
| |
| return preds |
|
|
| def segment_mars(img_path): |
| if not img_path: return None |
| |
| |
| raw_img = Image.open(img_path).convert('RGB') |
| orig_w, orig_h = raw_img.size |
| |
| |
| preprocess = transforms.Compose([ |
| transforms.Resize((256, 256)), |
| transforms.ToTensor(), |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
| ]) |
| input_tensor = preprocess(raw_img).unsqueeze(0).to(device) |
| |
| with torch.no_grad(): |
| outputs = model(pixel_values=input_tensor) |
| logits = F.interpolate(outputs.logits, size=(orig_h, orig_w), mode='bilinear') |
| probs = F.softmax(logits, dim=1) |
| confidences, preds = torch.max(probs, dim=1) |
| preds = preds.squeeze().cpu().numpy() |
|
|
| |
| filename = os.path.basename(img_path) |
| |
| |
| |
| mxy_filename = filename.replace("EDR", "MXY").replace(".JPG", ".png").replace(".jpg", ".png") |
| rng_filename = filename.replace("EDR", "RNG").replace(".JPG", ".png").replace(".jpg", ".png") |
|
|
| mxy_path = os.path.join("stirling_masks_bundle", "rover_mxy", mxy_filename) |
| rng_path = os.path.join("stirling_masks_bundle", "range_rng", rng_filename) |
|
|
| |
| if os.path.exists(mxy_path): |
| mxy = Image.open(mxy_path).convert('L').resize((orig_w, orig_h), Image.NEAREST) |
| preds[np.array(mxy) > 0] = -1 |
| |
| else: |
| print(f"β MXY Not Found: {mxy_path}") |
|
|
| |
| if os.path.exists(rng_path): |
| rng = Image.open(rng_path).convert('L').resize((orig_w, orig_h), Image.NEAREST) |
| preds[np.array(rng) > 0] = -1 |
| else: |
| print(f"β RNG Not Found: {rng_path}") |
|
|
| |
| mask_rgb = np.zeros((orig_h, orig_w, 3), dtype=np.uint8) |
| for cls_id, color in COLOR_MAP.items(): |
| mask_rgb[preds == cls_id] = color |
| |
| overlay = (np.array(raw_img) * 0.5 + mask_rgb * 0.5).astype(np.uint8) |
| return Image.fromarray(overlay) |
|
|
| |
| legend_html = """ |
| <div style="text-align: center; margin-bottom: 20px;"> |
| <p style="font-size: 1.2em; font-weight: bold; margin-bottom: 15px;"> |
| π Select an image to segment from the examples below |
| </p> |
| |
| <div style="display: flex; flex-wrap: wrap; justify-content: center; gap: 15px; font-weight: bold;"> |
| <div style="display: flex; align-items: center; gap: 8px;"> |
| <div style="width: 18px; height: 18px; background-color: rgb(0, 255, 0); border: 1px solid #ccc; border-radius: 3px;"></div> |
| <span>Soil</span> |
| </div> |
| <div style="display: flex; align-items: center; gap: 8px;"> |
| <div style="width: 18px; height: 18px; background-color: rgb(0, 0, 255); border: 1px solid #ccc; border-radius: 3px;"></div> |
| <span>Bedrock</span> |
| </div> |
| <div style="display: flex; align-items: center; gap: 8px;"> |
| <div style="width: 18px; height: 18px; background-color: rgb(255, 215, 0); border: 1px solid #ccc; border-radius: 3px;"></div> |
| <span>Sand</span> |
| </div> |
| <div style="display: flex; align-items: center; gap: 8px;"> |
| <div style="width: 18px; height: 18px; background-color: rgb(255, 0, 0); border: 1px solid #ccc; border-radius: 3px;"></div> |
| <span>Big Rock</span> |
| </div> |
| <div style="display: flex; align-items: center; gap: 8px;"> |
| <div style="width: 18px; height: 18px; background-color: rgb(0, 0, 0); border: 1px solid #ccc; border-radius: 3px;"></div> |
| <span>Rover/Background</span> |
| </div> |
| </div> |
| </div> |
| """ |
| |
| with gr.Blocks() as demo: |
| gr.Markdown(f"## NASA Curiosity Rover Terrain Classifier") |
| gr.HTML(legend_html) |
| with gr.Row(): |
| img_input = gr.Image(type="filepath", label="Input Martian Image",interactive=False) |
| img_output = gr.Image(type="pil", label="Fused Ground Truth Prediction") |
|
|
| btn = gr.Button("Execute Data Fusion Segmentation") |
| btn.click(segment_mars, inputs=img_input, outputs=img_output) |
|
|
| gr.Examples( |
| examples=[ |
| ["NLA_601686301EDR_F0732112NCAM00353M1.JPG"], |
| ["NLB_436292094EDR_F0211028NCAM00257M1.JPG"], |
| ["NLB_486005519EDR_F0481570NCAM07813M1.JPG"], |
| ["NLB_519658137EDR_F0550000NCAM00654M1.JPG"], |
| ["NLB_541230242EDR_F0611140NCAM07753M1.JPG"], |
| ["NLB_621571338EDR_F0763002NCAM00207M1.JPG"] |
| ], |
| inputs=img_input |
| ) |
|
|
| if __name__ == "__main__": |
| demo.launch(share=True) |