File size: 2,589 Bytes
2a034f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import torch
import torch.nn as nn
from torchvision import models

class PairNet(nn.Module):
    """
    A lightweight backbone + dual-head network:
      - Regression head for days-to-delivery
      - Classification head for preterm probability
    This is a scaffold and not the proprietary model from the paper.
    """
    def __init__(self, backbone_name: str = "efficientnet_b0", pretrained: bool = True):
        super().__init__()
        if backbone_name == "efficientnet_b0":
            try:
                weights = models.EfficientNet_B0_Weights.DEFAULT if pretrained else None
            except Exception:
                weights = None
            backbone = models.efficientnet_b0(weights=weights)
            in_feats = backbone.classifier[1].in_features
            backbone.classifier = nn.Identity()
        else:
            # Fallback to resnet18
            try:
                weights = models.ResNet18_Weights.DEFAULT if pretrained else None
            except Exception:
                weights = None
            backbone = models.resnet18(weights=weights)
            in_feats = backbone.fc.in_features
            backbone.fc = nn.Identity()

        self.backbone = backbone
        self.reg_head = nn.Linear(in_feats, 1)
        self.cls_head = nn.Linear(in_feats, 1)

    def forward(self, x):
        feats = self.backbone(x)
        days = self.reg_head(feats)   # unconstrained
        logits = self.cls_head(feats) # unconstrained
        return days, logits

def load_weights_if_any(model: nn.Module, weights_path: str | None):
    if not weights_path:
        return False, "No weights path provided"
    import os
    if os.path.isfile(weights_path):
        state = torch.load(weights_path, map_location="cpu")
        if "state_dict" in state:
            state = state["state_dict"]
        missing, unexpected = model.load_state_dict(state, strict=False)
        return True, f"Loaded local weights. missing={len(missing)} unexpected={len(unexpected)}"
    # Try huggingface hub repo id
    try:
        from huggingface_hub import hf_hub_download
        fp = hf_hub_download(repo_id=weights_path, filename="pytorch_model.bin", local_dir="weights")
        state = torch.load(fp, map_location="cpu")
        if "state_dict" in state:
            state = state["state_dict"]
        missing, unexpected = model.load_state_dict(state, strict=False)
        return True, f"Loaded HF weights from {weights_path}. missing={len(missing)} unexpected={len(unexpected)}"
    except Exception as e:
        return False, f"Failed to load weights: {e}"