Leacb4 commited on
Commit
161be00
Β·
verified Β·
1 Parent(s): 79a1985

Upload example_usage.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. example_usage.py +49 -17
example_usage.py CHANGED
@@ -1,9 +1,13 @@
1
  #!/usr/bin/env python3
2
  """
3
- Example usage of models from Hugging Face
 
 
 
4
  """
5
 
6
  import torch
 
7
  from PIL import Image
8
  from transformers import CLIPProcessor, CLIPModel as CLIPModel_transformers
9
  from huggingface_hub import hf_hub_download
@@ -11,9 +15,9 @@ import json
11
  import os
12
 
13
  # Import local models (to adapt to your structure)
14
- from color_model import ColorCLIP, SimpleTokenizer
15
  from hierarchy_model import Model as HierarchyModel, HierarchyExtractor
16
- from config import color_emb_dim, hierarchy_emb_dim
17
 
18
  def load_models_from_hf(repo_id: str, cache_dir: str = "./models_cache"):
19
  """
@@ -25,7 +29,7 @@ def load_models_from_hf(repo_id: str, cache_dir: str = "./models_cache"):
25
  """
26
 
27
  os.makedirs(cache_dir, exist_ok=True)
28
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
29
 
30
  print(f"πŸ“₯ Loading models from '{repo_id}'...")
31
 
@@ -40,19 +44,19 @@ def load_models_from_hf(repo_id: str, cache_dir: str = "./models_cache"):
40
  # Loading vocabulary
41
  vocab_path = hf_hub_download(
42
  repo_id=repo_id,
43
- filename="tokenizer_vocab.json",
44
  cache_dir=cache_dir
45
  )
46
 
47
  with open(vocab_path, 'r') as f:
48
  vocab_dict = json.load(f)
49
 
50
- tokenizer = SimpleTokenizer()
51
  tokenizer.load_vocab(vocab_dict)
52
 
53
  checkpoint = torch.load(color_model_path, map_location=device)
54
  vocab_size = checkpoint['text_encoder.embedding.weight'].shape[0]
55
- color_model = ColorCLIP(vocab_size=vocab_size, embedding_dim=color_emb_dim).to(device)
56
  color_model.tokenizer = tokenizer
57
  color_model.load_state_dict(checkpoint)
58
  color_model.eval()
@@ -62,7 +66,7 @@ def load_models_from_hf(repo_id: str, cache_dir: str = "./models_cache"):
62
  print(" πŸ“¦ Loading hierarchy model...")
63
  hierarchy_model_path = hf_hub_download(
64
  repo_id=repo_id,
65
- filename="hierarchy_model.pth",
66
  cache_dir=cache_dir
67
  )
68
 
@@ -71,7 +75,7 @@ def load_models_from_hf(repo_id: str, cache_dir: str = "./models_cache"):
71
 
72
  hierarchy_model = HierarchyModel(
73
  num_hierarchy_classes=len(hierarchy_classes),
74
- embed_dim=hierarchy_emb_dim
75
  ).to(device)
76
  hierarchy_model.load_state_dict(hierarchy_checkpoint['model_state'])
77
 
@@ -84,7 +88,7 @@ def load_models_from_hf(repo_id: str, cache_dir: str = "./models_cache"):
84
  print(" πŸ“¦ Loading main CLIP model...")
85
  main_model_path = hf_hub_download(
86
  repo_id=repo_id,
87
- filename="gap_clip.pth",
88
  cache_dir=cache_dir
89
  )
90
 
@@ -141,22 +145,45 @@ def example_search(models, image_path: str = None, text_query: str = None):
141
  if text_query:
142
  print(f" πŸ“ Text query: '{text_query}'")
143
 
144
- # Obtenir les embeddings de couleur et hiΓ©rarchie
145
  color_emb = color_model.get_text_embeddings([text_query])
146
  hierarchy_emb = hierarchy_model.get_text_embeddings([text_query])
147
 
148
  print(f" 🎨 Color embedding: {color_emb.shape}")
 
149
  print(f" πŸ“‚ Hierarchy embedding: {hierarchy_emb.shape}")
 
150
 
151
- # Obtenir les embeddings du modèle principal
152
  text_inputs = processor(text=[text_query], padding=True, return_tensors="pt")
153
  text_inputs = {k: v.to(device) for k, v in text_inputs.items()}
154
 
155
  with torch.no_grad():
156
- outputs = main_model(**text_inputs)
157
- text_features = outputs.text_embeds
 
 
158
 
159
  print(f" 🎯 Main embedding: {text_features.shape}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
160
 
161
  if image_path and os.path.exists(image_path):
162
  print(f" πŸ–ΌοΈ Image: {image_path}")
@@ -167,8 +194,10 @@ def example_search(models, image_path: str = None, text_query: str = None):
167
  image_inputs = {k: v.to(device) for k, v in image_inputs.items()}
168
 
169
  with torch.no_grad():
170
- outputs = main_model(**image_inputs)
171
- image_features = outputs.image_embeds
 
 
172
 
173
  print(f" 🎯 Image embedding: {image_features.shape}")
174
 
@@ -192,7 +221,7 @@ if __name__ == "__main__":
192
  parser.add_argument(
193
  "--image",
194
  type=str,
195
- default=None,
196
  help="Path to an image"
197
  )
198
 
@@ -204,3 +233,6 @@ if __name__ == "__main__":
204
  # Example search
205
  example_search(models, image_path=args.image, text_query=args.text)
206
 
 
 
 
 
1
  #!/usr/bin/env python3
2
  """
3
+ Example usage of models from Hugging Face.
4
+ This file provides example code for loading and using the models (color, hierarchy, main)
5
+ from the Hugging Face Hub. It shows how to load models, extract embeddings,
6
+ and perform searches or similarity comparisons.
7
  """
8
 
9
  import torch
10
+ import torch.nn.functional as F
11
  from PIL import Image
12
  from transformers import CLIPProcessor, CLIPModel as CLIPModel_transformers
13
  from huggingface_hub import hf_hub_download
 
15
  import os
16
 
17
  # Import local models (to adapt to your structure)
18
+ from color_model import ColorCLIP, Tokenizer
19
  from hierarchy_model import Model as HierarchyModel, HierarchyExtractor
20
+ import config
21
 
22
  def load_models_from_hf(repo_id: str, cache_dir: str = "./models_cache"):
23
  """
 
29
  """
30
 
31
  os.makedirs(cache_dir, exist_ok=True)
32
+ device = config.device
33
 
34
  print(f"πŸ“₯ Loading models from '{repo_id}'...")
35
 
 
44
  # Loading vocabulary
45
  vocab_path = hf_hub_download(
46
  repo_id=repo_id,
47
+ filename=config.tokeniser_path,
48
  cache_dir=cache_dir
49
  )
50
 
51
  with open(vocab_path, 'r') as f:
52
  vocab_dict = json.load(f)
53
 
54
+ tokenizer = Tokenizer()
55
  tokenizer.load_vocab(vocab_dict)
56
 
57
  checkpoint = torch.load(color_model_path, map_location=device)
58
  vocab_size = checkpoint['text_encoder.embedding.weight'].shape[0]
59
+ color_model = ColorCLIP(vocab_size=vocab_size, embedding_dim=config.color_emb_dim).to(device)
60
  color_model.tokenizer = tokenizer
61
  color_model.load_state_dict(checkpoint)
62
  color_model.eval()
 
66
  print(" πŸ“¦ Loading hierarchy model...")
67
  hierarchy_model_path = hf_hub_download(
68
  repo_id=repo_id,
69
+ filename=config.hierarchy_model_path,
70
  cache_dir=cache_dir
71
  )
72
 
 
75
 
76
  hierarchy_model = HierarchyModel(
77
  num_hierarchy_classes=len(hierarchy_classes),
78
+ embed_dim=config.hierarchy_emb_dim
79
  ).to(device)
80
  hierarchy_model.load_state_dict(hierarchy_checkpoint['model_state'])
81
 
 
88
  print(" πŸ“¦ Loading main CLIP model...")
89
  main_model_path = hf_hub_download(
90
  repo_id=repo_id,
91
+ filename=config.main_model_path,
92
  cache_dir=cache_dir
93
  )
94
 
 
145
  if text_query:
146
  print(f" πŸ“ Text query: '{text_query}'")
147
 
148
+ # Get color and hierarchy embeddings
149
  color_emb = color_model.get_text_embeddings([text_query])
150
  hierarchy_emb = hierarchy_model.get_text_embeddings([text_query])
151
 
152
  print(f" 🎨 Color embedding: {color_emb.shape}")
153
+ print(f"color_emb: {color_emb}")
154
  print(f" πŸ“‚ Hierarchy embedding: {hierarchy_emb.shape}")
155
+ print(f"hierarchy_emb: {hierarchy_emb}")
156
 
157
+ # Get main model embeddings
158
  text_inputs = processor(text=[text_query], padding=True, return_tensors="pt")
159
  text_inputs = {k: v.to(device) for k, v in text_inputs.items()}
160
 
161
  with torch.no_grad():
162
+ # Use text_model directly for text-only processing
163
+ text_outputs = main_model.text_model(**text_inputs)
164
+ text_features = main_model.text_projection(text_outputs.pooler_output)
165
+ text_features = F.normalize(text_features, dim=-1)
166
 
167
  print(f" 🎯 Main embedding: {text_features.shape}")
168
+ print(f" 🎯 First logits of main embedding: {text_features[0:10]}")
169
+
170
+ # Extract color and hierarchy embeddings from main embedding
171
+ main_color_emb = text_features[:, :config.color_emb_dim]
172
+ main_hierarchy_emb = text_features[:, config.color_emb_dim:config.color_emb_dim+config.hierarchy_emb_dim]
173
+
174
+ print(f"\n πŸ“Š Comparison:")
175
+ print(f" 🎨 Color embedding from color model: {color_emb[0]}")
176
+ print(f" 🎨 Color embedding from main model (first {config.color_emb_dim} dims): {main_color_emb[0]}")
177
+ print(f" πŸ“‚ Hierarchy embedding from hierarchy model: {hierarchy_emb[0]}")
178
+ print(f" πŸ“‚ Hierarchy embedding from main model (dims {config.color_emb_dim}-{config.color_emb_dim+config.hierarchy_emb_dim}): {main_hierarchy_emb[0]}")
179
+
180
+ # Calculate cosine similarity between color embeddings
181
+ color_cosine_sim = F.cosine_similarity(color_emb, main_color_emb, dim=1)
182
+ print(f"\n πŸ” Cosine similarity between color embeddings: {color_cosine_sim.item():.4f}")
183
+
184
+ # Calculate cosine similarity between hierarchy embeddings
185
+ hierarchy_cosine_sim = F.cosine_similarity(hierarchy_emb, main_hierarchy_emb, dim=1)
186
+ print(f" πŸ” Cosine similarity between hierarchy embeddings: {hierarchy_cosine_sim.item():.4f}")
187
 
188
  if image_path and os.path.exists(image_path):
189
  print(f" πŸ–ΌοΈ Image: {image_path}")
 
194
  image_inputs = {k: v.to(device) for k, v in image_inputs.items()}
195
 
196
  with torch.no_grad():
197
+ # Use vision_model directly for image-only processing
198
+ vision_outputs = main_model.vision_model(**image_inputs)
199
+ image_features = main_model.visual_projection(vision_outputs.pooler_output)
200
+ image_features = F.normalize(image_features, dim=-1)
201
 
202
  print(f" 🎯 Image embedding: {image_features.shape}")
203
 
 
221
  parser.add_argument(
222
  "--image",
223
  type=str,
224
+ default="red_dress.png",
225
  help="Path to an image"
226
  )
227
 
 
233
  # Example search
234
  example_search(models, image_path=args.image, text_query=args.text)
235
 
236
+
237
+
238
+