File size: 2,527 Bytes
4f55ca2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import torch
import clip
import open_clip
import h5py

from huggingface_hub import hf_hub_download

from app_lib.utils import SUPPORTED_MODELS


def _get_open_clip_model(model_name, device):
    backbone = model_name.split(":")[-1]

    model, _, preprocess = open_clip.create_model_and_transforms(
        SUPPORTED_MODELS[model_name], device=device
    )
    model.eval()
    tokenizer = open_clip.get_tokenizer(backbone)
    return model, preprocess, tokenizer


def _get_clip_model(model_name, device):
    backbone = model_name.split(":")[-1]
    model, preprocess = clip.load(backbone, device=device)
    tokenizer = clip.tokenize
    return model, preprocess, tokenizer


def load_dataset(dataset_name, model_name):
    dataset_path = hf_hub_download(
        repo_id="jacopoteneggi/IBYDMT",
        filename=f"{dataset_name}_{model_name}_train.h5",
        repo_type="dataset",
    )

    with h5py.File(dataset_path, "r") as dataset:
        embedding = dataset["embedding"][:]
    return embedding


def load_model(model_name, device):
    print(model_name)
    if "open_clip" in model_name:
        model, preprocess, tokenizer = _get_open_clip_model(model_name, device)
    elif "clip" in model_name:
        model, preprocess, tokenizer = _get_clip_model(model_name, device)
    return model, preprocess, tokenizer


@torch.no_grad()
@torch.cuda.amp.autocast()
def encode_concepts(tokenizer, model, concepts, device):
    concepts_text = tokenizer(concepts).to(device)

    concept_features = model.encode_text(concepts_text)
    concept_features /= torch.linalg.norm(concept_features, dim=-1, keepdim=True)
    return concept_features.cpu().numpy()


@torch.no_grad()
@torch.cuda.amp.autocast()
def encode_image(model, preprocess, image, device):
    image = preprocess(image)
    image = image.unsqueeze(0)
    image = image.to(device)

    image_features = model.encode_image(image)
    image_features /= image_features.norm(dim=-1, keepdim=True)
    return image_features.cpu().numpy()


@torch.no_grad()
@torch.cuda.amp.autocast()
def encode_class_name(tokenizer, model, class_name, device):
    class_text = tokenizer([f"A photo of a {class_name}"]).to(device)

    class_features = model.encode_text(class_text)
    class_features /= torch.linalg.norm(class_features, dim=-1, keepdim=True)
    return class_features.cpu().numpy()


def test(image, class_name, concepts, cardinality, dataset_name, model_name, device):
    model, preprocess = load_model(model_name, device)
    print(f"loaded {model_name}")