import os import sys import pandas as pd import numpy as np import faiss import gradio as gr from sentence_transformers import SentenceTransformer from huggingface_hub import InferenceClient from datasets import load_dataset import json import re DATASET_REPO = "LCA/HACKATHON_PARTS" dataset = load_dataset(DATASET_REPO, split="train") df = dataset.to_pandas() descriptions = df['DESIGNATION'].tolist() codes = df["CODE"].astype(str).tolist() # --- Embedding model --- embedding_model = SentenceTransformer("all-MiniLM-L6-v2") #--- Load or compute embeddings + FAISS index --- #For start, test perf without caching this if os.path.exists("embeddings.npy") and os.path.exists("faiss.index"): embeddings = np.load("embeddings.npy") index = faiss.read_index("faiss.index") else: embeddings = embedding_model.encode(descriptions, convert_to_numpy=True) faiss.normalize_L2(embeddings) index = faiss.IndexFlatIP(embeddings.shape[1]) index.add(embeddings) # Save embeddings and index for future use np.save("embeddings.npy", embeddings) faiss.write_index(index, "faiss.index") # --- Inference API client --- # client = InferenceClient("HuggingFaceH4/zephyr-7b-beta", token=os.getenv("HF_TOKEN")) def rechercher_article(articleSource): print(f"Recherch article pour {articleSource}") article = {} source = articleSource["designation"] query_embedding = embedding_model.encode([source], convert_to_numpy=True) faiss.normalize_L2(query_embedding) # Recherche du/des voisin(s) le(s) plus proche(s) similarity_scores, indices = index.search(query_embedding, k=1) # Gérer la qualité du retour avec un seuil de similarité threshold = 0.7 # à ajuster selon vos tests print(f"Score de similarité ({similarity_scores[0][0]:.2f}) pour '{source}'") if similarity_scores[0][0] < threshold: article["code"] = "Inconnu" article["designation"] = source article["source"] = source article["quantite"] = articleSource.get("quantite", None) print(f"Code non trouvé pour '{source}'") else: article["code"] = codes[indices[0][0]] article["designation"] = descriptions[indices[0][0]] article["source"] = source article["quantite"] = articleSource.get("quantite", None) print(f"Code trouvé pour '{source}': {article['code']} / {article['designation']}") return article def extract_json_from_response(response): """ Extrait le premier bloc JSON valide d'une chaîne de texte contenant potentiellement du texte en vrac. Gère les dialogues USER/INST et autres artefacts de modèles de chat. Retourne un objet Python (dict) ou None si extraction impossible. """ # Nettoyer la réponse des balises de dialogue communes cleaned_response = response # Supprimer les balises de dialogue courantes patterns_to_remove = [ r'USER:.*?(?=\{|$)', r'INST:.*?(?=\{|$)', r'ASSISTANT:.*?(?=\{|$)', r'AI:.*?(?=\{|$)', r'```json', r'```', r'Here is the JSON:', r'The JSON response is:', r'Response:', ] for pattern in patterns_to_remove: cleaned_response = re.sub(pattern, '', cleaned_response, flags=re.IGNORECASE | re.DOTALL) # Recherche tous les blocs JSON potentiels dans la réponse nettoyée json_candidates = re.findall(r'({[\s\S]*?})', cleaned_response) for candidate in json_candidates: try: # Nettoyer le candidat des caractères parasites candidate = candidate.strip() parsed = json.loads(candidate) # Vérifier que c'est un objet avec la structure attendue if isinstance(parsed, dict): return parsed except Exception: continue # Si aucun bloc JSON valide trouvé, essayer de corriger les crochets manquants try: start = cleaned_response.index('{') end = cleaned_response.rindex('}') + 1 json_str = cleaned_response[start:end] return json.loads(json_str) except Exception as e: print("Erreur lors du parsing JSON extrait:", e) print("Réponse brute:", response) print("Réponse nettoyée:", cleaned_response) return None def respond(message): print(" ------------------ ") print(message) print(" ------------------ ") # Prompt par défaut custom_prompt = """Tu es un analyseur de texte qui extrait des informations d'articles. Tu dois analyser le message et identifier les articles demandés avec leurs quantités. IMPORTANT: Réponds UNIQUEMENT avec un objet JSON valide, sans texte supplémentaire. Format de réponse attendu: { "articles": [ { "designation": "description de l'article", "quantite": nombre_ou_null } ] } Règles: - Pas de texte avant ou après le JSON - Pas de commentaires - Pas de dialogue USER/INST - Juste le JSON brut """ messages = [{"role": "system", "content": custom_prompt}] messages += [{"role": "user", "content": message}] # Utiliser zephyr avec des paramètres plus stricts pour éviter les dialogues client = InferenceClient("HuggingFaceH4/zephyr-7b-beta", token=os.getenv("HF_TOKEN")) # client = InferenceClient( # "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B", # token=os.getenv("HF_TOKEN"), # #provider="auto" # or choose a supported provider from the error message # ) full_response = "" for chunk in client.chat_completion( messages, max_tokens=256, # Réduire pour éviter les dialogues longs stream=True, temperature=0.05, # Très faible pour plus de déterminisme top_p=0.3, # Plus restrictif stop=["\n\n", "USER:", "Assistant:", "###"] ): token = chunk.choices[0].delta.content if token: full_response += token # yield full_response.replace("\n", "\n\n") print("---- retour de l'analyse") print(full_response) print("--") json_response = extract_json_from_response(full_response) print(json_response) # If you expect a JSON response, you can try to parse it here # import json # try: order = {} try: if json_response is None: print("Aucun JSON valide trouvé dans la réponse") return {"articles": [], "erreur": "Impossible de parser la réponse"} articles = [] # Vérifier si la réponse a la structure attendue if "articles" in json_response: articles_data = json_response["articles"] else: # Si pas de clé "articles", essayer d'utiliser la réponse directement si c'est une liste if isinstance(json_response, list): articles_data = json_response else: print("Structure JSON inattendue:", json_response) return {"articles": [], "erreur": "Structure JSON inattendue"} for article in articles_data: if isinstance(article, dict) and "designation" in article: found_article = rechercher_article(article) articles.append(found_article) else: print("Article mal formaté:", article) order["articles"] = articles # Ajouter les champs destinataire et delai avec des valeurs figées order["destinataire"] = { "societe": "Société Exemple", "nom": "Dupont", "prenom": "Jean", "email": "jean.dupont@exemple.com" } order["delai"] = "2024-07-15" except Exception as e: print("Could not parse articles:", e) order = {} return order with gr.Blocks() as demo: gr.Markdown("# Part identification Assistant") #prompt_box = gr.Textbox(label="Prompt système", value=DEFAULT_PROMPT, lines=8) #temperature_slider = gr.Slider(label="Température", minimum=0.0, maximum=1.0, value=0.1, step=0.01) #top_p_slider = gr.Slider(label="Top-p", minimum=0.0, maximum=1.0, value=0.8, step=0.01) message_box = gr.Textbox(label="Votre question") response_box = gr.Textbox(label="Réponse de l'assistant", interactive=False, lines=30) send_btn = gr.Button("Envoyer") def chat(message): history = [] # ou récupère l'historique si tu veux le gérer gen = respond(message) return json.dumps(gen, indent=2, ensure_ascii=False) send_btn.click( chat, inputs=[message_box], outputs=[response_box] ) if __name__ == "__main__": demo.launch(share=True)