File size: 1,038 Bytes
b891e61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
"""CLIPSeg model loading and freezing utilities."""

from transformers import CLIPSegForImageSegmentation, CLIPSegProcessor


def load_model_and_processor(model_name: str = "CIDAS/clipseg-rd64-refined", freeze_backbone: bool = True):
    """Load CLIPSeg model and processor, optionally freezing the backbone."""
    model = CLIPSegForImageSegmentation.from_pretrained(model_name)
    processor = CLIPSegProcessor.from_pretrained(model_name)

    if freeze_backbone:
        trainable, frozen = 0, 0
        for name, param in model.named_parameters():
            if "decoder" in name:
                param.requires_grad = True
                trainable += param.numel()
            else:
                param.requires_grad = False
                frozen += param.numel()
        print(f"Parameters — trainable (decoder): {trainable:,} | frozen (backbone): {frozen:,}")
    else:
        trainable = sum(p.numel() for p in model.parameters())
        print(f"Parameters — all trainable: {trainable:,}")

    return model, processor