import unittest import numpy as np import os import torch from transformers import SegformerFeatureExtractor, SegformerForSemanticSegmentation import sys sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) import model_utils from PIL import Image class TestInference(unittest.TestCase): @classmethod def setUpClass(cls): print("Setting up TestInference...") cls.model_name = "nvidia/segformer-b0-finetuned-ade-512-512" cls.feature_extractor, cls.model, cls.device = model_utils.load_model(cls.model_name) # Create dummy image cls.image = Image.new('RGB', (512, 512), color = 'red') def test_model_loading(self): self.assertIsNotNone(self.model) self.assertIsNotNone(self.feature_extractor) def test_prediction_shape(self): mask = model_utils.predict_mask(self.image, (self.feature_extractor, self.model, self.device)) self.assertEqual(mask.shape, (512, 512)) def test_safety_mapping(self): # Create a dummy mask with known IDs # 4: floor (safe), 10: grass (safe), 17: mountain (safe) # 12: person (hazard), 20: car (hazard) # 999: unknown mask = np.zeros((100, 100), dtype=np.int64) mask[0:10, :] = 4 # safe mask[10:20, :] = 12 # hazard # We need the real id2label to test logic, or mock it. # Let's use the loaded model's config id2label = self.model.config.id2label # Check if our assumptions about IDs hold specific to ADE20k # "floor" id might differ. # Instead, let's reverse lookup from the loaded model to be robust. label2id = {v: k for k, v in id2label.items()} # Find a safe label and a hazard label safe_lbl = "grass" hazard_lbl = "person" # Partial match lookup safe_id = -1 hazard_id = -1 for k, v in id2label.items(): if safe_lbl in v.lower(): safe_id = int(k) break for k, v in id2label.items(): if hazard_lbl in v.lower(): hazard_id = int(k) break if safe_id != -1: mask[0:50, :] = safe_id if hazard_id != -1: mask[50:100, :] = hazard_id safety_mask = model_utils.map_classes_to_safety(mask, id2label) # Top half should be 1 (safe), bottom 2 (hazard) if safe_id != -1: self.assertTrue(np.all(safety_mask[0:50, :] == 1)) if hazard_id != -1: self.assertTrue(np.all(safety_mask[50:100, :] == 2)) def test_stats_computation(self): mask = np.zeros((10, 10), dtype=np.int64) safety_mask = np.zeros((10, 10), dtype=np.uint8) # 50% safe, 50% hazard safety_mask[0:5, :] = 1 safety_mask[5:10, :] = 2 stats = model_utils.compute_stats(mask, safety_mask, self.model.config.id2label) self.assertEqual(stats["safe_pixels"], 50) self.assertEqual(stats["hazard_pixels"], 50) self.assertAlmostEqual(stats["safe_percentage"], 50.0) self.assertAlmostEqual(stats["hazard_percentage"], 50.0) if __name__ == '__main__': unittest.main()