building-detection / model_manager.py
yusef
Increase DETECTIONS_PER_IMAGE to 500 per tile
5d0fd2e
"""
Model Manager โ€” Easy version swapping for Building Detection models.
To swap models:
1. Upload new model to HF repo (e.g., v6/model_final.pth)
2. Set MODEL_VERSION env var to "v6"
3. Restart the Space
"""
import os
import torch
from detectron2.config import get_cfg
from detectron2 import model_zoo
from detectron2.engine import DefaultPredictor
from huggingface_hub import hf_hub_download
# ==========================================
# === Configuration ===
# ==========================================
MODEL_REPO = os.environ.get("MODEL_REPO", "yusef75/building-detection-models")
MODEL_VERSION = os.environ.get("MODEL_VERSION", "v5")
MODEL_FILENAME = os.environ.get("MODEL_FILENAME", "model_final.pth")
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
SCORE_THRESHOLD = float(os.environ.get("SCORE_THRESHOLD", "0.3"))
# Global predictor
_predictor = None
_model_info = {}
def load_model():
"""Load model from Hugging Face Hub. Called once at startup."""
global _predictor, _model_info
print(f"๐Ÿ” Loading model: {MODEL_REPO} / {MODEL_VERSION} / {MODEL_FILENAME}")
print(f"๐Ÿ–ฅ๏ธ Device: {DEVICE}")
# Download model from HF Hub
model_path = hf_hub_download(
repo_id=MODEL_REPO,
filename=f"{MODEL_VERSION}/{MODEL_FILENAME}",
cache_dir="/tmp/models",
)
print(f"โœ… Model downloaded to: {model_path}")
# Configure Detectron2
cfg = get_cfg()
cfg.merge_from_file(model_zoo.get_config_file(
"COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"))
cfg.MODEL.ROI_HEADS.NUM_CLASSES = 1
cfg.MODEL.WEIGHTS = model_path
cfg.MODEL.DEVICE = DEVICE
cfg.INPUT.MIN_SIZE_TEST = 512
cfg.INPUT.MAX_SIZE_TEST = 512
# === Detection quality settings ===
# Low base threshold โ€” actual filtering happens in inference.py
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.1
# NMS: Aggressively remove overlapping detections (lower = stricter)
cfg.MODEL.ROI_HEADS.NMS_THRESH_TEST = 0.3
# Max detections per image (ุฑูุนู†ุง ุงู„ุญุฏ ุนุดุงู† ู†ูƒุดู ูƒู„ ุงู„ู…ุจุงู†ูŠ ููŠ ุงู„ู…ู†ุงุทู‚ ุงู„ูƒุซูŠูุฉ)
cfg.TEST.DETECTIONS_PER_IMAGE = 500
_predictor = DefaultPredictor(cfg)
_model_info = {
"version": MODEL_VERSION,
"repo": MODEL_REPO,
"device": DEVICE,
"threshold": SCORE_THRESHOLD,
}
print(f"๐Ÿš€ Model {MODEL_VERSION} loaded on {DEVICE}!")
return _predictor
def get_predictor():
"""Get the loaded predictor. Loads model if not loaded yet."""
global _predictor
if _predictor is None:
load_model()
return _predictor
def get_model_info():
"""Get info about the currently loaded model."""
return _model_info
def set_threshold(threshold: float):
"""Update the detection threshold dynamically."""
global _predictor
if _predictor is not None:
_predictor.cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = threshold