TerraEye / io-app-backend /evaluate_system.py
janpiechota's picture
quick fix
de0410f
import torch
import terramindFunctions as tm
from terratorch import FULL_MODEL_REGISTRY
from metrics import calculate_precision,calculate_recall,calculate_accuracy, calculate_miou, calculate_fw_iou, calculate_dice_score
# [CONFIGURATION] Device setup
DEVICE = tm.device
def load_model(model_name):
"""Loads selected model from TerraTorch registry. Falls back to Large version if loading fails."""
print(f"[LOADING] Loading model: {model_name}...")
try:
model = FULL_MODEL_REGISTRY.build(
model_name,
modalities=["S2L2A"],
output_modalities=["LULC"],
pretrained=True,
standardize=True,
).to(DEVICE)
model.eval()
return model
except Exception as e:
print(f"[WARNING] Model loading error {model_name}: {e}")
print(f"[FALLBACK] Attempting terramind_v1_large_generate...")
try:
model = FULL_MODEL_REGISTRY.build(
"terramind_v1_large_generate",
modalities=["S2L2A"],
output_modalities=["LULC"],
pretrained=True,
standardize=True,
).to(DEVICE)
model.eval()
return model
except Exception as e2:
print(f"[ERROR] Fallback model loading error: {e2}")
return None
def run_evaluation_with_models(lat, lon, buffer_km=5, max_cloud_cover=20, days_back=120, model_a_name=None, model_b_name=None):
"""
Runs comparison of two selected models on satellite imagery.
Downloads data once and processes both models with spectral corrections.
Computes metrics for both raw and corrected outputs.
Args:
lat: latitude coordinate
lon: longitude coordinate
buffer_km: search radius in kilometers
model_a_name: name of first model (default: terramind_v1_small_generate)
model_b_name: name of second model (default: terramind_v1_large_generate)
Returns:
dict with comparison metrics, class maps, and imagery data
"""
if model_a_name is None:
model_a_name = 'terramind_v1_small_generate'
if model_b_name is None:
model_b_name = 'terramind_v1_large_generate'
print(f"[COMPARE] Model comparison for: {lat}, {lon}")
print(f" Model A: {model_a_name}")
print(f" Model B: {model_b_name}")
# [DOWNLOAD] Download data once for both models (time efficient)
dl_result = tm.download_sentinel2(lat, lon, buffer_km, max_cloud_cover, days_back)
if dl_result is None:
return {"error": "Satellite data unavailable for given analysis parameters."}
raw_data, date, scene_id = dl_result
# [DIMENSIONS] Save original image dimensions before scaling to 224x224
original_height, original_width = raw_data.shape[1], raw_data.shape[2]
print(f"[DIMENSIONS] Original image size: {original_width}x{original_height}")
# [PREPARE] Prepare common data for both models
input_tensor = tm.prepare_input(raw_data)
# [INDICES] Calculate spectral indices once and use for both models (consistency)
indices = tm.calculate_spectral_indices(input_tensor)
# ==========================================
# [MODEL_A] Model A processing
# ==========================================
print(f"[PROCESSING] Processing: {model_a_name}...")
model_a = load_model(model_a_name)
if model_a is None:
return {"error": f"Error loading model {model_a_name}"}
raw_output_a = tm.run_inference(model_a, input_tensor)
map_a_raw = tm.decode_output(raw_output_a)
# [CORRECTIONS] Apply spectral corrections
map_a, _ = tm.apply_hybrid_corrections(map_a_raw, indices)
del model_a
# ==========================================
# [MODEL_B] Model B processing
# ==========================================
print(f"[PROCESSING] Processing: {model_b_name}...")
model_b = load_model(model_b_name)
if model_b is None:
return {"error": f"Error loading model {model_b_name}"}
raw_output_b = tm.run_inference(model_b, input_tensor)
map_b_raw = tm.decode_output(raw_output_b)
# [CORRECTIONS] Apply spectral corrections
map_b, _ = tm.apply_hybrid_corrections(map_b_raw, indices)
del model_b
# [CLEANUP] Memory cleanup
if torch.cuda.is_available():
torch.cuda.empty_cache()
# ==========================================
# [METRICS] Calculate evaluation metrics
# ==========================================
print("[METRICS] Computing metrics...")
# [RAW_METRICS] Metrics for RAW segmentation (without spectral indices)
print(" [RAW] Computing metrics for raw segmentation (without spectral indices)...")
acc_raw = calculate_accuracy(map_a_raw, map_b_raw)
miou_raw, iou_details_raw = calculate_miou(map_a_raw, map_b_raw)
fw_iou_raw = calculate_fw_iou(map_a_raw, map_b_raw)
dice_raw = calculate_dice_score(map_a_raw, map_b_raw)
mean_precision_raw, precision_details_raw = calculate_precision(map_a_raw, map_b_raw)
mean_recall_raw, recall_details_raw = calculate_recall(map_a_raw, map_b_raw)
# [COMBINE_RAW] Combine details for RAW
combined_details_raw = {}
for class_name in iou_details_raw.keys():
combined_details_raw[class_name] = {
"iou": iou_details_raw.get(class_name, 0.0),
"precision": precision_details_raw.get(class_name, 0.0),
"recall": recall_details_raw.get(class_name, 0.0)
}
# [CORRECTED_METRICS] Metrics for CORRECTED segmentation (with spectral corrections)
print(" [CORRECTED] Computing metrics for corrected segmentation (with spectral indices)...")
acc = calculate_accuracy(map_a, map_b)
miou, iou_details = calculate_miou(map_a, map_b)
fw_iou = calculate_fw_iou(map_a, map_b)
dice = calculate_dice_score(map_a, map_b)
# [NEW_METRICS] Invoke precision and recall functions
mean_precision, precision_details = calculate_precision(map_a, map_b)
mean_recall, recall_details = calculate_recall(map_a, map_b)
# [COMBINE_CORRECTED] Combine details for CORRECTED
combined_details = {}
# [KEYS] Extract class names from IoU details (always computed)
for class_name in iou_details.keys():
combined_details[class_name] = {
"iou": iou_details.get(class_name, 0.0),
"precision": precision_details.get(class_name, 0.0),
"recall": recall_details.get(class_name, 0.0)
}
return {
"status": "success",
"metrics": {
"raw": {
"accuracy": acc_raw,
"miou": miou_raw,
"fw_iou": fw_iou_raw,
"dice": dice_raw,
"mean_precision": mean_precision_raw,
"mean_recall": mean_recall_raw,
"class_details": combined_details_raw
},
"corrected": {
"accuracy": acc,
"miou": miou,
"fw_iou": fw_iou,
"dice": dice,
"mean_precision": mean_precision,
"mean_recall": mean_recall,
"class_details": combined_details
}
},
"maps": {
"modelA": map_a,
"modelB": map_b,
"modelA_raw": map_a_raw,
"modelB_raw": map_b_raw
},
"raw_data": raw_data,
"input_tensor": input_tensor,
"indices": indices,
"date": date,
"image_width": original_width,
"image_height": original_height
}