import gradio as gr import torch import torch.nn.functional as F from transformers import SegformerForSemanticSegmentation from torchvision import transforms from PIL import Image import numpy as np import os # 1. SETUP & MODEL LOADING device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = SegformerForSemanticSegmentation.from_pretrained( "nvidia/mit-b2", num_labels=4, id2label={0: "Soil", 1: "Bedrock", 2: "Sand", 3: "Big Rock"}, label2id={"Soil": 0, "Bedrock": 1, "Sand": 2, "Big Rock": 3}, ignore_mismatched_sizes=True ) # Load your Stirling weights (Must be in the same folder as app.py) try: checkpoint = torch.load('SegFormer_B2_Final_Stirling_3456526New.pth', map_location=device) model.load_state_dict(checkpoint['model_state_dict']) print("Model Weights Loaded Successfully") except Exception as e: print(f"⚠️ Weights missing or error: {e}. Running with base weights.") model.to(device).eval() # 2. COLOR MAP (Brightened for visibility) COLOR_MAP = { 0: [0, 255, 0], # Soil (Neon Green) - UPDATED 1: [0, 0, 255], # Bedrock (Electric Blue) - UPDATED 2: [255, 215, 0], # Sand (Yellow) 3: [255, 0, 0], # Big Rock (Red) -1: [0, 0, 0] } def apply_mask_safe(preds, folder, img_path, suffix, w, h): """ Finds the mask by searching for the 9-digit SCLK ID inside the folder, ignoring 'EDR', 'NLA', or other naming variations. """ filename = os.path.basename(img_path) # 1. Extract the 9-digit numeric ID (SCLK) # Example: NLA_601686301EDR... -> 601686301 import re match = re.search(r'\d{9}', filename) if not match: print(f"❌ Could not find a 9-digit ID in {filename}") return preds seq_id = match.group(0) print(f"DEBUG: Searching for ID {seq_id} in {folder}...") # 2. Search the folder for a file containing this ID and the suffix (mxy/rng) target_file = None if os.path.exists(folder): for f in os.listdir(folder): # Check if the 9-digit ID is in the filename AND it's a .png if seq_id in f and f.lower().endswith('.png'): # Also check if it matches the specific suffix if needed if suffix in f.lower(): target_file = f break # 3. Apply if found if target_file: path = os.path.join(folder, target_file) mask = Image.open(path).convert('L').resize((w, h), Image.NEAREST) mask_np = np.array(mask) preds[mask_np > 0] = -1 print(f"✅ SUCCESS: Applied mask from {target_file}") else: print(f"❌ FAIL: No mask for ID {seq_id} found in {folder}") return preds def segment_mars(img_path): if not img_path: return None # 1. Load Image raw_img = Image.open(img_path).convert('RGB') orig_w, orig_h = raw_img.size # 2. Inference (SegFormer) preprocess = transforms.Compose([ transforms.Resize((256, 256)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) input_tensor = preprocess(raw_img).unsqueeze(0).to(device) with torch.no_grad(): outputs = model(pixel_values=input_tensor) logits = F.interpolate(outputs.logits, size=(orig_h, orig_w), mode='bilinear') probs = F.softmax(logits, dim=1) confidences, preds = torch.max(probs, dim=1) preds = preds.squeeze().cpu().numpy() # 3. DIRECT FILENAME SWAP (EDR -> MXY / EDR -> RNG) filename = os.path.basename(img_path) # Generate mask names by replacing 'EDR' with 'MXY' or 'RNG' # and changing extension to .png mxy_filename = filename.replace("EDR", "MXY").replace(".JPG", ".png").replace(".jpg", ".png") rng_filename = filename.replace("EDR", "RNG").replace(".JPG", ".png").replace(".jpg", ".png") mxy_path = os.path.join("stirling_masks_bundle", "rover_mxy", mxy_filename) rng_path = os.path.join("stirling_masks_bundle", "range_rng", rng_filename) # Apply MXY Mask if os.path.exists(mxy_path): mxy = Image.open(mxy_path).convert('L').resize((orig_w, orig_h), Image.NEAREST) preds[np.array(mxy) > 0] = -1 else: print(f"❌ MXY Not Found: {mxy_path}") # Apply RNG Mask if os.path.exists(rng_path): rng = Image.open(rng_path).convert('L').resize((orig_w, orig_h), Image.NEAREST) preds[np.array(rng) > 0] = -1 else: print(f"❌ RNG Not Found: {rng_path}") # 4. OVERLAY mask_rgb = np.zeros((orig_h, orig_w, 3), dtype=np.uint8) for cls_id, color in COLOR_MAP.items(): mask_rgb[preds == cls_id] = color overlay = (np.array(raw_img) * 0.5 + mask_rgb * 0.5).astype(np.uint8) return Image.fromarray(overlay) # 3. CUSTOM HTML LEGEND legend_html = """
🔍 Select an image to segment from the examples below