Spaces:
Running
Running
| import os | |
| import pandas as pd | |
| import chromadb | |
| from google import genai | |
| from sentence_transformers import SentenceTransformer, CrossEncoder | |
| from typing import List, Dict | |
| from flask import Flask, request, jsonify | |
| from flask_cors import CORS | |
| from datetime import datetime | |
| # ====================================================================== | |
| # CONFIGURATION | |
| # ====================================================================== | |
| DATA_FILE_PATH = "data/QR.csv" | |
| CHROMA_DB_PATH = "data/bdd_ChromaDB" | |
| COLLECTION_NAME = "qr_data_dual_embeddings" | |
| Q_COLUMN_NAME = "Question" | |
| R_COLUMN_NAME = "Reponse" | |
| SYSTEM_PROMPT_PATH = "data/system_prompt.txt" | |
| SRC_CROSS_ENCODER = "models/mmarco-mMiniLMv2-L12-H384-v1" | |
| SRC_PARAPHRASE = "models/paraphrase-mpnet-base-v2" | |
| N_RESULTS_RETRIEVAL = 10 | |
| N_RESULTS_RERANK = 3 | |
| GEMINI_API_KEY = "AIzaSyDXXY7uSXryTxZ51jQFsSLcPnC_Ivt9V1g" | |
| GEMINI_MODEL = "gemini-2.5-flash" | |
| MAX_CONVERSATION_HISTORY = 10 | |
| # Configuration pour l'accès externe (host et port) | |
| API_HOST = '0.0.0.0' | |
| API_PORT = 1212 | |
| # ====================================================================== | |
| # VARIABLES GLOBALES | |
| # ====================================================================== | |
| model_cross_encoder: CrossEncoder = None | |
| model_paraphrase: SentenceTransformer = None | |
| collection: chromadb.Collection = None | |
| system_prompt: str = None | |
| gemini_client: genai.Client = None | |
| conversation_histories: Dict[str, List[Dict[str, str]]] = {} | |
| conversation_start_times: Dict[str, str] = {} | |
| # ====================================================================== | |
| # CHARGEMENT DES RESSOURCES | |
| # ====================================================================== | |
| def load_models(): | |
| """Charge les modèles SentenceTransformer et CrossEncoder.""" | |
| print("⏳ Chargement des modèles...") | |
| try: | |
| cross_encoder = CrossEncoder( | |
| SRC_CROSS_ENCODER if os.path.exists(SRC_CROSS_ENCODER) | |
| else "cross-encoder/mmarco-mMiniLMv2-L12-H384-v1" | |
| ) | |
| paraphrase = SentenceTransformer( | |
| SRC_PARAPHRASE if os.path.exists(SRC_PARAPHRASE) | |
| else "sentence-transformers/paraphrase-mpnet-base-v2" | |
| ) | |
| print("✅ Modèles chargés avec succès.") | |
| return cross_encoder, paraphrase | |
| except Exception as e: | |
| print(f"❌ Erreur chargement modèles: {e}") | |
| raise | |
| def load_data(): | |
| """Charge le DataFrame depuis le CSV.""" | |
| try: | |
| if not os.path.exists(DATA_FILE_PATH): | |
| print(f"⚠️ Fichier {DATA_FILE_PATH} non trouvé. Utilisation d'exemple.") | |
| df = pd.DataFrame({ | |
| Q_COLUMN_NAME: ["Où est le soleil?", "Qui est l'IA?"], | |
| R_COLUMN_NAME: ["Le soleil est une étoile.", "L'IA est l'intelligence artificielle."] | |
| }) | |
| else: | |
| df = pd.read_csv(DATA_FILE_PATH) | |
| print(f"✅ {len(df)} lignes chargées depuis {DATA_FILE_PATH}.") | |
| return df | |
| except Exception as e: | |
| print(f"❌ Erreur chargement données: {e}") | |
| raise | |
| def load_system_prompt(): | |
| """Charge le system prompt.""" | |
| try: | |
| with open(SYSTEM_PROMPT_PATH, 'r', encoding='utf-8') as f: | |
| return f.read().strip() | |
| except FileNotFoundError: | |
| default = "Tu es un assistant utile et concis. Réponds à la requête de l'utilisateur." | |
| print(f"⚠️ System prompt non trouvé. Utilisation du prompt par défaut.") | |
| return default | |
| def initialize_gemini_client(): | |
| """Initialise le client Google Gemini.""" | |
| try: | |
| return genai.Client(api_key=GEMINI_API_KEY) | |
| except Exception as e: | |
| print(f"❌ Erreur Gemini: {e}") | |
| raise | |
| # ====================================================================== | |
| # CHROMADB SETUP | |
| # ====================================================================== | |
| def setup_chromadb_collection(client, df, model_paraphrase): | |
| """Configure et remplit la collection ChromaDB.""" | |
| total_docs = len(df) * 2 | |
| try: | |
| collection = client.get_or_create_collection(name=COLLECTION_NAME) | |
| except Exception as e: | |
| print(f"❌ Erreur ChromaDB: {e}") | |
| raise | |
| if collection.count() == total_docs and total_docs > 0: | |
| print(f"✅ Collection déjà remplie ({collection.count()} docs).") | |
| return collection | |
| if total_docs == 0: | |
| print("⚠️ DataFrame vide.") | |
| return collection | |
| print(f"⏳ Remplissage de ChromaDB ({len(df)} lignes)...") | |
| docs, metadatas, ids = [], [], [] | |
| for i, row in df.iterrows(): | |
| question = str(row[Q_COLUMN_NAME]) | |
| reponse = str(row[R_COLUMN_NAME]) | |
| meta = {Q_COLUMN_NAME: question, R_COLUMN_NAME: reponse, "source_row": i} | |
| docs.append(question) | |
| metadatas.append({**meta, "type": "question"}) | |
| ids.append(f"id_{i}_Q") | |
| docs.append(reponse) | |
| metadatas.append({**meta, "type": "reponse"}) | |
| ids.append(f"id_{i}_R") | |
| embeddings = model_paraphrase.encode(docs, show_progress_bar=False).tolist() | |
| try: | |
| client.delete_collection(name=COLLECTION_NAME) | |
| except: | |
| pass | |
| collection = client.get_or_create_collection(name=COLLECTION_NAME) | |
| collection.add(embeddings=embeddings, documents=docs, metadatas=metadatas, ids=ids) | |
| print(f"✅ Collection remplie: {collection.count()} documents.") | |
| return collection | |
| # ====================================================================== | |
| # RAG - RETRIEVAL & RERANKING | |
| # ====================================================================== | |
| def retrieve_and_rerank(query_text, collection, model_paraphrase, model_cross_encoder): | |
| """Récupère et rerank les résultats.""" | |
| print(f"🔍 Récupération pour: '{query_text[:40]}...'") | |
| query_emb = model_paraphrase.encode([query_text]).tolist() | |
| results = collection.query( | |
| query_embeddings=query_emb, | |
| n_results=N_RESULTS_RETRIEVAL, | |
| include=['documents', 'metadatas', 'distances'] | |
| ) | |
| if not results['ids'][0]: | |
| print("⚠️ Aucun résultat trouvé.") | |
| return pd.DataFrame() | |
| candidates = [] | |
| cross_input = [] | |
| for i, doc in enumerate(results['documents'][0]): | |
| meta = results['metadatas'][0][i] | |
| candidates.append({ | |
| 'question': meta[Q_COLUMN_NAME], | |
| 'reponse': meta[R_COLUMN_NAME], | |
| 'doc_type': meta.get('type'), | |
| 'text_reranked': doc, | |
| 'initial_distance': results['distances'][0][i] | |
| }) | |
| cross_input.append([query_text, doc]) | |
| scores = model_cross_encoder.predict(cross_input) | |
| for i, score in enumerate(scores): | |
| candidates[i]['rerank_score'] = score | |
| df = pd.DataFrame(candidates).sort_values('rerank_score', ascending=False) | |
| df = df.drop_duplicates(subset=['question', 'reponse'], keep='first') | |
| return df.head(N_RESULTS_RERANK) | |
| def generate_rag_prompt(query_text, df_results, conversation_history): | |
| """Génère le prompt RAG final.""" | |
| context = [] | |
| if not df_results.empty: | |
| for _, row in df_results.iterrows(): | |
| context.append(f"Q: {row['question']}\nR: {row['reponse']}") | |
| context_str = "\n---\n".join(context) | |
| history_str = "" | |
| if conversation_history: | |
| history_str = "HISTORIQUE:\n" | |
| for msg in conversation_history: | |
| role = "USER" if msg["role"] == "user" else "ASSISTANT" | |
| history_str += f"{role}: {msg['content']}\n" | |
| history_str += "\n" | |
| return f"""{history_str}UTILISATEUR: {query_text} | |
| CONTEXTE (si utile): | |
| [{context_str}] | |
| INSTRUCTIONS: | |
| - Réponds brièvement et clairement, en un minimum de 2 paragraphe et un maximum de 3 paragraphe, entre chaque paragraphe tu écrira un # | |
| - Reste professionnel, met des emojis dans ta réponse tout en restant profesionnel | |
| - Réponds uniquement à la question posée | |
| - Ne mentionne NOVA que si l'utilisateur l'évoque clairement | |
| - Mentionne obligatoirement Mathieu dans ta réponse""" | |
| # ====================================================================== | |
| # GESTION HISTORIQUE | |
| # ====================================================================== | |
| def get_conversation_history(session_id): | |
| """Récupère l'historique d'une session.""" | |
| return conversation_histories.get(session_id, []) | |
| def add_to_history(session_id, role, content): | |
| """Ajoute un message à l'historique.""" | |
| if session_id not in conversation_histories: | |
| conversation_histories[session_id] = [] | |
| conversation_histories[session_id].append({"role": role, "content": content}) | |
| if len(conversation_histories[session_id]) > MAX_CONVERSATION_HISTORY * 2: | |
| conversation_histories[session_id] = conversation_histories[session_id][-(MAX_CONVERSATION_HISTORY * 2):] | |
| def clear_history(session_id): | |
| """Efface l'historique d'une session.""" | |
| conversation_histories[session_id] = [] | |
| # ====================================================================== | |
| # CALL GEMINI | |
| # ====================================================================== | |
| def call_gemini(rag_prompt, system_prompt, gemini_client): | |
| """Appelle Google Gemini.""" | |
| try: | |
| response = gemini_client.models.generate_content( | |
| model=GEMINI_MODEL, | |
| contents=f"{system_prompt}\n\n{rag_prompt}" | |
| ) | |
| return response.text | |
| except Exception as e: | |
| print(f"❌ Erreur Gemini: {e}") | |
| return f"Erreur: {str(e)}" | |
| # ====================================================================== | |
| # ANSWER PROCESS | |
| # ====================================================================== | |
| def get_answer(query_text, collection, model_paraphrase, model_cross_encoder, conversation_history): | |
| """Exécute le processus RAG complet.""" | |
| print(f"\n{'='*50}") | |
| print(f"🚀 Traitement: '{query_text}'") | |
| print(f"{'='*50}") | |
| df_results = retrieve_and_rerank(query_text, collection, model_paraphrase, model_cross_encoder) | |
| final_prompt = generate_rag_prompt(query_text, df_results, conversation_history) | |
| return final_prompt | |
| # ====================================================================== | |
| # INITIALISATION GLOBALE | |
| # ====================================================================== | |
| def initialize_global_resources(): | |
| """Initialise tous les modèles et ressources.""" | |
| global model_cross_encoder, model_paraphrase, collection, system_prompt, gemini_client | |
| print("\n" + "="*50) | |
| print("⚙️ INITIALISATION RAG") | |
| print("="*50) | |
| os.makedirs(CHROMA_DB_PATH, exist_ok=True) | |
| try: | |
| model_cross_encoder, model_paraphrase = load_models() | |
| df = load_data() | |
| system_prompt = load_system_prompt() | |
| gemini_client = initialize_gemini_client() | |
| except Exception: | |
| return False | |
| try: | |
| chroma_client = chromadb.PersistentClient(path=CHROMA_DB_PATH) | |
| collection = setup_chromadb_collection(chroma_client, df, model_paraphrase) | |
| print("✅ INITIALISATION COMPLÈTE\n") | |
| return True | |
| except Exception: | |
| return False | |
| # ====================================================================== | |
| # FLASK API | |
| # ====================================================================== | |
| app = Flask(__name__) | |
| # CORS activé, permet les requêtes depuis n'importe quelle origine | |
| CORS(app) | |
| def api_status(): | |
| """Route de ping pour vérifier l'état de l'API.""" | |
| return jsonify({"status": "everything is good"}), 200 | |
| def api_get_answer(): | |
| """Endpoint principal pour obtenir une réponse.""" | |
| if any(x is None for x in [model_cross_encoder, model_paraphrase, collection, system_prompt, gemini_client]): | |
| return jsonify({"error": "Ressources non chargées"}), 500 | |
| try: | |
| data = request.get_json() | |
| query_text = data.get('query_text') | |
| session_id = data.get('session_id', 'archive') | |
| if not query_text: | |
| return jsonify({"error": "Champ 'query_text' manquant"}), 400 | |
| # Récupère historique | |
| history = get_conversation_history(session_id) | |
| # Génère prompt RAG | |
| rag_prompt = get_answer(query_text, collection, model_paraphrase, model_cross_encoder, history) | |
| # Appelle Gemini | |
| response = call_gemini(rag_prompt, system_prompt, gemini_client) | |
| # Sauvegarde réponse | |
| add_to_history(session_id, "user", query_text) | |
| add_to_history(session_id, "assistant", response) | |
| return jsonify({"generated_response": response}) | |
| except Exception as e: | |
| print(f"❌ Erreur: {e}") | |
| return jsonify({"error": str(e)}), 500 | |
| def api_clear_history(): | |
| """Efface l'historique d'une session.""" | |
| try: | |
| data = request.get_json() | |
| session_id = data.get('session_id', 'archive') | |
| clear_history(session_id) | |
| return jsonify({"message": f"Historique effacé: {session_id}"}) | |
| except Exception as e: | |
| return jsonify({"error": str(e)}), 500 | |
| # ====================================================================== | |
| # MAIN | |
| # ====================================================================== | |
| if __name__ == '__main__': | |
| print("start app.py") | |
| if initialize_global_resources(): | |
| # Récupération de l'adresse IP si possible (pour l'affichage) | |
| try: | |
| import socket | |
| s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) | |
| s.connect(("8.8.8.8", 80)) # Connecte à un serveur externe pour trouver l'IP locale utilisée | |
| local_ip = s.getsockname()[0] | |
| s.close() | |
| except Exception: | |
| local_ip = "127.0.0.1" # Fallback si échec | |
| print("\n" + "="*50) | |
| print("🌐 SERVEUR DÉMARRÉ") | |
| print(f"✅ API accessible à l'URL (via l'interface réseau locale): http://{local_ip}:{API_PORT}") | |
| print(f"✅ Route Status: http://{local_ip}:{API_PORT}/status") | |
| print(f"💡 Pour un accès depuis l'extérieur, utilisez l'adresse IP publique de votre machine et assurez-vous que le port {API_PORT} est ouvert.") | |
| print("="*50 + "\n") | |
| # L'utilisation de host='0.0.0.0' dans app.run() permet l'accès depuis l'extérieur | |
| app.run(host=API_HOST, port=API_PORT, debug=False) | |
| else: | |
| print("❌ Impossible de démarrer le serveur") |