File size: 2,936 Bytes
df64c50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5d0fd2e
 
df64c50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
"""
Model Manager โ€” Easy version swapping for Building Detection models.

To swap models:
1. Upload new model to HF repo (e.g., v6/model_final.pth)
2. Set MODEL_VERSION env var to "v6"
3. Restart the Space
"""

import os
import torch
from detectron2.config import get_cfg
from detectron2 import model_zoo
from detectron2.engine import DefaultPredictor
from huggingface_hub import hf_hub_download

# ==========================================
# === Configuration ===
# ==========================================
MODEL_REPO = os.environ.get("MODEL_REPO", "yusef75/building-detection-models")
MODEL_VERSION = os.environ.get("MODEL_VERSION", "v5")
MODEL_FILENAME = os.environ.get("MODEL_FILENAME", "model_final.pth")
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
SCORE_THRESHOLD = float(os.environ.get("SCORE_THRESHOLD", "0.3"))

# Global predictor
_predictor = None
_model_info = {}


def load_model():
    """Load model from Hugging Face Hub. Called once at startup."""
    global _predictor, _model_info

    print(f"๐Ÿ” Loading model: {MODEL_REPO} / {MODEL_VERSION} / {MODEL_FILENAME}")
    print(f"๐Ÿ–ฅ๏ธ Device: {DEVICE}")

    # Download model from HF Hub
    model_path = hf_hub_download(
        repo_id=MODEL_REPO,
        filename=f"{MODEL_VERSION}/{MODEL_FILENAME}",
        cache_dir="/tmp/models",
    )
    print(f"โœ… Model downloaded to: {model_path}")

    # Configure Detectron2
    cfg = get_cfg()
    cfg.merge_from_file(model_zoo.get_config_file(
        "COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"))
    cfg.MODEL.ROI_HEADS.NUM_CLASSES = 1
    cfg.MODEL.WEIGHTS = model_path
    cfg.MODEL.DEVICE = DEVICE
    cfg.INPUT.MIN_SIZE_TEST = 512
    cfg.INPUT.MAX_SIZE_TEST = 512

    # === Detection quality settings ===
    # Low base threshold โ€” actual filtering happens in inference.py
    cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.1

    # NMS: Aggressively remove overlapping detections (lower = stricter)
    cfg.MODEL.ROI_HEADS.NMS_THRESH_TEST = 0.3

    # Max detections per image (ุฑูุนู†ุง ุงู„ุญุฏ ุนุดุงู† ู†ูƒุดู ูƒู„ ุงู„ู…ุจุงู†ูŠ ููŠ ุงู„ู…ู†ุงุทู‚ ุงู„ูƒุซูŠูุฉ)
    cfg.TEST.DETECTIONS_PER_IMAGE = 500

    _predictor = DefaultPredictor(cfg)
    _model_info = {
        "version": MODEL_VERSION,
        "repo": MODEL_REPO,
        "device": DEVICE,
        "threshold": SCORE_THRESHOLD,
    }
    print(f"๐Ÿš€ Model {MODEL_VERSION} loaded on {DEVICE}!")
    return _predictor


def get_predictor():
    """Get the loaded predictor. Loads model if not loaded yet."""
    global _predictor
    if _predictor is None:
        load_model()
    return _predictor


def get_model_info():
    """Get info about the currently loaded model."""
    return _model_info


def set_threshold(threshold: float):
    """Update the detection threshold dynamically."""
    global _predictor
    if _predictor is not None:
        _predictor.cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = threshold