Upload example_usage.py with huggingface_hub
Browse files- example_usage.py +37 -37
example_usage.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
#!/usr/bin/env python3
|
| 2 |
"""
|
| 3 |
-
|
| 4 |
"""
|
| 5 |
|
| 6 |
import torch
|
|
@@ -10,34 +10,34 @@ from huggingface_hub import hf_hub_download
|
|
| 10 |
import json
|
| 11 |
import os
|
| 12 |
|
| 13 |
-
# Import
|
| 14 |
from color_model import ColorCLIP, SimpleTokenizer
|
| 15 |
from hierarchy_model import Model as HierarchyModel, HierarchyExtractor
|
| 16 |
from config import color_emb_dim, hierarchy_emb_dim
|
| 17 |
|
| 18 |
def load_models_from_hf(repo_id: str, cache_dir: str = "./models_cache"):
|
| 19 |
"""
|
| 20 |
-
|
| 21 |
|
| 22 |
Args:
|
| 23 |
-
repo_id: ID
|
| 24 |
-
cache_dir:
|
| 25 |
"""
|
| 26 |
|
| 27 |
os.makedirs(cache_dir, exist_ok=True)
|
| 28 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 29 |
|
| 30 |
-
print(f"📥
|
| 31 |
|
| 32 |
-
# 1.
|
| 33 |
-
print(" 📦
|
| 34 |
color_model_path = hf_hub_download(
|
| 35 |
repo_id=repo_id,
|
| 36 |
filename="color_model.pt",
|
| 37 |
cache_dir=cache_dir
|
| 38 |
)
|
| 39 |
|
| 40 |
-
#
|
| 41 |
vocab_path = hf_hub_download(
|
| 42 |
repo_id=repo_id,
|
| 43 |
filename="tokenizer_vocab.json",
|
|
@@ -56,10 +56,10 @@ def load_models_from_hf(repo_id: str, cache_dir: str = "./models_cache"):
|
|
| 56 |
color_model.tokenizer = tokenizer
|
| 57 |
color_model.load_state_dict(checkpoint)
|
| 58 |
color_model.eval()
|
| 59 |
-
print(" ✅
|
| 60 |
|
| 61 |
-
# 2.
|
| 62 |
-
print(" 📦
|
| 63 |
hierarchy_model_path = hf_hub_download(
|
| 64 |
repo_id=repo_id,
|
| 65 |
filename="hierarchy_model.pth",
|
|
@@ -78,13 +78,13 @@ def load_models_from_hf(repo_id: str, cache_dir: str = "./models_cache"):
|
|
| 78 |
hierarchy_extractor = HierarchyExtractor(hierarchy_classes, verbose=False)
|
| 79 |
hierarchy_model.set_hierarchy_extractor(hierarchy_extractor)
|
| 80 |
hierarchy_model.eval()
|
| 81 |
-
print(" ✅
|
| 82 |
|
| 83 |
-
# 3.
|
| 84 |
-
print(" 📦
|
| 85 |
main_model_path = hf_hub_download(
|
| 86 |
repo_id=repo_id,
|
| 87 |
-
filename="
|
| 88 |
cache_dir=cache_dir
|
| 89 |
)
|
| 90 |
|
|
@@ -93,12 +93,12 @@ def load_models_from_hf(repo_id: str, cache_dir: str = "./models_cache"):
|
|
| 93 |
)
|
| 94 |
checkpoint = torch.load(main_model_path, map_location=device)
|
| 95 |
|
| 96 |
-
#
|
| 97 |
if isinstance(checkpoint, dict):
|
| 98 |
if 'model_state_dict' in checkpoint:
|
| 99 |
clip_model.load_state_dict(checkpoint['model_state_dict'])
|
| 100 |
else:
|
| 101 |
-
#
|
| 102 |
clip_model.load_state_dict(checkpoint)
|
| 103 |
else:
|
| 104 |
clip_model.load_state_dict(checkpoint)
|
|
@@ -107,9 +107,9 @@ def load_models_from_hf(repo_id: str, cache_dir: str = "./models_cache"):
|
|
| 107 |
clip_model.eval()
|
| 108 |
|
| 109 |
processor = CLIPProcessor.from_pretrained('laion/CLIP-ViT-B-32-laion2B-s34B-b79K')
|
| 110 |
-
print(" ✅
|
| 111 |
|
| 112 |
-
print("\n✅
|
| 113 |
|
| 114 |
return {
|
| 115 |
'color_model': color_model,
|
|
@@ -122,12 +122,12 @@ def load_models_from_hf(repo_id: str, cache_dir: str = "./models_cache"):
|
|
| 122 |
|
| 123 |
def example_search(models, image_path: str = None, text_query: str = None):
|
| 124 |
"""
|
| 125 |
-
|
| 126 |
|
| 127 |
Args:
|
| 128 |
-
models:
|
| 129 |
-
image_path:
|
| 130 |
-
text_query:
|
| 131 |
"""
|
| 132 |
|
| 133 |
color_model = models['color_model']
|
|
@@ -136,17 +136,17 @@ def example_search(models, image_path: str = None, text_query: str = None):
|
|
| 136 |
processor = models['processor']
|
| 137 |
device = models['device']
|
| 138 |
|
| 139 |
-
print("\n🔍
|
| 140 |
|
| 141 |
if text_query:
|
| 142 |
-
print(f" 📝
|
| 143 |
|
| 144 |
# Obtenir les embeddings de couleur et hiérarchie
|
| 145 |
color_emb = color_model.get_text_embeddings([text_query])
|
| 146 |
hierarchy_emb = hierarchy_model.get_text_embeddings([text_query])
|
| 147 |
|
| 148 |
-
print(f" 🎨
|
| 149 |
-
print(f" 📂
|
| 150 |
|
| 151 |
# Obtenir les embeddings du modèle principal
|
| 152 |
text_inputs = processor(text=[text_query], padding=True, return_tensors="pt")
|
|
@@ -156,13 +156,13 @@ def example_search(models, image_path: str = None, text_query: str = None):
|
|
| 156 |
outputs = main_model(**text_inputs)
|
| 157 |
text_features = outputs.text_embeds
|
| 158 |
|
| 159 |
-
print(f" 🎯
|
| 160 |
|
| 161 |
if image_path and os.path.exists(image_path):
|
| 162 |
print(f" 🖼️ Image: {image_path}")
|
| 163 |
image = Image.open(image_path).convert("RGB")
|
| 164 |
|
| 165 |
-
#
|
| 166 |
image_inputs = processor(images=[image], return_tensors="pt")
|
| 167 |
image_inputs = {k: v.to(device) for k, v in image_inputs.items()}
|
| 168 |
|
|
@@ -170,37 +170,37 @@ def example_search(models, image_path: str = None, text_query: str = None):
|
|
| 170 |
outputs = main_model(**image_inputs)
|
| 171 |
image_features = outputs.image_embeds
|
| 172 |
|
| 173 |
-
print(f" 🎯
|
| 174 |
|
| 175 |
|
| 176 |
if __name__ == "__main__":
|
| 177 |
import argparse
|
| 178 |
|
| 179 |
-
parser = argparse.ArgumentParser(description="
|
| 180 |
parser.add_argument(
|
| 181 |
"--repo-id",
|
| 182 |
type=str,
|
| 183 |
required=True,
|
| 184 |
-
help="ID
|
| 185 |
)
|
| 186 |
parser.add_argument(
|
| 187 |
"--text",
|
| 188 |
type=str,
|
| 189 |
default="red dress",
|
| 190 |
-
help="
|
| 191 |
)
|
| 192 |
parser.add_argument(
|
| 193 |
"--image",
|
| 194 |
type=str,
|
| 195 |
default=None,
|
| 196 |
-
help="
|
| 197 |
)
|
| 198 |
|
| 199 |
args = parser.parse_args()
|
| 200 |
|
| 201 |
-
#
|
| 202 |
models = load_models_from_hf(args.repo_id)
|
| 203 |
|
| 204 |
-
#
|
| 205 |
example_search(models, image_path=args.image, text_query=args.text)
|
| 206 |
|
|
|
|
| 1 |
#!/usr/bin/env python3
|
| 2 |
"""
|
| 3 |
+
Example usage of models from Hugging Face
|
| 4 |
"""
|
| 5 |
|
| 6 |
import torch
|
|
|
|
| 10 |
import json
|
| 11 |
import os
|
| 12 |
|
| 13 |
+
# Import local models (to adapt to your structure)
|
| 14 |
from color_model import ColorCLIP, SimpleTokenizer
|
| 15 |
from hierarchy_model import Model as HierarchyModel, HierarchyExtractor
|
| 16 |
from config import color_emb_dim, hierarchy_emb_dim
|
| 17 |
|
| 18 |
def load_models_from_hf(repo_id: str, cache_dir: str = "./models_cache"):
|
| 19 |
"""
|
| 20 |
+
Load models from Hugging Face
|
| 21 |
|
| 22 |
Args:
|
| 23 |
+
repo_id: ID of the Hugging Face repository
|
| 24 |
+
cache_dir: Local cache directory
|
| 25 |
"""
|
| 26 |
|
| 27 |
os.makedirs(cache_dir, exist_ok=True)
|
| 28 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 29 |
|
| 30 |
+
print(f"📥 Loading models from '{repo_id}'...")
|
| 31 |
|
| 32 |
+
# 1. Loading color model
|
| 33 |
+
print(" 📦 Loading color model...")
|
| 34 |
color_model_path = hf_hub_download(
|
| 35 |
repo_id=repo_id,
|
| 36 |
filename="color_model.pt",
|
| 37 |
cache_dir=cache_dir
|
| 38 |
)
|
| 39 |
|
| 40 |
+
# Loading vocabulary
|
| 41 |
vocab_path = hf_hub_download(
|
| 42 |
repo_id=repo_id,
|
| 43 |
filename="tokenizer_vocab.json",
|
|
|
|
| 56 |
color_model.tokenizer = tokenizer
|
| 57 |
color_model.load_state_dict(checkpoint)
|
| 58 |
color_model.eval()
|
| 59 |
+
print(" ✅ Color model loaded")
|
| 60 |
|
| 61 |
+
# 2. Loading hierarchy model
|
| 62 |
+
print(" 📦 Loading hierarchy model...")
|
| 63 |
hierarchy_model_path = hf_hub_download(
|
| 64 |
repo_id=repo_id,
|
| 65 |
filename="hierarchy_model.pth",
|
|
|
|
| 78 |
hierarchy_extractor = HierarchyExtractor(hierarchy_classes, verbose=False)
|
| 79 |
hierarchy_model.set_hierarchy_extractor(hierarchy_extractor)
|
| 80 |
hierarchy_model.eval()
|
| 81 |
+
print(" ✅ Hierarchy model loaded")
|
| 82 |
|
| 83 |
+
# 3. Loading main CLIP model
|
| 84 |
+
print(" 📦 Loading main CLIP model...")
|
| 85 |
main_model_path = hf_hub_download(
|
| 86 |
repo_id=repo_id,
|
| 87 |
+
filename="gap_clip.pth",
|
| 88 |
cache_dir=cache_dir
|
| 89 |
)
|
| 90 |
|
|
|
|
| 93 |
)
|
| 94 |
checkpoint = torch.load(main_model_path, map_location=device)
|
| 95 |
|
| 96 |
+
# Handle different checkpoint structures
|
| 97 |
if isinstance(checkpoint, dict):
|
| 98 |
if 'model_state_dict' in checkpoint:
|
| 99 |
clip_model.load_state_dict(checkpoint['model_state_dict'])
|
| 100 |
else:
|
| 101 |
+
# If the checkpoint is directly the state_dict
|
| 102 |
clip_model.load_state_dict(checkpoint)
|
| 103 |
else:
|
| 104 |
clip_model.load_state_dict(checkpoint)
|
|
|
|
| 107 |
clip_model.eval()
|
| 108 |
|
| 109 |
processor = CLIPProcessor.from_pretrained('laion/CLIP-ViT-B-32-laion2B-s34B-b79K')
|
| 110 |
+
print(" ✅ Main CLIP model loaded")
|
| 111 |
|
| 112 |
+
print("\n✅ All models loaded!")
|
| 113 |
|
| 114 |
return {
|
| 115 |
'color_model': color_model,
|
|
|
|
| 122 |
|
| 123 |
def example_search(models, image_path: str = None, text_query: str = None):
|
| 124 |
"""
|
| 125 |
+
Example search with the models
|
| 126 |
|
| 127 |
Args:
|
| 128 |
+
models: Dictionary of loaded models
|
| 129 |
+
image_path: Path to an image (optional)
|
| 130 |
+
text_query: Text query (optional)
|
| 131 |
"""
|
| 132 |
|
| 133 |
color_model = models['color_model']
|
|
|
|
| 136 |
processor = models['processor']
|
| 137 |
device = models['device']
|
| 138 |
|
| 139 |
+
print("\n🔍 Example search...")
|
| 140 |
|
| 141 |
if text_query:
|
| 142 |
+
print(f" 📝 Text query: '{text_query}'")
|
| 143 |
|
| 144 |
# Obtenir les embeddings de couleur et hiérarchie
|
| 145 |
color_emb = color_model.get_text_embeddings([text_query])
|
| 146 |
hierarchy_emb = hierarchy_model.get_text_embeddings([text_query])
|
| 147 |
|
| 148 |
+
print(f" 🎨 Color embedding: {color_emb.shape}")
|
| 149 |
+
print(f" 📂 Hierarchy embedding: {hierarchy_emb.shape}")
|
| 150 |
|
| 151 |
# Obtenir les embeddings du modèle principal
|
| 152 |
text_inputs = processor(text=[text_query], padding=True, return_tensors="pt")
|
|
|
|
| 156 |
outputs = main_model(**text_inputs)
|
| 157 |
text_features = outputs.text_embeds
|
| 158 |
|
| 159 |
+
print(f" 🎯 Main embedding: {text_features.shape}")
|
| 160 |
|
| 161 |
if image_path and os.path.exists(image_path):
|
| 162 |
print(f" 🖼️ Image: {image_path}")
|
| 163 |
image = Image.open(image_path).convert("RGB")
|
| 164 |
|
| 165 |
+
# Get image embeddings
|
| 166 |
image_inputs = processor(images=[image], return_tensors="pt")
|
| 167 |
image_inputs = {k: v.to(device) for k, v in image_inputs.items()}
|
| 168 |
|
|
|
|
| 170 |
outputs = main_model(**image_inputs)
|
| 171 |
image_features = outputs.image_embeds
|
| 172 |
|
| 173 |
+
print(f" 🎯 Image embedding: {image_features.shape}")
|
| 174 |
|
| 175 |
|
| 176 |
if __name__ == "__main__":
|
| 177 |
import argparse
|
| 178 |
|
| 179 |
+
parser = argparse.ArgumentParser(description="Example usage of models")
|
| 180 |
parser.add_argument(
|
| 181 |
"--repo-id",
|
| 182 |
type=str,
|
| 183 |
required=True,
|
| 184 |
+
help="ID of the Hugging Face repository"
|
| 185 |
)
|
| 186 |
parser.add_argument(
|
| 187 |
"--text",
|
| 188 |
type=str,
|
| 189 |
default="red dress",
|
| 190 |
+
help="Text query for search"
|
| 191 |
)
|
| 192 |
parser.add_argument(
|
| 193 |
"--image",
|
| 194 |
type=str,
|
| 195 |
default=None,
|
| 196 |
+
help="Path to an image"
|
| 197 |
)
|
| 198 |
|
| 199 |
args = parser.parse_args()
|
| 200 |
|
| 201 |
+
# Load models
|
| 202 |
models = load_models_from_hf(args.repo_id)
|
| 203 |
|
| 204 |
+
# Example search
|
| 205 |
example_search(models, image_path=args.image, text_query=args.text)
|
| 206 |
|