Spaces:
Sleeping
Sleeping
File size: 2,527 Bytes
4f55ca2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 |
import torch
import clip
import open_clip
import h5py
from huggingface_hub import hf_hub_download
from app_lib.utils import SUPPORTED_MODELS
def _get_open_clip_model(model_name, device):
backbone = model_name.split(":")[-1]
model, _, preprocess = open_clip.create_model_and_transforms(
SUPPORTED_MODELS[model_name], device=device
)
model.eval()
tokenizer = open_clip.get_tokenizer(backbone)
return model, preprocess, tokenizer
def _get_clip_model(model_name, device):
backbone = model_name.split(":")[-1]
model, preprocess = clip.load(backbone, device=device)
tokenizer = clip.tokenize
return model, preprocess, tokenizer
def load_dataset(dataset_name, model_name):
dataset_path = hf_hub_download(
repo_id="jacopoteneggi/IBYDMT",
filename=f"{dataset_name}_{model_name}_train.h5",
repo_type="dataset",
)
with h5py.File(dataset_path, "r") as dataset:
embedding = dataset["embedding"][:]
return embedding
def load_model(model_name, device):
print(model_name)
if "open_clip" in model_name:
model, preprocess, tokenizer = _get_open_clip_model(model_name, device)
elif "clip" in model_name:
model, preprocess, tokenizer = _get_clip_model(model_name, device)
return model, preprocess, tokenizer
@torch.no_grad()
@torch.cuda.amp.autocast()
def encode_concepts(tokenizer, model, concepts, device):
concepts_text = tokenizer(concepts).to(device)
concept_features = model.encode_text(concepts_text)
concept_features /= torch.linalg.norm(concept_features, dim=-1, keepdim=True)
return concept_features.cpu().numpy()
@torch.no_grad()
@torch.cuda.amp.autocast()
def encode_image(model, preprocess, image, device):
image = preprocess(image)
image = image.unsqueeze(0)
image = image.to(device)
image_features = model.encode_image(image)
image_features /= image_features.norm(dim=-1, keepdim=True)
return image_features.cpu().numpy()
@torch.no_grad()
@torch.cuda.amp.autocast()
def encode_class_name(tokenizer, model, class_name, device):
class_text = tokenizer([f"A photo of a {class_name}"]).to(device)
class_features = model.encode_text(class_text)
class_features /= torch.linalg.norm(class_features, dim=-1, keepdim=True)
return class_features.cpu().numpy()
def test(image, class_name, concepts, cardinality, dataset_name, model_name, device):
model, preprocess = load_model(model_name, device)
print(f"loaded {model_name}")
|