Leacb4 commited on
Commit
512487a
·
verified ·
1 Parent(s): 9715189

Upload example_usage.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. example_usage.py +37 -37
example_usage.py CHANGED
@@ -1,6 +1,6 @@
1
  #!/usr/bin/env python3
2
  """
3
- Exemple d'utilisation des modèles depuis Hugging Face
4
  """
5
 
6
  import torch
@@ -10,34 +10,34 @@ from huggingface_hub import hf_hub_download
10
  import json
11
  import os
12
 
13
- # Import des modèles locaux (à adapter selon votre structure)
14
  from color_model import ColorCLIP, SimpleTokenizer
15
  from hierarchy_model import Model as HierarchyModel, HierarchyExtractor
16
  from config import color_emb_dim, hierarchy_emb_dim
17
 
18
  def load_models_from_hf(repo_id: str, cache_dir: str = "./models_cache"):
19
  """
20
- Charger les modèles depuis Hugging Face
21
 
22
  Args:
23
- repo_id: ID du repository Hugging Face
24
- cache_dir: Dossier de cache local
25
  """
26
 
27
  os.makedirs(cache_dir, exist_ok=True)
28
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
29
 
30
- print(f"📥 Chargement des modèles depuis '{repo_id}'...")
31
 
32
- # 1. Charger le modèle de couleur
33
- print(" 📦 Chargement du modèle de couleur...")
34
  color_model_path = hf_hub_download(
35
  repo_id=repo_id,
36
  filename="color_model.pt",
37
  cache_dir=cache_dir
38
  )
39
 
40
- # Charger le vocabulaire
41
  vocab_path = hf_hub_download(
42
  repo_id=repo_id,
43
  filename="tokenizer_vocab.json",
@@ -56,10 +56,10 @@ def load_models_from_hf(repo_id: str, cache_dir: str = "./models_cache"):
56
  color_model.tokenizer = tokenizer
57
  color_model.load_state_dict(checkpoint)
58
  color_model.eval()
59
- print(" ✅ Modèle de couleur chargé")
60
 
61
- # 2. Charger le modèle de hiérarchie
62
- print(" 📦 Chargement du modèle de hiérarchie...")
63
  hierarchy_model_path = hf_hub_download(
64
  repo_id=repo_id,
65
  filename="hierarchy_model.pth",
@@ -78,13 +78,13 @@ def load_models_from_hf(repo_id: str, cache_dir: str = "./models_cache"):
78
  hierarchy_extractor = HierarchyExtractor(hierarchy_classes, verbose=False)
79
  hierarchy_model.set_hierarchy_extractor(hierarchy_extractor)
80
  hierarchy_model.eval()
81
- print(" ✅ Modèle de hiérarchie chargé")
82
 
83
- # 3. Charger le modèle principal CLIP
84
- print(" 📦 Chargement du modèle principal CLIP...")
85
  main_model_path = hf_hub_download(
86
  repo_id=repo_id,
87
- filename="laion_explicable_model.pth",
88
  cache_dir=cache_dir
89
  )
90
 
@@ -93,12 +93,12 @@ def load_models_from_hf(repo_id: str, cache_dir: str = "./models_cache"):
93
  )
94
  checkpoint = torch.load(main_model_path, map_location=device)
95
 
96
- # Gérer différentes structures de checkpoint
97
  if isinstance(checkpoint, dict):
98
  if 'model_state_dict' in checkpoint:
99
  clip_model.load_state_dict(checkpoint['model_state_dict'])
100
  else:
101
- # Si le checkpoint est directement le state_dict
102
  clip_model.load_state_dict(checkpoint)
103
  else:
104
  clip_model.load_state_dict(checkpoint)
@@ -107,9 +107,9 @@ def load_models_from_hf(repo_id: str, cache_dir: str = "./models_cache"):
107
  clip_model.eval()
108
 
109
  processor = CLIPProcessor.from_pretrained('laion/CLIP-ViT-B-32-laion2B-s34B-b79K')
110
- print(" ✅ Modèle principal CLIP chargé")
111
 
112
- print("\n✅ Tous les modèles sont chargés!")
113
 
114
  return {
115
  'color_model': color_model,
@@ -122,12 +122,12 @@ def load_models_from_hf(repo_id: str, cache_dir: str = "./models_cache"):
122
 
123
  def example_search(models, image_path: str = None, text_query: str = None):
124
  """
125
- Exemple de recherche avec les modèles
126
 
127
  Args:
128
- models: Dictionnaire des modèles chargés
129
- image_path: Chemin vers une image (optionnel)
130
- text_query: Requête textuelle (optionnel)
131
  """
132
 
133
  color_model = models['color_model']
@@ -136,17 +136,17 @@ def example_search(models, image_path: str = None, text_query: str = None):
136
  processor = models['processor']
137
  device = models['device']
138
 
139
- print("\n🔍 Exemple de recherche...")
140
 
141
  if text_query:
142
- print(f" 📝 Requête textuelle: '{text_query}'")
143
 
144
  # Obtenir les embeddings de couleur et hiérarchie
145
  color_emb = color_model.get_text_embeddings([text_query])
146
  hierarchy_emb = hierarchy_model.get_text_embeddings([text_query])
147
 
148
- print(f" 🎨 Embedding couleur: {color_emb.shape}")
149
- print(f" 📂 Embedding hiérarchie: {hierarchy_emb.shape}")
150
 
151
  # Obtenir les embeddings du modèle principal
152
  text_inputs = processor(text=[text_query], padding=True, return_tensors="pt")
@@ -156,13 +156,13 @@ def example_search(models, image_path: str = None, text_query: str = None):
156
  outputs = main_model(**text_inputs)
157
  text_features = outputs.text_embeds
158
 
159
- print(f" 🎯 Embedding principal: {text_features.shape}")
160
 
161
  if image_path and os.path.exists(image_path):
162
  print(f" 🖼️ Image: {image_path}")
163
  image = Image.open(image_path).convert("RGB")
164
 
165
- # Obtenir les embeddings d'image
166
  image_inputs = processor(images=[image], return_tensors="pt")
167
  image_inputs = {k: v.to(device) for k, v in image_inputs.items()}
168
 
@@ -170,37 +170,37 @@ def example_search(models, image_path: str = None, text_query: str = None):
170
  outputs = main_model(**image_inputs)
171
  image_features = outputs.image_embeds
172
 
173
- print(f" 🎯 Embedding image: {image_features.shape}")
174
 
175
 
176
  if __name__ == "__main__":
177
  import argparse
178
 
179
- parser = argparse.ArgumentParser(description="Exemple d'utilisation des modèles")
180
  parser.add_argument(
181
  "--repo-id",
182
  type=str,
183
  required=True,
184
- help="ID du repository Hugging Face"
185
  )
186
  parser.add_argument(
187
  "--text",
188
  type=str,
189
  default="red dress",
190
- help="Requête textuelle de recherche"
191
  )
192
  parser.add_argument(
193
  "--image",
194
  type=str,
195
  default=None,
196
- help="Chemin vers une image"
197
  )
198
 
199
  args = parser.parse_args()
200
 
201
- # Charger les modèles
202
  models = load_models_from_hf(args.repo_id)
203
 
204
- # Exemple de recherche
205
  example_search(models, image_path=args.image, text_query=args.text)
206
 
 
1
  #!/usr/bin/env python3
2
  """
3
+ Example usage of models from Hugging Face
4
  """
5
 
6
  import torch
 
10
  import json
11
  import os
12
 
13
+ # Import local models (to adapt to your structure)
14
  from color_model import ColorCLIP, SimpleTokenizer
15
  from hierarchy_model import Model as HierarchyModel, HierarchyExtractor
16
  from config import color_emb_dim, hierarchy_emb_dim
17
 
18
  def load_models_from_hf(repo_id: str, cache_dir: str = "./models_cache"):
19
  """
20
+ Load models from Hugging Face
21
 
22
  Args:
23
+ repo_id: ID of the Hugging Face repository
24
+ cache_dir: Local cache directory
25
  """
26
 
27
  os.makedirs(cache_dir, exist_ok=True)
28
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
29
 
30
+ print(f"📥 Loading models from '{repo_id}'...")
31
 
32
+ # 1. Loading color model
33
+ print(" 📦 Loading color model...")
34
  color_model_path = hf_hub_download(
35
  repo_id=repo_id,
36
  filename="color_model.pt",
37
  cache_dir=cache_dir
38
  )
39
 
40
+ # Loading vocabulary
41
  vocab_path = hf_hub_download(
42
  repo_id=repo_id,
43
  filename="tokenizer_vocab.json",
 
56
  color_model.tokenizer = tokenizer
57
  color_model.load_state_dict(checkpoint)
58
  color_model.eval()
59
+ print(" ✅ Color model loaded")
60
 
61
+ # 2. Loading hierarchy model
62
+ print(" 📦 Loading hierarchy model...")
63
  hierarchy_model_path = hf_hub_download(
64
  repo_id=repo_id,
65
  filename="hierarchy_model.pth",
 
78
  hierarchy_extractor = HierarchyExtractor(hierarchy_classes, verbose=False)
79
  hierarchy_model.set_hierarchy_extractor(hierarchy_extractor)
80
  hierarchy_model.eval()
81
+ print(" ✅ Hierarchy model loaded")
82
 
83
+ # 3. Loading main CLIP model
84
+ print(" 📦 Loading main CLIP model...")
85
  main_model_path = hf_hub_download(
86
  repo_id=repo_id,
87
+ filename="gap_clip.pth",
88
  cache_dir=cache_dir
89
  )
90
 
 
93
  )
94
  checkpoint = torch.load(main_model_path, map_location=device)
95
 
96
+ # Handle different checkpoint structures
97
  if isinstance(checkpoint, dict):
98
  if 'model_state_dict' in checkpoint:
99
  clip_model.load_state_dict(checkpoint['model_state_dict'])
100
  else:
101
+ # If the checkpoint is directly the state_dict
102
  clip_model.load_state_dict(checkpoint)
103
  else:
104
  clip_model.load_state_dict(checkpoint)
 
107
  clip_model.eval()
108
 
109
  processor = CLIPProcessor.from_pretrained('laion/CLIP-ViT-B-32-laion2B-s34B-b79K')
110
+ print(" ✅ Main CLIP model loaded")
111
 
112
+ print("\n✅ All models loaded!")
113
 
114
  return {
115
  'color_model': color_model,
 
122
 
123
  def example_search(models, image_path: str = None, text_query: str = None):
124
  """
125
+ Example search with the models
126
 
127
  Args:
128
+ models: Dictionary of loaded models
129
+ image_path: Path to an image (optional)
130
+ text_query: Text query (optional)
131
  """
132
 
133
  color_model = models['color_model']
 
136
  processor = models['processor']
137
  device = models['device']
138
 
139
+ print("\n🔍 Example search...")
140
 
141
  if text_query:
142
+ print(f" 📝 Text query: '{text_query}'")
143
 
144
  # Obtenir les embeddings de couleur et hiérarchie
145
  color_emb = color_model.get_text_embeddings([text_query])
146
  hierarchy_emb = hierarchy_model.get_text_embeddings([text_query])
147
 
148
+ print(f" 🎨 Color embedding: {color_emb.shape}")
149
+ print(f" 📂 Hierarchy embedding: {hierarchy_emb.shape}")
150
 
151
  # Obtenir les embeddings du modèle principal
152
  text_inputs = processor(text=[text_query], padding=True, return_tensors="pt")
 
156
  outputs = main_model(**text_inputs)
157
  text_features = outputs.text_embeds
158
 
159
+ print(f" 🎯 Main embedding: {text_features.shape}")
160
 
161
  if image_path and os.path.exists(image_path):
162
  print(f" 🖼️ Image: {image_path}")
163
  image = Image.open(image_path).convert("RGB")
164
 
165
+ # Get image embeddings
166
  image_inputs = processor(images=[image], return_tensors="pt")
167
  image_inputs = {k: v.to(device) for k, v in image_inputs.items()}
168
 
 
170
  outputs = main_model(**image_inputs)
171
  image_features = outputs.image_embeds
172
 
173
+ print(f" 🎯 Image embedding: {image_features.shape}")
174
 
175
 
176
  if __name__ == "__main__":
177
  import argparse
178
 
179
+ parser = argparse.ArgumentParser(description="Example usage of models")
180
  parser.add_argument(
181
  "--repo-id",
182
  type=str,
183
  required=True,
184
+ help="ID of the Hugging Face repository"
185
  )
186
  parser.add_argument(
187
  "--text",
188
  type=str,
189
  default="red dress",
190
+ help="Text query for search"
191
  )
192
  parser.add_argument(
193
  "--image",
194
  type=str,
195
  default=None,
196
+ help="Path to an image"
197
  )
198
 
199
  args = parser.parse_args()
200
 
201
+ # Load models
202
  models = load_models_from_hf(args.repo_id)
203
 
204
+ # Example search
205
  example_search(models, image_path=args.image, text_query=args.text)
206