|
|
|
|
|
""" |
|
|
Example usage of models from Hugging Face. |
|
|
This file provides example code for loading and using the models (color, hierarchy, main) |
|
|
from the Hugging Face Hub. It shows how to load models, extract embeddings, |
|
|
and perform searches or similarity comparisons. |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
from PIL import Image |
|
|
from transformers import CLIPProcessor, CLIPModel as CLIPModel_transformers |
|
|
from huggingface_hub import hf_hub_download |
|
|
import json |
|
|
import os |
|
|
|
|
|
|
|
|
from color_model import ColorCLIP, Tokenizer |
|
|
from hierarchy_model import Model as HierarchyModel, HierarchyExtractor |
|
|
import config |
|
|
|
|
|
def load_models_from_hf(repo_id: str, cache_dir: str = "./models_cache"): |
|
|
""" |
|
|
Load models from Hugging Face |
|
|
|
|
|
Args: |
|
|
repo_id: ID of the Hugging Face repository |
|
|
cache_dir: Local cache directory |
|
|
""" |
|
|
|
|
|
os.makedirs(cache_dir, exist_ok=True) |
|
|
device = config.device |
|
|
|
|
|
print(f"π₯ Loading models from '{repo_id}'...") |
|
|
|
|
|
|
|
|
print(" π¦ Loading color model...") |
|
|
color_model_path = hf_hub_download( |
|
|
repo_id=repo_id, |
|
|
filename="color_model.pt", |
|
|
cache_dir=cache_dir |
|
|
) |
|
|
|
|
|
|
|
|
vocab_path = hf_hub_download( |
|
|
repo_id=repo_id, |
|
|
filename=config.tokeniser_path, |
|
|
cache_dir=cache_dir |
|
|
) |
|
|
|
|
|
with open(vocab_path, 'r') as f: |
|
|
vocab_dict = json.load(f) |
|
|
|
|
|
tokenizer = Tokenizer() |
|
|
tokenizer.load_vocab(vocab_dict) |
|
|
|
|
|
checkpoint = torch.load(color_model_path, map_location=device) |
|
|
vocab_size = checkpoint['text_encoder.embedding.weight'].shape[0] |
|
|
color_model = ColorCLIP(vocab_size=vocab_size, embedding_dim=config.color_emb_dim).to(device) |
|
|
color_model.tokenizer = tokenizer |
|
|
color_model.load_state_dict(checkpoint) |
|
|
color_model.eval() |
|
|
print(" β
Color model loaded") |
|
|
|
|
|
|
|
|
print(" π¦ Loading hierarchy model...") |
|
|
hierarchy_model_path = hf_hub_download( |
|
|
repo_id=repo_id, |
|
|
filename=config.hierarchy_model_path, |
|
|
cache_dir=cache_dir |
|
|
) |
|
|
|
|
|
hierarchy_checkpoint = torch.load(hierarchy_model_path, map_location=device) |
|
|
hierarchy_classes = hierarchy_checkpoint.get('hierarchy_classes', []) |
|
|
|
|
|
hierarchy_model = HierarchyModel( |
|
|
num_hierarchy_classes=len(hierarchy_classes), |
|
|
embed_dim=config.hierarchy_emb_dim |
|
|
).to(device) |
|
|
hierarchy_model.load_state_dict(hierarchy_checkpoint['model_state']) |
|
|
|
|
|
hierarchy_extractor = HierarchyExtractor(hierarchy_classes, verbose=False) |
|
|
hierarchy_model.set_hierarchy_extractor(hierarchy_extractor) |
|
|
hierarchy_model.eval() |
|
|
print(" β
Hierarchy model loaded") |
|
|
|
|
|
|
|
|
print(" π¦ Loading main CLIP model...") |
|
|
main_model_path = hf_hub_download( |
|
|
repo_id=repo_id, |
|
|
filename=config.main_model_path, |
|
|
cache_dir=cache_dir |
|
|
) |
|
|
|
|
|
clip_model = CLIPModel_transformers.from_pretrained( |
|
|
'laion/CLIP-ViT-B-32-laion2B-s34B-b79K' |
|
|
) |
|
|
checkpoint = torch.load(main_model_path, map_location=device) |
|
|
|
|
|
|
|
|
if isinstance(checkpoint, dict): |
|
|
if 'model_state_dict' in checkpoint: |
|
|
clip_model.load_state_dict(checkpoint['model_state_dict']) |
|
|
else: |
|
|
|
|
|
clip_model.load_state_dict(checkpoint) |
|
|
else: |
|
|
clip_model.load_state_dict(checkpoint) |
|
|
|
|
|
clip_model = clip_model.to(device) |
|
|
clip_model.eval() |
|
|
|
|
|
processor = CLIPProcessor.from_pretrained('laion/CLIP-ViT-B-32-laion2B-s34B-b79K') |
|
|
print(" β
Main CLIP model loaded") |
|
|
|
|
|
print("\nβ
All models loaded!") |
|
|
|
|
|
return { |
|
|
'color_model': color_model, |
|
|
'hierarchy_model': hierarchy_model, |
|
|
'main_model': clip_model, |
|
|
'processor': processor, |
|
|
'device': device |
|
|
} |
|
|
|
|
|
|
|
|
def example_search(models, image_path: str = None, text_query: str = None): |
|
|
""" |
|
|
Example search with the models |
|
|
|
|
|
Args: |
|
|
models: Dictionary of loaded models |
|
|
image_path: Path to an image (optional) |
|
|
text_query: Text query (optional) |
|
|
""" |
|
|
|
|
|
color_model = models['color_model'] |
|
|
hierarchy_model = models['hierarchy_model'] |
|
|
main_model = models['main_model'] |
|
|
processor = models['processor'] |
|
|
device = models['device'] |
|
|
|
|
|
print("\nπ Example search...") |
|
|
|
|
|
if text_query: |
|
|
print(f" π Text query: '{text_query}'") |
|
|
|
|
|
|
|
|
color_emb = color_model.get_text_embeddings([text_query]) |
|
|
hierarchy_emb = hierarchy_model.get_text_embeddings([text_query]) |
|
|
|
|
|
print(f" π¨ Color embedding: {color_emb.shape}") |
|
|
print(f"color_emb: {color_emb}") |
|
|
print(f" π Hierarchy embedding: {hierarchy_emb.shape}") |
|
|
print(f"hierarchy_emb: {hierarchy_emb}") |
|
|
|
|
|
|
|
|
text_inputs = processor(text=[text_query], padding=True, return_tensors="pt") |
|
|
text_inputs = {k: v.to(device) for k, v in text_inputs.items()} |
|
|
|
|
|
with torch.no_grad(): |
|
|
|
|
|
text_outputs = main_model.text_model(**text_inputs) |
|
|
text_features = main_model.text_projection(text_outputs.pooler_output) |
|
|
text_features = F.normalize(text_features, dim=-1) |
|
|
|
|
|
print(f" π― Main embedding: {text_features.shape}") |
|
|
print(f" π― First logits of main embedding: {text_features[0:10]}") |
|
|
|
|
|
|
|
|
main_color_emb = text_features[:, :config.color_emb_dim] |
|
|
main_hierarchy_emb = text_features[:, config.color_emb_dim:config.color_emb_dim+config.hierarchy_emb_dim] |
|
|
|
|
|
print(f"\n π Comparison:") |
|
|
print(f" π¨ Color embedding from color model: {color_emb[0]}") |
|
|
print(f" π¨ Color embedding from main model (first {config.color_emb_dim} dims): {main_color_emb[0]}") |
|
|
print(f" π Hierarchy embedding from hierarchy model: {hierarchy_emb[0]}") |
|
|
print(f" π Hierarchy embedding from main model (dims {config.color_emb_dim}-{config.color_emb_dim+config.hierarchy_emb_dim}): {main_hierarchy_emb[0]}") |
|
|
|
|
|
|
|
|
color_cosine_sim = F.cosine_similarity(color_emb, main_color_emb, dim=1) |
|
|
print(f"\n π Cosine similarity between color embeddings: {color_cosine_sim.item():.4f}") |
|
|
|
|
|
|
|
|
hierarchy_cosine_sim = F.cosine_similarity(hierarchy_emb, main_hierarchy_emb, dim=1) |
|
|
print(f" π Cosine similarity between hierarchy embeddings: {hierarchy_cosine_sim.item():.4f}") |
|
|
|
|
|
if image_path and os.path.exists(image_path): |
|
|
print(f" πΌοΈ Image: {image_path}") |
|
|
image = Image.open(image_path).convert("RGB") |
|
|
|
|
|
|
|
|
image_inputs = processor(images=[image], return_tensors="pt") |
|
|
image_inputs = {k: v.to(device) for k, v in image_inputs.items()} |
|
|
|
|
|
with torch.no_grad(): |
|
|
|
|
|
vision_outputs = main_model.vision_model(**image_inputs) |
|
|
image_features = main_model.visual_projection(vision_outputs.pooler_output) |
|
|
image_features = F.normalize(image_features, dim=-1) |
|
|
|
|
|
print(f" π― Image embedding: {image_features.shape}") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
import argparse |
|
|
|
|
|
parser = argparse.ArgumentParser(description="Example usage of models") |
|
|
parser.add_argument( |
|
|
"--repo-id", |
|
|
type=str, |
|
|
required=True, |
|
|
help="ID of the Hugging Face repository" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--text", |
|
|
type=str, |
|
|
default="red dress", |
|
|
help="Text query for search" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--image", |
|
|
type=str, |
|
|
default="red_dress.png", |
|
|
help="Path to an image" |
|
|
) |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
models = load_models_from_hf(args.repo_id) |
|
|
|
|
|
|
|
|
example_search(models, image_path=args.image, text_query=args.text) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|