adel67460 commited on
Commit
d108bfc
·
verified ·
1 Parent(s): e6dd2e8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +66 -78
app.py CHANGED
@@ -1,5 +1,6 @@
1
  import os
2
  import json
 
3
  import gradio as gr
4
  import torch
5
  import pandas as pd
@@ -7,19 +8,15 @@ from scipy.sparse import csr_matrix
7
  from sklearn.feature_extraction.text import TfidfVectorizer
8
  import open_clip
9
 
10
- # ==========================
11
- # 🔥 DEVICE
12
- # ==========================
13
  device = "cuda" if torch.cuda.is_available() else "cpu"
14
  print(f"🔹 Utilisation du périphérique : {device}")
15
 
16
- # ==========================
17
- # 🔥 CHARGEMENT OPENCLIP ViT-H/14
18
- # ==========================
19
  print("🔄 Chargement du modèle OpenCLIP ViT-H/14...")
20
 
21
  model_name = "ViT-H-14"
22
- pretrained = "laion2b_s32b_b79k" # meilleur checkpoint
23
 
24
  model, _, preprocess = open_clip.create_model_and_transforms(
25
  model_name,
@@ -31,114 +28,109 @@ tokenizer = open_clip.get_tokenizer(model_name)
31
  model = model.to(device)
32
  model.eval()
33
 
34
- print("✅ OpenCLIP chargé avec succès !")
35
 
36
- # ==========================
37
- # 🔥 JSON produits
38
- # ==========================
39
  PRODUCTS_FILE = "products.json"
40
  QA_FILE = "qa_sequences_output.json"
41
 
42
- def safe_load_json(path):
43
- if not os.path.exists(path):
44
- print(f"⛔ Fichier introuvable : {path}")
 
45
  return []
46
  try:
47
- with open(path, "r", encoding="utf-8") as f:
48
  data = json.load(f)
49
  return data.get("products", []) if "products" in data else data
50
- except:
51
- print(f"⚠️ Erreur JSON dans {path}")
52
  return []
53
 
54
  products_data = safe_load_json(PRODUCTS_FILE)
55
  qa_data = safe_load_json(QA_FILE)
56
 
57
- # ==========================
58
- # 🔥 EMBEDDINGS TEXTE (OpenCLIP)
59
- # ==========================
60
  def get_text_embeddings(texts):
 
61
  with torch.no_grad():
 
62
  tokens = tokenizer(texts).to(device)
63
 
64
- features = model.encode_text(tokens)
 
65
 
66
- # Normalisation L2 → très important pour cosine
67
- features /= features.norm(dim=-1, keepdim=True)
68
 
69
- return features.cpu().numpy()
70
-
71
- # ==========================
72
- # 🔥 EMBEDDING PRODUITS
73
- # ==========================
74
- print("🛠️ Génération embeddings produits...")
75
 
 
 
76
  product_embeddings = get_text_embeddings([
77
- p["title"] + " " + p["description"]
78
- for p in products_data
79
  ])
 
80
 
81
- print("✅ Embeddings produits générés !")
82
-
83
- # ==========================
84
- # 🔥 TF-IDF
85
- # ==========================
86
  vectorizer = TfidfVectorizer(stop_words="english")
87
-
88
  tfidf_matrix = vectorizer.fit_transform([
89
- p["title"] + " " + p["description"]
90
- for p in products_data
91
  ])
92
 
93
- # ==========================
94
- # 🔥 RECHERCHE HYBRIDE
95
- # ==========================
96
  def search_products(query, category, min_price, max_price,
97
- weight_tfidf=0.5, weight_embed=0.5):
98
 
99
  if not query.strip():
100
- return "❌ Veuillez entrer un terme valide."
101
 
102
- min_price = float(min_price) if min_price else 0
103
- max_price = float(max_price) if max_price else float("inf")
104
 
105
- # Embedding requête
106
- q_emb = get_text_embeddings([query])[0]
107
 
108
- # Cosine similarity = dot product car vecteurs normalisés
109
- clip_scores = (product_embeddings @ q_emb).tolist()
110
 
111
- # TF-IDF similarity
112
- query_vec = vectorizer.transform([query])
113
- tfidf_scores = (tfidf_matrix @ query_vec.T).toarray().flatten()
114
 
115
- # Normalisation
116
- def norm(x):
117
- return (x - x.min()) / (x.max() - x.min() + 1e-6)
 
118
 
119
- clip_scores = norm(pd.Series(clip_scores))
120
- tfidf_scores = norm(pd.Series(tfidf_scores))
121
 
122
- final = weight_tfidf * tfidf_scores + weight_embed * clip_scores
 
123
 
124
- df = pd.DataFrame(products_data)
125
- df["score"] = final.values
 
126
 
127
- # Filtres
128
- df = df[
129
- (df["price"].astype(float).fillna(0) >= min_price) &
130
- (df["price"].astype(float).fillna(0) <= max_price) &
131
- (df["availability"].fillna("").str.lower() == "in stock")
132
  ]
133
 
134
- if category and category.lower() != "toutes":
135
- df = df[df["category"].str.contains(category, case=False, na=False)]
 
 
 
136
 
137
- return df.sort_values("score", ascending=False).head(20)
138
 
139
- # ==========================
140
- # 🔥 INTERFACE GRADIO
141
- # ==========================
142
  app = gr.Interface(
143
  fn=search_products,
144
  inputs=[
@@ -151,11 +143,7 @@ app = gr.Interface(
151
  gr.Dataframe(headers=[
152
  "ID", "Titre", "Description", "Prix", "Disponibilité", "Score"
153
  ])
154
- ],
155
- title="🔍 Recherche IA e-commerce avec OpenCLIP",
156
- description="Moteur de recherche hybride basé sur OpenCLIP ViT-H/14 + TF-IDF"
157
  )
158
 
159
- if __name__ == "__main__":
160
- print("🚀 Lancement interface...")
161
- app.launch()
 
1
  import os
2
  import json
3
+ import re
4
  import gradio as gr
5
  import torch
6
  import pandas as pd
 
8
  from sklearn.feature_extraction.text import TfidfVectorizer
9
  import open_clip
10
 
11
+ # 📌 Vérifier si CUDA est disponible
 
 
12
  device = "cuda" if torch.cuda.is_available() else "cpu"
13
  print(f"🔹 Utilisation du périphérique : {device}")
14
 
15
+ # 📌 Chargement du modèle OpenCLIP
 
 
16
  print("🔄 Chargement du modèle OpenCLIP ViT-H/14...")
17
 
18
  model_name = "ViT-H-14"
19
+ pretrained = "laion2b_s32b_b79k"
20
 
21
  model, _, preprocess = open_clip.create_model_and_transforms(
22
  model_name,
 
28
  model = model.to(device)
29
  model.eval()
30
 
31
+ print("✅ Modèle OpenCLIP chargé avec succès !")
32
 
33
+ # 📌 Définition des fichiers JSON
 
 
34
  PRODUCTS_FILE = "products.json"
35
  QA_FILE = "qa_sequences_output.json"
36
 
37
+ # 📌 Fonction pour charger les fichiers JSON
38
+ def safe_load_json(file_path):
39
+ if not os.path.exists(file_path):
40
+ print(f"⛔ Fichier introuvable : {file_path}")
41
  return []
42
  try:
43
+ with open(file_path, "r", encoding="utf-8") as f:
44
  data = json.load(f)
45
  return data.get("products", []) if "products" in data else data
46
+ except json.JSONDecodeError:
47
+ print(f"⚠️ Erreur de décodage JSON dans {file_path}")
48
  return []
49
 
50
  products_data = safe_load_json(PRODUCTS_FILE)
51
  qa_data = safe_load_json(QA_FILE)
52
 
53
+ # 📌 Générer des embeddings pour les produits
 
 
54
  def get_text_embeddings(texts):
55
+ """Génère des embeddings via OpenCLIP (même logique que ton code Marqo)."""
56
  with torch.no_grad():
57
+ # Tokenisation
58
  tokens = tokenizer(texts).to(device)
59
 
60
+ # Encodage texte
61
+ embeddings = model.encode_text(tokens)
62
 
63
+ # Normalisation
64
+ embeddings = embeddings / embeddings.norm(dim=-1, keepdim=True)
65
 
66
+ return embeddings.cpu().numpy()
 
 
 
 
 
67
 
68
+ # Création des embeddings pour tous les produits
69
+ print("🛠️ Génération des embeddings des produits...")
70
  product_embeddings = get_text_embeddings([
71
+ prod["title"] + " " + prod["description"]
72
+ for prod in products_data
73
  ])
74
+ print("✅ Embeddings générés et sauvegardés !")
75
 
76
+ # 📌 TF-IDF Vectorizer pour une recherche hybride
 
 
 
 
77
  vectorizer = TfidfVectorizer(stop_words="english")
 
78
  tfidf_matrix = vectorizer.fit_transform([
79
+ prod["title"] + " " + prod["description"]
80
+ for prod in products_data
81
  ])
82
 
83
+ # 📌 Recherche hybride avec OpenCLIP embeddings + TF-IDF
 
 
84
  def search_products(query, category, min_price, max_price,
85
+ weight_tfidf=0.5, weight_openclip=0.5):
86
 
87
  if not query.strip():
88
+ return "❌ Veuillez entrer un terme de recherche valide."
89
 
90
+ min_price = float(min_price) if isinstance(min_price, (int, float)) else 0
91
+ max_price = float(max_price) if isinstance(max_price, (int, float)) else float("inf")
92
 
93
+ # 📌 Embedding requête
94
+ query_embedding = get_text_embeddings([query])[0]
95
 
96
+ # 📌 Cosine similarity (dot product car vecteurs normalisés)
97
+ clip_scores = (product_embeddings @ query_embedding).tolist()
98
 
99
+ # 📌 TF-IDF Similarité
100
+ query_vector_sparse = csr_matrix(vectorizer.transform([query]))
101
+ tfidf_scores = (tfidf_matrix * query_vector_sparse.T).toarray().flatten()
102
 
103
+ # 📌 Normalisation
104
+ def normalize(v):
105
+ v = pd.Series(v)
106
+ return (v - v.min()) / (v.max() - v.min() + 1e-6)
107
 
108
+ clip_scores = normalize(clip_scores)
109
+ tfidf_scores = normalize(tfidf_scores)
110
 
111
+ # 📌 Fusion
112
+ final_scores = weight_tfidf * tfidf_scores + weight_openclip * clip_scores
113
 
114
+ # 📌 DataFrame résultats
115
+ results_df = pd.DataFrame(products_data)
116
+ results_df["score"] = final_scores
117
 
118
+ # 📌 Filtrage prix + dispo
119
+ results_df = results_df[
120
+ (results_df["price"].fillna(0).astype(float) >= min_price) &
121
+ (results_df["price"].fillna(0).astype(float) <= max_price) &
122
+ (results_df["availability"].fillna("").str.lower() == "in stock")
123
  ]
124
 
125
+ # 📌 Filtrer par catégorie
126
+ if category and category != "Toutes":
127
+ results_df = results_df[
128
+ results_df["category"].str.contains(category, case=False, na=False)
129
+ ]
130
 
131
+ return results_df.sort_values(by="score", ascending=False).head(20)
132
 
133
+ # 📌 Interface Gradio
 
 
134
  app = gr.Interface(
135
  fn=search_products,
136
  inputs=[
 
143
  gr.Dataframe(headers=[
144
  "ID", "Titre", "Description", "Prix", "Disponibilité", "Score"
145
  ])
146
+ ]
 
 
147
  )
148
 
149
+ app.launch()