Spaces:
Sleeping
Sleeping
| import os | |
| import sys | |
| import pandas as pd | |
| import numpy as np | |
| import faiss | |
| import gradio as gr | |
| from sentence_transformers import SentenceTransformer | |
| from huggingface_hub import InferenceClient | |
| from datasets import load_dataset | |
| import json | |
| import re | |
| DATASET_REPO = "LCA/HACKATHON_PARTS" | |
| dataset = load_dataset(DATASET_REPO, split="train") | |
| df = dataset.to_pandas() | |
| descriptions = df['DESIGNATION'].tolist() | |
| codes = df["CODE"].astype(str).tolist() | |
| # --- Embedding model --- | |
| embedding_model = SentenceTransformer("all-MiniLM-L6-v2") | |
| #--- Load or compute embeddings + FAISS index --- | |
| #For start, test perf without caching this | |
| if os.path.exists("embeddings.npy") and os.path.exists("faiss.index"): | |
| embeddings = np.load("embeddings.npy") | |
| index = faiss.read_index("faiss.index") | |
| else: | |
| embeddings = embedding_model.encode(descriptions, convert_to_numpy=True) | |
| faiss.normalize_L2(embeddings) | |
| index = faiss.IndexFlatIP(embeddings.shape[1]) | |
| index.add(embeddings) | |
| # Save embeddings and index for future use | |
| np.save("embeddings.npy", embeddings) | |
| faiss.write_index(index, "faiss.index") | |
| # --- Inference API client --- | |
| # client = InferenceClient("HuggingFaceH4/zephyr-7b-beta", token=os.getenv("HF_TOKEN")) | |
| def rechercher_article(articleSource): | |
| print(f"Recherch article pour {articleSource}") | |
| article = {} | |
| source = articleSource["designation"] | |
| query_embedding = embedding_model.encode([source], convert_to_numpy=True) | |
| faiss.normalize_L2(query_embedding) | |
| # Recherche du/des voisin(s) le(s) plus proche(s) | |
| similarity_scores, indices = index.search(query_embedding, k=1) | |
| # Gérer la qualité du retour avec un seuil de similarité | |
| threshold = 0.7 # à ajuster selon vos tests | |
| print(f"Score de similarité ({similarity_scores[0][0]:.2f}) pour '{source}'") | |
| if similarity_scores[0][0] < threshold: | |
| article["code"] = "Inconnu" | |
| article["designation"] = source | |
| article["source"] = source | |
| article["quantite"] = articleSource.get("quantite", None) | |
| print(f"Code non trouvé pour '{source}'") | |
| else: | |
| article["code"] = codes[indices[0][0]] | |
| article["designation"] = descriptions[indices[0][0]] | |
| article["source"] = source | |
| article["quantite"] = articleSource.get("quantite", None) | |
| print(f"Code trouvé pour '{source}': {article['code']} / {article['designation']}") | |
| return article | |
| def extract_json_from_response(response): | |
| """ | |
| Extrait le premier bloc JSON valide d'une chaîne de texte contenant potentiellement du texte en vrac. | |
| Gère les dialogues USER/INST et autres artefacts de modèles de chat. | |
| Retourne un objet Python (dict) ou None si extraction impossible. | |
| """ | |
| # Nettoyer la réponse des balises de dialogue communes | |
| cleaned_response = response | |
| # Supprimer les balises de dialogue courantes | |
| patterns_to_remove = [ | |
| r'USER:.*?(?=\{|$)', | |
| r'INST:.*?(?=\{|$)', | |
| r'ASSISTANT:.*?(?=\{|$)', | |
| r'AI:.*?(?=\{|$)', | |
| r'```json', | |
| r'```', | |
| r'Here is the JSON:', | |
| r'The JSON response is:', | |
| r'Response:', | |
| ] | |
| for pattern in patterns_to_remove: | |
| cleaned_response = re.sub(pattern, '', cleaned_response, flags=re.IGNORECASE | re.DOTALL) | |
| # Recherche tous les blocs JSON potentiels dans la réponse nettoyée | |
| json_candidates = re.findall(r'({[\s\S]*?})', cleaned_response) | |
| for candidate in json_candidates: | |
| try: | |
| # Nettoyer le candidat des caractères parasites | |
| candidate = candidate.strip() | |
| parsed = json.loads(candidate) | |
| # Vérifier que c'est un objet avec la structure attendue | |
| if isinstance(parsed, dict): | |
| return parsed | |
| except Exception: | |
| continue | |
| # Si aucun bloc JSON valide trouvé, essayer de corriger les crochets manquants | |
| try: | |
| start = cleaned_response.index('{') | |
| end = cleaned_response.rindex('}') + 1 | |
| json_str = cleaned_response[start:end] | |
| return json.loads(json_str) | |
| except Exception as e: | |
| print("Erreur lors du parsing JSON extrait:", e) | |
| print("Réponse brute:", response) | |
| print("Réponse nettoyée:", cleaned_response) | |
| return None | |
| def respond(message): | |
| print(" ------------------ ") | |
| print(message) | |
| print(" ------------------ ") | |
| # Prompt par défaut | |
| custom_prompt = """Tu es un analyseur de texte qui extrait des informations d'articles. | |
| Tu dois analyser le message et identifier les articles demandés avec leurs quantités. | |
| IMPORTANT: Réponds UNIQUEMENT avec un objet JSON valide, sans texte supplémentaire. | |
| Format de réponse attendu: | |
| { | |
| "articles": [ | |
| { | |
| "designation": "description de l'article", | |
| "quantite": nombre_ou_null | |
| } | |
| ] | |
| } | |
| Règles: | |
| - Pas de texte avant ou après le JSON | |
| - Pas de commentaires | |
| - Pas de dialogue USER/INST | |
| - Juste le JSON brut | |
| """ | |
| messages = [{"role": "system", "content": custom_prompt}] | |
| messages += [{"role": "user", "content": message}] | |
| # Utiliser zephyr avec des paramètres plus stricts pour éviter les dialogues | |
| client = InferenceClient("HuggingFaceH4/zephyr-7b-beta", token=os.getenv("HF_TOKEN")) | |
| # client = InferenceClient( | |
| # "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B", | |
| # token=os.getenv("HF_TOKEN"), | |
| # #provider="auto" # or choose a supported provider from the error message | |
| # ) | |
| full_response = "" | |
| for chunk in client.chat_completion( | |
| messages, | |
| max_tokens=256, # Réduire pour éviter les dialogues longs | |
| stream=True, | |
| temperature=0.05, # Très faible pour plus de déterminisme | |
| top_p=0.3, # Plus restrictif | |
| stop=["\n\n", "USER:", "Assistant:", "###"] | |
| ): | |
| token = chunk.choices[0].delta.content | |
| if token: | |
| full_response += token | |
| # yield full_response.replace("\n", "\n\n") | |
| print("---- retour de l'analyse") | |
| print(full_response) | |
| print("--") | |
| json_response = extract_json_from_response(full_response) | |
| print(json_response) | |
| # If you expect a JSON response, you can try to parse it here | |
| # import json | |
| # try: | |
| order = {} | |
| try: | |
| if json_response is None: | |
| print("Aucun JSON valide trouvé dans la réponse") | |
| return {"articles": [], "erreur": "Impossible de parser la réponse"} | |
| articles = [] | |
| # Vérifier si la réponse a la structure attendue | |
| if "articles" in json_response: | |
| articles_data = json_response["articles"] | |
| else: | |
| # Si pas de clé "articles", essayer d'utiliser la réponse directement si c'est une liste | |
| if isinstance(json_response, list): | |
| articles_data = json_response | |
| else: | |
| print("Structure JSON inattendue:", json_response) | |
| return {"articles": [], "erreur": "Structure JSON inattendue"} | |
| for article in articles_data: | |
| if isinstance(article, dict) and "designation" in article: | |
| found_article = rechercher_article(article) | |
| articles.append(found_article) | |
| else: | |
| print("Article mal formaté:", article) | |
| order["articles"] = articles | |
| # Ajouter les champs destinataire et delai avec des valeurs figées | |
| order["destinataire"] = { | |
| "societe": "Société Exemple", | |
| "nom": "Dupont", | |
| "prenom": "Jean", | |
| "email": "jean.dupont@exemple.com" | |
| } | |
| order["delai"] = "2024-07-15" | |
| except Exception as e: | |
| print("Could not parse articles:", e) | |
| order = {} | |
| return order | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# Part identification Assistant") | |
| #prompt_box = gr.Textbox(label="Prompt système", value=DEFAULT_PROMPT, lines=8) | |
| #temperature_slider = gr.Slider(label="Température", minimum=0.0, maximum=1.0, value=0.1, step=0.01) | |
| #top_p_slider = gr.Slider(label="Top-p", minimum=0.0, maximum=1.0, value=0.8, step=0.01) | |
| message_box = gr.Textbox(label="Votre question") | |
| response_box = gr.Textbox(label="Réponse de l'assistant", interactive=False, lines=30) | |
| send_btn = gr.Button("Envoyer") | |
| def chat(message): | |
| history = [] # ou récupère l'historique si tu veux le gérer | |
| gen = respond(message) | |
| return json.dumps(gen, indent=2, ensure_ascii=False) | |
| send_btn.click( | |
| chat, | |
| inputs=[message_box], | |
| outputs=[response_box] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(share=True) | |