File size: 1,233 Bytes
e1aaaac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import sys
import open_clip
import torch


def load_open_clip(model_name: str = "ViT-B-32-quickgelu", pretrained: str = "laion400m_e32", cache_dir: str = None, device="cpu"):
    try:
        model, _, transform = open_clip.create_model_and_transforms(
            model_name, pretrained='openai', cache_dir=cache_dir, device='cpu'
        )
        if isinstance(pretrained, str):
            checkpoint = torch.load(pretrained, map_location=torch.device('cpu'))
        else:
            checkpoint = pretrained
        if 'vision_encoder_state_dict' in checkpoint.keys():  # tecoa checkpoint
            model.visual.load_state_dict(checkpoint['vision_encoder_state_dict'])
        else:
            model.visual.load_state_dict(checkpoint)
    except Exception as e:
        # try loading whole model
        print(f'error: {e}', file=sys.stderr)
        print('retrying by loading whole model..', file=sys.stderr)
        torch.cuda.empty_cache()
        model, _, transform = open_clip.create_model_and_transforms(
            model_name, pretrained=pretrained, cache_dir=cache_dir, device='cpu'
        )

    model = model.to(device)
    tokenizer = open_clip.get_tokenizer(model_name)
    return model, transform, tokenizer