Spaces:
Sleeping
Sleeping
| import torch | |
| import clip | |
| import open_clip | |
| import h5py | |
| from huggingface_hub import hf_hub_download | |
| from app_lib.utils import SUPPORTED_MODELS | |
| def _get_open_clip_model(model_name, device): | |
| backbone = model_name.split(":")[-1] | |
| model, _, preprocess = open_clip.create_model_and_transforms( | |
| SUPPORTED_MODELS[model_name], device=device | |
| ) | |
| model.eval() | |
| tokenizer = open_clip.get_tokenizer(backbone) | |
| return model, preprocess, tokenizer | |
| def _get_clip_model(model_name, device): | |
| backbone = model_name.split(":")[-1] | |
| model, preprocess = clip.load(backbone, device=device) | |
| tokenizer = clip.tokenize | |
| return model, preprocess, tokenizer | |
| def load_dataset(dataset_name, model_name): | |
| dataset_path = hf_hub_download( | |
| repo_id="jacopoteneggi/IBYDMT", | |
| filename=f"{dataset_name}_{model_name}_train.h5", | |
| repo_type="dataset", | |
| ) | |
| with h5py.File(dataset_path, "r") as dataset: | |
| embedding = dataset["embedding"][:] | |
| return embedding | |
| def load_model(model_name, device): | |
| print(model_name) | |
| if "open_clip" in model_name: | |
| model, preprocess, tokenizer = _get_open_clip_model(model_name, device) | |
| elif "clip" in model_name: | |
| model, preprocess, tokenizer = _get_clip_model(model_name, device) | |
| return model, preprocess, tokenizer | |
| def encode_concepts(tokenizer, model, concepts, device): | |
| concepts_text = tokenizer(concepts).to(device) | |
| concept_features = model.encode_text(concepts_text) | |
| concept_features /= torch.linalg.norm(concept_features, dim=-1, keepdim=True) | |
| return concept_features.cpu().numpy() | |
| def encode_image(model, preprocess, image, device): | |
| image = preprocess(image) | |
| image = image.unsqueeze(0) | |
| image = image.to(device) | |
| image_features = model.encode_image(image) | |
| image_features /= image_features.norm(dim=-1, keepdim=True) | |
| return image_features.cpu().numpy() | |
| def encode_class_name(tokenizer, model, class_name, device): | |
| class_text = tokenizer([f"A photo of a {class_name}"]).to(device) | |
| class_features = model.encode_text(class_text) | |
| class_features /= torch.linalg.norm(class_features, dim=-1, keepdim=True) | |
| return class_features.cpu().numpy() | |
| def test(image, class_name, concepts, cardinality, dataset_name, model_name, device): | |
| model, preprocess = load_model(model_name, device) | |
| print(f"loaded {model_name}") | |