File size: 2,456 Bytes
1be92b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7f0920d
 
 
1be92b1
 
 
7f0920d
1be92b1
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
"""
model.py — helpers to load BDAPPV baseline models from HuggingFace Hub.

Usage:
    from huggingface_hub import hf_hub_download
    import importlib.util, sys

    path = hf_hub_download("gabrielkasmi/bdappv-models", "model.py")
    spec = importlib.util.spec_from_file_location("bdappv_model", path)
    mod  = importlib.util.module_from_spec(spec); spec.loader.exec_module(mod)

    seg = mod.load_segmentation_model("google")   # or "ign"
    clf = mod.load_classification_model("google") # or "ign"
"""

import torch
import torch.nn as nn
from huggingface_hub import hf_hub_download
from torchvision.models.segmentation import deeplabv3_resnet101
from torchvision.models import inception_v3

REPO = "gabrielkasmi/bdappv-models"


def load_segmentation_model(provider: str = "google", device: str = "cpu"):
    """
    Load the DeepLabV3-ResNet101 segmentation model.

    Args:
        provider : "google" or "ign"
        device   : "cpu", "cuda", or "mps"

    Returns:
        model in eval mode
    """
    assert provider in ("google", "ign"), "provider must be 'google' or 'ign'"

    path  = hf_hub_download(REPO, f"deeplab_{provider}_best.pth")
    model = deeplabv3_resnet101(weights=None, aux_loss=False)
    model.classifier[-1] = nn.Conv2d(256, 1, kernel_size=1)

    state      = torch.load(path, map_location=device, weights_only=False)
    model_dict = model.state_dict()
    compatible = {k: v for k, v in state.items()
                  if k in model_dict and v.shape == model_dict[k].shape}
    model_dict.update(compatible)
    model.load_state_dict(model_dict)

    return model.eval().to(device)


def load_classification_model(provider: str = "google", device: str = "cpu"):
    """
    Load the InceptionV3 classification model (panel / no panel).

    Args:
        provider : "google" or "ign"
        device   : "cpu", "cuda", or "mps"

    Returns:
        model in eval mode
    """
    assert provider in ("google", "ign"), "provider must be 'google' or 'ign'"

    path  = hf_hub_download(REPO, f"inception_{provider}_best.pth")
    model = inception_v3(weights=None, aux_logits=True)
    model.fc           = nn.Linear(model.fc.in_features, 1)
    model.AuxLogits.fc = nn.Linear(model.AuxLogits.fc.in_features, 1)

    state = torch.load(path, map_location=device, weights_only=False)
    model.load_state_dict(state)
    model.aux_logits = False  # désactive pour l'inférence

    return model.eval().to(device)