Spaces:
Sleeping
Sleeping
| """ | |
| Model Manager โ Easy version swapping for Building Detection models. | |
| To swap models: | |
| 1. Upload new model to HF repo (e.g., v6/model_final.pth) | |
| 2. Set MODEL_VERSION env var to "v6" | |
| 3. Restart the Space | |
| """ | |
| import os | |
| import torch | |
| from detectron2.config import get_cfg | |
| from detectron2 import model_zoo | |
| from detectron2.engine import DefaultPredictor | |
| from huggingface_hub import hf_hub_download | |
| # ========================================== | |
| # === Configuration === | |
| # ========================================== | |
| MODEL_REPO = os.environ.get("MODEL_REPO", "yusef75/building-detection-models") | |
| MODEL_VERSION = os.environ.get("MODEL_VERSION", "v5") | |
| MODEL_FILENAME = os.environ.get("MODEL_FILENAME", "model_final.pth") | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| SCORE_THRESHOLD = float(os.environ.get("SCORE_THRESHOLD", "0.3")) | |
| # Global predictor | |
| _predictor = None | |
| _model_info = {} | |
| def load_model(): | |
| """Load model from Hugging Face Hub. Called once at startup.""" | |
| global _predictor, _model_info | |
| print(f"๐ Loading model: {MODEL_REPO} / {MODEL_VERSION} / {MODEL_FILENAME}") | |
| print(f"๐ฅ๏ธ Device: {DEVICE}") | |
| # Download model from HF Hub | |
| model_path = hf_hub_download( | |
| repo_id=MODEL_REPO, | |
| filename=f"{MODEL_VERSION}/{MODEL_FILENAME}", | |
| cache_dir="/tmp/models", | |
| ) | |
| print(f"โ Model downloaded to: {model_path}") | |
| # Configure Detectron2 | |
| cfg = get_cfg() | |
| cfg.merge_from_file(model_zoo.get_config_file( | |
| "COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")) | |
| cfg.MODEL.ROI_HEADS.NUM_CLASSES = 1 | |
| cfg.MODEL.WEIGHTS = model_path | |
| cfg.MODEL.DEVICE = DEVICE | |
| cfg.INPUT.MIN_SIZE_TEST = 512 | |
| cfg.INPUT.MAX_SIZE_TEST = 512 | |
| # === Detection quality settings === | |
| # Low base threshold โ actual filtering happens in inference.py | |
| cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.1 | |
| # NMS: Aggressively remove overlapping detections (lower = stricter) | |
| cfg.MODEL.ROI_HEADS.NMS_THRESH_TEST = 0.3 | |
| # Max detections per image (ุฑูุนูุง ุงูุญุฏ ุนุดุงู ููุดู ูู ุงูู ุจุงูู ูู ุงูู ูุงุทู ุงููุซููุฉ) | |
| cfg.TEST.DETECTIONS_PER_IMAGE = 500 | |
| _predictor = DefaultPredictor(cfg) | |
| _model_info = { | |
| "version": MODEL_VERSION, | |
| "repo": MODEL_REPO, | |
| "device": DEVICE, | |
| "threshold": SCORE_THRESHOLD, | |
| } | |
| print(f"๐ Model {MODEL_VERSION} loaded on {DEVICE}!") | |
| return _predictor | |
| def get_predictor(): | |
| """Get the loaded predictor. Loads model if not loaded yet.""" | |
| global _predictor | |
| if _predictor is None: | |
| load_model() | |
| return _predictor | |
| def get_model_info(): | |
| """Get info about the currently loaded model.""" | |
| return _model_info | |
| def set_threshold(threshold: float): | |
| """Update the detection threshold dynamically.""" | |
| global _predictor | |
| if _predictor is not None: | |
| _predictor.cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = threshold | |