NOVA_API / app.py
MathieuGAL's picture
Update app.py
8491d5b verified
raw
history blame
18 kB
import os
import pandas as pd
import chromadb
from google import genai
from sentence_transformers import SentenceTransformer, CrossEncoder
from typing import List, Dict
from flask import Flask, request, jsonify
from flask_cors import CORS 
from datetime import datetime
# ======================================================================
# CONFIGURATION
# ======================================================================
DATA_FILE_PATH = "data/QR.csv"
# CORRECTION CRITIQUE: Déplacement de la DB vers /tmp
# Ce répertoire est le seul garanti en écriture sur Hugging Face Spaces.
CHROMA_DB_PATH = "/tmp/bdd_ChromaDB" 
COLLECTION_NAME = "qr_data_dual_embeddings"
Q_COLUMN_NAME = "Question"
R_COLUMN_NAME = "Reponse"
SYSTEM_PROMPT_PATH = "data/system_prompt.txt"
# Les chemins des modèles sont conservés (ils se mettront en cache dans /tmp grâce au Dockerfile)
SRC_CROSS_ENCODER = "models/mmarco-mMiniLMv2-L12-H384-v1"
SRC_PARAPHRASE = "models/paraphrase-mpnet-base-v2"
N_RESULTS_RETRIEVAL = 10
N_RESULTS_RERANK = 3
# Récupération de la clé depuis l'environnement (Hugging Face Secrets)
# Si non trouvée, utilise la clé de placeholder.
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY", "AIzaSyDXXY7uSXryTxZ51jQFsSLcPnC_Ivt9V1g"
GEMINI_MODEL = "gemini-2.5-flash"
MAX_CONVERSATION_HISTORY = 10
# Configuration pour l'accès externe (host et port)
API_HOST = '0.0.0.0'
API_PORT = 1212 # Le port 1212 est conservé, il doit être configuré dans le README.md
# ======================================================================
# VARIABLES GLOBALES
# ======================================================================
model_cross_encoder: CrossEncoder = None
model_paraphrase: SentenceTransformer = None
collection: chromadb.Collection = None
system_prompt: str = None
gemini_client: genai.Client = None
conversation_histories: Dict[str, List[Dict[str, str]]] = {}
conversation_start_times: Dict[str, str] = {}
# ======================================================================
# CHARGEMENT DES RESSOURCES
# ======================================================================
def load_models():
    """Charge les modèles SentenceTransformer et CrossEncoder."""
    print("⏳ Chargement des modèles...")
    try:
        # Tente de charger localement, sinon télécharge (le cache se fera dans /tmp)
        cross_encoder = CrossEncoder(
            SRC_CROSS_ENCODER if os.path.exists(SRC_CROSS_ENCODER) 
            else "cross-encoder/mmarco-mMiniLMv2-L12-H384-v1"
        )
        paraphrase = SentenceTransformer(
            SRC_PARAPHRASE if os.path.exists(SRC_PARAPHRASE) 
            else "sentence-transformers/paraphrase-mpnet-base-v2"
        )
        print("✅ Modèles chargés avec succès.")
        return cross_encoder, paraphrase
    except Exception as e:
        print(f"❌ Erreur chargement modèles: {e}")
        # Note: L'erreur de PermissionError est maintenant gérée par le Dockerfile
        raise
def load_data():
    """Charge le DataFrame depuis le CSV."""
    try:
        if not os.path.exists(DATA_FILE_PATH):
            print(f"⚠️ Fichier {DATA_FILE_PATH} non trouvé. Utilisation d'exemple.")
            df = pd.DataFrame({
                Q_COLUMN_NAME: ["Où est le soleil?", "Qui est l'IA?"],
                R_COLUMN_NAME: ["Le soleil est une étoile.", "L'IA est l'intelligence artificielle."]
            })
        else:
            df = pd.read_csv(DATA_FILE_PATH)
            print(f"✅ {len(df)} lignes chargées depuis {DATA_FILE_PATH}.")
        return df
    except Exception as e:
        print(f"❌ Erreur chargement données: {e}")
        raise
def load_system_prompt():
    """Charge le system prompt."""
    try:
        with open(SYSTEM_PROMPT_PATH, 'r', encoding='utf-8') as f:
            return f.read().strip()
    except FileNotFoundError:
        default = "Tu es un assistant utile et concis. Réponds à la requête de l'utilisateur."
        print(f"⚠️ System prompt non trouvé à {SYSTEM_PROMPT_PATH}. Utilisation du prompt par défaut.")
        return default
def initialize_gemini_client():
    """Initialise le client Google Gemini."""
    if GEMINI_API_KEY == "AIzaSyDXXY7uSXryTxZ51jQFsSLcPnC_Ivt9V1g":
        print("⚠️ AVIS: Clé Gemini par défaut/placeholder détectée. Veuillez la remplacer par un secret d'environnement nommé 'GEMINI_API_KEY' pour la production.")
    try:
        return genai.Client(api_key=GEMINI_API_KEY)
    except Exception as e:
        print(f"❌ Erreur lors de l'initialisation du client Gemini: {e}")
        raise
# ======================================================================
# CHROMADB SETUP
# ======================================================================
def setup_chromadb_collection(client, df, model_paraphrase):
    """Configure et remplit la collection ChromaDB."""
    total_docs = len(df) * 2
    
    # S'assurer que le répertoire de la DB existe
    os.makedirs(CHROMA_DB_PATH, exist_ok=True)
    
    try:
        collection = client.get_or_create_collection(name=COLLECTION_NAME)
    except Exception as e:
        print(f"❌ Erreur lors de l'accès à la collection ChromaDB: {e}")
        raise
    
    if collection.count() == total_docs and total_docs > 0:
        print(f"✅ Collection déjà remplie ({collection.count()} docs) dans {CHROMA_DB_PATH}.")
        return collection
    
    if total_docs == 0:
        print("⚠️ DataFrame vide. Collection non remplie.")
        return collection
    
    print(f"⏳ Remplissage de ChromaDB ({len(df)} lignes) à l'emplacement: {CHROMA_DB_PATH}...")
    
    docs, metadatas, ids = [], [], []
    
    for i, row in df.iterrows():
        question = str(row[Q_COLUMN_NAME])
        reponse = str(row[R_COLUMN_NAME])
        meta = {Q_COLUMN_NAME: question, R_COLUMN_NAME: reponse, "source_row": i}
        
        docs.append(question)
        metadatas.append({**meta, "type": "question"})
        ids.append(f"id_{i}_Q")
        
        docs.append(reponse)
        metadatas.append({**meta, "type": "reponse"})
        ids.append(f"id_{i}_R")
    
    embeddings = model_paraphrase.encode(docs, show_progress_bar=False).tolist()
    
    # Nettoyage et recréation (pour le cas où les données CSV ont changé)
    try:
        client.delete_collection(name=COLLECTION_NAME)
    except:
        pass
    
    collection = client.get_or_create_collection(name=COLLECTION_NAME)
    collection.add(embeddings=embeddings, documents=docs, metadatas=metadatas, ids=ids)
    
    print(f"✅ Collection remplie: {collection.count()} documents.")
    return collection
# ======================================================================
# RAG - RETRIEVAL & RERANKING
# ======================================================================
def retrieve_and_rerank(query_text, collection, model_paraphrase, model_cross_encoder):
    """Récupère et rerank les résultats."""
    print(f"🔍 Récupération pour: '{query_text[:40]}...'")
    
    query_emb = model_paraphrase.encode([query_text]).tolist()
    results = collection.query(
        query_embeddings=query_emb,
        n_results=N_RESULTS_RETRIEVAL,
        include=['documents', 'metadatas', 'distances']
    )
    
    if not results['ids'][0]:
        print("⚠️ Aucun résultat trouvé.")
        return pd.DataFrame()
    
    candidates = []
    cross_input = []
    
    for i, doc in enumerate(results['documents'][0]):
        meta = results['metadatas'][0][i]
        candidates.append({
            'question': meta[Q_COLUMN_NAME],
            'reponse': meta[R_COLUMN_NAME],
            'doc_type': meta.get('type'),
            'text_reranked': doc,
            'initial_distance': results['distances'][0][i]
        })
        cross_input.append([query_text, doc])
    
    scores = model_cross_encoder.predict(cross_input)
    for i, score in enumerate(scores):
        candidates[i]['rerank_score'] = score
    
    df = pd.DataFrame(candidates).sort_values('rerank_score', ascending=False)
    df = df.drop_duplicates(subset=['question', 'reponse'], keep='first')
    
    return df.head(N_RESULTS_RERANK)
def generate_rag_prompt(query_text, df_results, conversation_history):
    """Génère le prompt RAG final."""
    context = []
    if not df_results.empty:
        for _, row in df_results.iterrows():
            context.append(f"Q: {row['question']}\nR: {row['reponse']}")
    
    context_str = "\n---\n".join(context)
    
    history_str = ""
    if conversation_history:
        history_str = "HISTORIQUE:\n"
        # Ajout du contexte pour le LLM, mais on ne veut pas l'historique complet
        # On va limiter l'historique à l'affichage si on dépasse MAX_CONVERSATION_HISTORY
        display_history = conversation_history[-(MAX_CONVERSATION_HISTORY * 2):] 
        for msg in display_history:
            role = "USER" if msg["role"] == "user" else "ASSISTANT"
            # On utilise 'content' pour le texte du message
            history_str += f"{role}: {msg['content']}\n"
        history_str += "\n"
    
    return f"""{history_str}UTILISATEUR: {query_text}
CONTEXTE (si utile):
[{context_str}]
INSTRUCTIONS:
- Réponds brièvement et clairement, en un minimum de 2 paragraphe et un maximum de 3 paragraphe, entre chaque paragraphe tu écrira un #
- Reste professionnel, met des emojis dans ta réponse tout en restant profesionnel
- Réponds uniquement à la question posée
- Ne mentionne NOVA que si l'utilisateur l'évoque clairement
- Mentionne obligatoirement Mathieu dans ta réponse"""
# ======================================================================
# GESTION HISTORIQUE
# ======================================================================
def get_conversation_history(session_id):
    """Récupère l'historique d'une session."""
    return conversation_histories.get(session_id, [])
def add_to_history(session_id, role, content):
    """Ajoute un message à l'historique."""
    if session_id not in conversation_histories:
        conversation_histories[session_id] = []
    
    conversation_histories[session_id].append({"role": role, "content": content})
    
    # Limiter la taille de l'historique conservé en mémoire
    if len(conversation_histories[session_id]) > MAX_CONVERSATION_HISTORY * 2:
        conversation_histories[session_id] = conversation_histories[session_id][-(MAX_CONVERSATION_HISTORY * 2):]
def clear_history(session_id):
    """Efface l'historique d'une session."""
    conversation_histories[session_id] = []
# ======================================================================
# CALL GEMINI
# ======================================================================
def call_gemini(rag_prompt, system_prompt, gemini_client):
    """Appelle Google Gemini."""
    try:
        response = gemini_client.models.generate_content(
            model=GEMINI_MODEL,
            contents=f"{system_prompt}\n\n{rag_prompt}"
        )
        return response.text.replace("*", "")
    except Exception as e:
        print(f"❌ Erreur Gemini: {e}")
        return f"Erreur: {str(e)}"
# ======================================================================
# ANSWER PROCESS
# ======================================================================
def get_answer(query_text, collection, model_paraphrase, model_cross_encoder, conversation_history):
    """Exécute le processus RAG complet."""
    print(f"\n{'='*50}")
    print(f"🚀 Traitement: '{query_text}'")
    print(f"{'='*50}")
    
    df_results = retrieve_and_rerank(query_text, collection, model_paraphrase, model_cross_encoder)
    final_prompt = generate_rag_prompt(query_text, df_results, conversation_history)
    
    # On retourne le prompt final RAG pour référence, mais l'appel Gemini est fait après
    return final_prompt
# ======================================================================
# INITIALISATION GLOBALE
# ======================================================================
def initialize_global_resources():
    """Initialise tous les modèles et ressources."""
    global model_cross_encoder, model_paraphrase, collection, system_prompt, gemini_client
    
    print("\n" + "="*50)
    print("⚙️  INITIALISATION RAG")
    print("="*50)
    
    # Le répertoire /tmp est géré par la variable CHROMA_DB_PATH
    
    try:
        model_cross_encoder, model_paraphrase = load_models()
        df = load_data()
        system_prompt = load_system_prompt()
        gemini_client = initialize_gemini_client()
    except Exception:
        # L'erreur est déjà print dans les fonctions de chargement
        return False
    
    try:
        print(f"⏳ Initialisation de ChromaDB à l'emplacement: {CHROMA_DB_PATH}")
        # Le PersistentClient créera les fichiers dans le chemin spécifié (maintenant dans /tmp)
        chroma_client = chromadb.PersistentClient(path=CHROMA_DB_PATH)
        collection = setup_chromadb_collection(chroma_client, df, model_paraphrase)
        print("✅ INITIALISATION COMPLÈTE\n")
        return True
    except Exception as e:
        print(f"❌ Erreur lors de l'initialisation de ChromaDB ou du remplissage: {e}")
        return False
# ======================================================================
# FLASK API
# ======================================================================
app = Flask(__name__)
# CORS activé, permet les requêtes depuis n'importe quelle origine
CORS(app) 
@app.route('/status', methods=['GET'])
def api_status():
    """Route de ping pour vérifier l'état de l'API."""
    return jsonify({"status": "everything is good"}), 200
@app.route('/api/get_answer', methods=['POST'])
def api_get_answer():
    """Endpoint principal pour obtenir une réponse."""
    if any(x is None for x in [model_cross_encoder, model_paraphrase, collection, system_prompt, gemini_client]):
        return jsonify({"error": "Ressources non chargées. Veuillez vérifier les logs d'initialisation."}), 500
    
    try:
        data = request.get_json()
        query_text = data.get('query_text')
        session_id = data.get('session_id', 'archive')
        
        if not query_text:
            generic_message = "Problème avec l'API, veuillez réessayer plus tard."
            return jsonify({"error": generic_message}), 500
        
        # Récupère historique
        history = get_conversation_history(session_id)
        
        # Génère prompt RAG
        rag_prompt = get_answer(query_text, collection, model_paraphrase, model_cross_encoder, history)
        
        # Appelle Gemini
        response = call_gemini(rag_prompt, system_prompt, gemini_client)
        
        # Sauvegarde réponse
        add_to_history(session_id, "user", query_text)
        add_to_history(session_id, "assistant", response)
        
        return jsonify({"generated_response": response})
    
    except Exception as e:
        print(f"❌ Erreur générale de l'API: {e}")
        generic_message = "Problème avec l'API, veuillez réessayer plus tard."
        return jsonify({"error": generic_message}), 500
@app.route('/api/clear_history', methods=['POST'])
def api_clear_history():
    """Efface l'historique d'une session."""
    try:
        data = request.get_json()
        session_id = data.get('session_id', 'archive')
        clear_history(session_id)
        
        return jsonify({"message": f"Historique effacé: {session_id}"})
    except Exception as e:
        generic_message = "Problème avec l'API, veuillez réessayer plus tard."
        return jsonify({"error": generic_message}), 500
# ======================================================================
# MAIN
# ======================================================================
if __name__ == '__main__':
    print("start app.py")
    if initialize_global_resources():
        
        # Récupération de l'adresse IP si possible (pour l'affichage)
        try:
            import socket
            s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
            s.connect(("8.8.8.8", 80)) # Connecte à un serveur externe pour trouver l'IP locale utilisée
            local_ip = s.getsockname()[0]
            s.close()
        except Exception:
            local_ip = "127.0.0.1" # Fallback si échec
        
        print("\n" + "="*50)
        print("🌐 SERVEUR DÉMARRÉ")
        print(f"✅ API accessible à l'URL (via l'interface réseau locale): http://{local_ip}:{API_PORT}")
        print(f"✅ Route Status: http://{local_ip}:{API_PORT}/status")
        print(f"💡 N'oubliez pas de configurer 'app_port: 1212' et 'sdk: docker' dans votre README.md !")
        print("="*50 + "\n")
        
        # L'utilisation de host='0.0.0.0' dans app.run() permet l'accès depuis l'extérieur
        app.run(host=API_HOST, port=API_PORT, debug=False)
    else:
        print("❌ Impossible de démarrer le serveur. Veuillez vérifier les logs pour les erreurs d'initialisation.")