File size: 1,038 Bytes
b891e61 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 | """CLIPSeg model loading and freezing utilities."""
from transformers import CLIPSegForImageSegmentation, CLIPSegProcessor
def load_model_and_processor(model_name: str = "CIDAS/clipseg-rd64-refined", freeze_backbone: bool = True):
"""Load CLIPSeg model and processor, optionally freezing the backbone."""
model = CLIPSegForImageSegmentation.from_pretrained(model_name)
processor = CLIPSegProcessor.from_pretrained(model_name)
if freeze_backbone:
trainable, frozen = 0, 0
for name, param in model.named_parameters():
if "decoder" in name:
param.requires_grad = True
trainable += param.numel()
else:
param.requires_grad = False
frozen += param.numel()
print(f"Parameters — trainable (decoder): {trainable:,} | frozen (backbone): {frozen:,}")
else:
trainable = sum(p.numel() for p in model.parameters())
print(f"Parameters — all trainable: {trainable:,}")
return model, processor
|