LCA commited on
Commit
145ebdf
·
verified ·
1 Parent(s): 852958b

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +170 -0
app.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import pandas as pd
4
+ import numpy as np
5
+ import faiss
6
+ import gradio as gr
7
+ from sentence_transformers import SentenceTransformer
8
+ from huggingface_hub import InferenceClient
9
+ from datasets import load_dataset
10
+ import json
11
+
12
+
13
+ DATASET_REPO = "LCA/HACKATHON_PARTS"
14
+
15
+ dataset = load_dataset(DATASET_REPO, split="train")
16
+ df = dataset.to_pandas()
17
+
18
+ descriptions = df['DESIGNATION'].tolist()
19
+ codes = df["CODE"].astype(str).tolist()
20
+
21
+
22
+ # --- Embedding model ---
23
+ embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
24
+
25
+ #--- Load or compute embeddings + FAISS index ---
26
+ #For start, test perf without caching this
27
+ if os.path.exists("embeddings.npy") and os.path.exists("faiss.index"):
28
+ embeddings = np.load("embeddings.npy")
29
+ index = faiss.read_index("faiss.index")
30
+ else:
31
+ embeddings = embedding_model.encode(descriptions, convert_to_numpy=True)
32
+ faiss.normalize_L2(embeddings)
33
+ index = faiss.IndexFlatIP(embeddings.shape[1])
34
+ index.add(embeddings)
35
+ # Save embeddings and index for future use
36
+ np.save("embeddings.npy", embeddings)
37
+ faiss.write_index(index, "faiss.index")
38
+
39
+ # --- Inference API client ---
40
+ # client = InferenceClient("HuggingFaceH4/zephyr-7b-beta", token=os.getenv("HF_TOKEN"))
41
+
42
+ def rechercher_article(articleSource):
43
+
44
+ source = articleSource["designation"]
45
+ query_embedding = embedding_model.encode([source], convert_to_numpy=True)
46
+ faiss.normalize_L2(query_embedding)
47
+ # Recherche du/des voisin(s) le(s) plus proche(s)
48
+ similarity_scores, indices = index.search(query_embedding, k=1)
49
+ # Gérer la qualité du retour avec un seuil de similarité
50
+ threshold = 0.7 # à ajuster selon vos tests
51
+ if similarity_scores[0][0] < threshold:
52
+ print(f"Score de similarité trop faible ({similarity_scores[0][0]:.2f}) pour '{source}'")
53
+ return "UNKNOWN"
54
+
55
+
56
+ article = {}
57
+ article["code"] = codes[indices[0][0]]
58
+ article["designation"] = descriptions[indices[0][0]]
59
+ article["source"] = source
60
+ article["quantite"] = articleSource.get("quantite", None)
61
+ print(f"Code trouvé pour '{source}': {article['code']} / {article['designation']}")
62
+
63
+
64
+ return article
65
+
66
+ def respond(message):
67
+
68
+ # Prompt par défaut
69
+ custom_prompt = """Tu es un analyseur de messages expert.
70
+ Ta mission est de déterminer dans le messages fourni quels sont les articles qui sont demandés et pour quelle quantité.
71
+ La réponse est au format json et donne 2 informations par article identifié : la désignation et le nombre
72
+ La désignation est formé du type d'article et des caractéristiques comme la matière ou les dimensions
73
+
74
+ Ne retourne que le JSON.
75
+
76
+ """
77
+ # query_embedding = embedding_model.encode([message], convert_to_numpy=True)
78
+ # faiss.normalize_L2(query_embedding)
79
+ # _, indices = index.search(query_embedding, k=5)
80
+ # context = "\n".join([f"{codes[i]}: {descriptions[i]}" for i in indices[0]])
81
+
82
+ # Utilise le prompt personnalisé
83
+
84
+
85
+ # message = custom_prompt.format(message=message)
86
+
87
+ messages = [{"role": "system", "content": custom_prompt}]
88
+ messages += [{"role": "user", "content": message}]
89
+
90
+ # full_response = client.text_generation(message)
91
+ client = InferenceClient("HuggingFaceH4/zephyr-7b-beta", token=os.getenv("HF_TOKEN"))
92
+ # client = InferenceClient(
93
+ # "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B",
94
+ # token=os.getenv("HF_TOKEN"),
95
+ # #provider="auto" # or choose a supported provider from the error message
96
+ # )
97
+
98
+
99
+ full_response = ""
100
+ for chunk in client.chat_completion(
101
+ messages,
102
+ max_tokens=512,
103
+ stream=True,
104
+ temperature=0.1,
105
+ top_p=0.8,
106
+ ):
107
+ token = chunk.choices[0].delta.content
108
+ if token:
109
+ full_response += token
110
+ # yield full_response.replace("\n", "\n\n")
111
+
112
+ # If you expect a JSON response, you can try to parse it here
113
+ # import json
114
+ # try:
115
+ order = {}
116
+ try:
117
+ data = json.loads(full_response)
118
+ articles = []
119
+ for article in data.get("articles", []):
120
+ found_article = rechercher_article(article)
121
+ if found_article != "UNKNOWN":
122
+ articles.append(found_article)
123
+ order["articles"] = articles
124
+ # Ajouter les champs destinataire et delai avec des valeurs figées
125
+ order["destinataire"] = {
126
+ "societe": "Société Exemple",
127
+ "nom": "Dupont",
128
+ "prenom": "Jean",
129
+ "email": "jean.dupont@exemple.com"
130
+ }
131
+ order["delai"] = "2024-07-15"
132
+ except Exception as e:
133
+ print("Could not parse articles:", e)
134
+ order = {}
135
+
136
+ return order
137
+
138
+
139
+ with gr.Blocks() as demo:
140
+ gr.Markdown("# Part identification Assistant")
141
+ #prompt_box = gr.Textbox(label="Prompt système", value=DEFAULT_PROMPT, lines=8)
142
+ #temperature_slider = gr.Slider(label="Température", minimum=0.0, maximum=1.0, value=0.1, step=0.01)
143
+ #top_p_slider = gr.Slider(label="Top-p", minimum=0.0, maximum=1.0, value=0.8, step=0.01)
144
+ message_box = gr.Textbox(label="Votre question")
145
+ response_box = gr.Textbox(label="Réponse de l'assistant", interactive=False, lines=30)
146
+ send_btn = gr.Button("Envoyer")
147
+
148
+ def chat(message):
149
+ history = [] # ou récupère l'historique si tu veux le gérer
150
+ gen = respond(message)
151
+ # full_response = ""
152
+ # for response in gen:
153
+ # full_response = full_response + response
154
+ # On renvoie la dernière réponse et le contexte utilisé
155
+ # Il faut recalculer le contexte ici pour l'afficher
156
+ # query_embedding = embedding_model.encode([message], convert_to_numpy=True)
157
+ # faiss.normalize_L2(query_embedding)
158
+ # _, indices = index.search(query_embedding, k=5)
159
+ # context = "\n".join([f"{codes[i]}: {descriptions[i]}" for i in indices[0]])
160
+ return json.dumps(gen, indent=2, ensure_ascii=False)
161
+
162
+ send_btn.click(
163
+ chat,
164
+ inputs=[message_box],
165
+ outputs=[response_box]
166
+ )
167
+
168
+ if __name__ == "__main__":
169
+ demo.launch(share=True)
170
+