File size: 7,439 Bytes
97be488
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c330a5b
97be488
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
00a8495
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97be488
 
 
 
 
00a8495
97be488
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
import gradio as gr
import torch
import torch.nn.functional as F
from transformers import SegformerForSemanticSegmentation
from torchvision import transforms
from PIL import Image
import numpy as np
import os

# 1. SETUP & MODEL LOADING
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = SegformerForSemanticSegmentation.from_pretrained(
    "nvidia/mit-b2", 
    num_labels=4, 
    id2label={0: "Soil", 1: "Bedrock", 2: "Sand", 3: "Big Rock"},
    label2id={"Soil": 0, "Bedrock": 1, "Sand": 2, "Big Rock": 3},
    ignore_mismatched_sizes=True
)

# Load your Stirling weights (Must be in the same folder as app.py)
try:
    checkpoint = torch.load('SegFormer_B2_Final_Stirling_3456526New.pth', map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    print("Model Weights Loaded Successfully")
except Exception as e:
    print(f"⚠️ Weights missing or error: {e}. Running with base weights.")

model.to(device).eval()

# 2. COLOR MAP (Brightened for visibility)
COLOR_MAP = {
    0: [0, 255, 0],      # Soil (Neon Green) - UPDATED
    1: [0, 0, 255],      # Bedrock (Electric Blue) - UPDATED
    2: [255, 215, 0],    # Sand (Yellow)
    3: [255, 0, 0],      # Big Rock (Red)
    -1: [0, 0, 0]
}

def apply_mask_safe(preds, folder, img_path, suffix, w, h):
    """
    Finds the mask by searching for the 9-digit SCLK ID inside the folder,
    ignoring 'EDR', 'NLA', or other naming variations.
    """
    filename = os.path.basename(img_path)
    
    # 1. Extract the 9-digit numeric ID (SCLK)
    # Example: NLA_601686301EDR... -> 601686301
    import re
    match = re.search(r'\d{9}', filename)
    if not match:
        print(f"❌ Could not find a 9-digit ID in {filename}")
        return preds
    
    seq_id = match.group(0)
    print(f"DEBUG: Searching for ID {seq_id} in {folder}...")

    # 2. Search the folder for a file containing this ID and the suffix (mxy/rng)
    target_file = None
    if os.path.exists(folder):
        for f in os.listdir(folder):
            # Check if the 9-digit ID is in the filename AND it's a .png
            if seq_id in f and f.lower().endswith('.png'):
                # Also check if it matches the specific suffix if needed
                if suffix in f.lower():
                    target_file = f
                    break
    
    # 3. Apply if found
    if target_file:
        path = os.path.join(folder, target_file)
        mask = Image.open(path).convert('L').resize((w, h), Image.NEAREST)
        mask_np = np.array(mask)
        preds[mask_np > 0] = -1
        print(f"βœ… SUCCESS: Applied mask from {target_file}")
    else:
        print(f"❌ FAIL: No mask for ID {seq_id} found in {folder}")
        
    return preds

def segment_mars(img_path):
    if not img_path: return None
    
    # 1. Load Image
    raw_img = Image.open(img_path).convert('RGB')
    orig_w, orig_h = raw_img.size
    
    # 2. Inference (SegFormer)
    preprocess = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    input_tensor = preprocess(raw_img).unsqueeze(0).to(device)
    
    with torch.no_grad():
        outputs = model(pixel_values=input_tensor)
        logits = F.interpolate(outputs.logits, size=(orig_h, orig_w), mode='bilinear')
        probs = F.softmax(logits, dim=1)
        confidences, preds = torch.max(probs, dim=1)
        preds = preds.squeeze().cpu().numpy()

    # 3. DIRECT FILENAME SWAP (EDR -> MXY / EDR -> RNG)
    filename = os.path.basename(img_path)
    
    # Generate mask names by replacing 'EDR' with 'MXY' or 'RNG'
    # and changing extension to .png
    mxy_filename = filename.replace("EDR", "MXY").replace(".JPG", ".png").replace(".jpg", ".png")
    rng_filename = filename.replace("EDR", "RNG").replace(".JPG", ".png").replace(".jpg", ".png")

    mxy_path = os.path.join("stirling_masks_bundle", "rover_mxy", mxy_filename)
    rng_path = os.path.join("stirling_masks_bundle", "range_rng", rng_filename)

    # Apply MXY Mask
    if os.path.exists(mxy_path):
        mxy = Image.open(mxy_path).convert('L').resize((orig_w, orig_h), Image.NEAREST)
        preds[np.array(mxy) > 0] = -1
        
    else:
        print(f"❌ MXY Not Found: {mxy_path}")

    # Apply RNG Mask
    if os.path.exists(rng_path):
        rng = Image.open(rng_path).convert('L').resize((orig_w, orig_h), Image.NEAREST)
        preds[np.array(rng) > 0] = -1
    else:
        print(f"❌ RNG Not Found: {rng_path}")

    # 4. OVERLAY
    mask_rgb = np.zeros((orig_h, orig_w, 3), dtype=np.uint8)
    for cls_id, color in COLOR_MAP.items():
        mask_rgb[preds == cls_id] = color
        
    overlay = (np.array(raw_img) * 0.5 + mask_rgb * 0.5).astype(np.uint8)
    return Image.fromarray(overlay)

# 3. CUSTOM HTML LEGEND
legend_html = """
<div style="text-align: center; margin-bottom: 20px;">
    <p style="font-size: 1.2em; font-weight: bold; margin-bottom: 15px;">
        πŸ” Select an image to segment from the examples below
    </p>
    
    <div style="display: flex; flex-wrap: wrap; justify-content: center; gap: 15px; font-weight: bold;">
        <div style="display: flex; align-items: center; gap: 8px;">
            <div style="width: 18px; height: 18px; background-color: rgb(0, 255, 0); border: 1px solid #ccc; border-radius: 3px;"></div> 
            <span>Soil</span>
        </div>
        <div style="display: flex; align-items: center; gap: 8px;">
            <div style="width: 18px; height: 18px; background-color: rgb(0, 0, 255); border: 1px solid #ccc; border-radius: 3px;"></div> 
            <span>Bedrock</span>
        </div>
        <div style="display: flex; align-items: center; gap: 8px;">
            <div style="width: 18px; height: 18px; background-color: rgb(255, 215, 0); border: 1px solid #ccc; border-radius: 3px;"></div> 
            <span>Sand</span>
        </div>
        <div style="display: flex; align-items: center; gap: 8px;">
            <div style="width: 18px; height: 18px; background-color: rgb(255, 0, 0); border: 1px solid #ccc; border-radius: 3px;"></div> 
            <span>Big Rock</span>
        </div>
        <div style="display: flex; align-items: center; gap: 8px;">
            <div style="width: 18px; height: 18px; background-color: rgb(0, 0, 0); border: 1px solid #ccc; border-radius: 3px;"></div> 
            <span>Rover/Background</span>
        </div>
    </div>
</div>
"""
# 3. GRADIO INTERFACE
with gr.Blocks() as demo:
    gr.Markdown(f"## NASA Curiosity Rover Terrain Classifier")
    gr.HTML(legend_html)
    with gr.Row():
        img_input = gr.Image(type="filepath", label="Input Martian Image",interactive=False)
        img_output = gr.Image(type="pil", label="Fused Ground Truth Prediction")

    btn = gr.Button("Execute Data Fusion Segmentation")
    btn.click(segment_mars, inputs=img_input, outputs=img_output)

    gr.Examples(
        examples=[
            ["NLA_601686301EDR_F0732112NCAM00353M1.JPG"],
            ["NLB_436292094EDR_F0211028NCAM00257M1.JPG"],
            ["NLB_486005519EDR_F0481570NCAM07813M1.JPG"],
            ["NLB_519658137EDR_F0550000NCAM00654M1.JPG"],
            ["NLB_541230242EDR_F0611140NCAM07753M1.JPG"],
            ["NLB_621571338EDR_F0763002NCAM00207M1.JPG"]
        ],
        inputs=img_input
    )

if __name__ == "__main__":
    demo.launch(share=True)