Spaces:
Paused
Paused
| import streamlit as st | |
| from typing import List, Optional | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig | |
| from sentence_transformers import SentenceTransformer | |
| import faiss | |
| import numpy as np | |
| from threading import Thread | |
| import os | |
| import zipfile | |
| import tempfile | |
| from io import BytesIO | |
| import logging | |
| from datetime import datetime | |
| import json | |
| import hashlib | |
| import base64 | |
| # --- CONFIGURATION PAGE --- | |
| st.set_page_config( | |
| page_title="Kibali AI - Assistant IA du Gabon", | |
| page_icon="kibali_logo.svg", | |
| layout="wide" | |
| ) | |
| # --- CONFIGURATION LOGGING --- | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # --- CHARGER LE LOGO --- | |
| with open("kibali_logo.svg", "rb") as f: | |
| logo_data = f.read() | |
| logo_base64 = base64.b64encode(logo_data).decode() | |
| # --- CONFIGURATION PDF --- | |
| try: | |
| from pypdf import PdfReader as PypdfReader | |
| PDF_READER = "pypdf" | |
| except ImportError: | |
| try: | |
| import PyPDF2 | |
| from PyPDF2 import PdfReader as PypdfReader | |
| PDF_READER = "PyPDF2" | |
| except ImportError: | |
| st.error("Installe pypdf ou PyPDF2 : pip install pypdf") | |
| # --- OUTILS PERSONNALISÉS --- | |
| from tools.web import web_search | |
| from tools.todo import execute_reflection_plan | |
| from tools.geo import get_geo_context | |
| # --- CHARGEMENT DES MODÈLES --- | |
| HF_MODEL_ID = "C:/Users/Admin/Desktop/logiciel/kibali-api/qwen_model" | |
| CACHE_DIR = "/data/cache" | |
| os.makedirs(CACHE_DIR, exist_ok=True) | |
| def load_embed_model(): | |
| logger.info("Chargement du modèle d'embedding...") | |
| return SentenceTransformer('paraphrase-multilingual-MiniLM-L12-v2', cache_folder=CACHE_DIR) | |
| def load_tokenizer(): | |
| logger.info(f"Chargement du tokenizer depuis {HF_MODEL_ID}") | |
| tokenizer = AutoTokenizer.from_pretrained(HF_MODEL_ID, cache_dir=CACHE_DIR, trust_remote_code=False) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| return tokenizer | |
| def load_model(): | |
| bnb_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_use_double_quant=True, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_compute_dtype=torch.float16 | |
| ) | |
| try: | |
| model = AutoModelForCausalLM.from_pretrained( | |
| HF_MODEL_ID, | |
| quantization_config=bnb_config, | |
| device_map="auto", | |
| torch_dtype=torch.float16, | |
| trust_remote_code=False, | |
| low_cpu_mem_usage=True, | |
| cache_dir=CACHE_DIR | |
| ) | |
| logger.info(f"Modèle chargé avec succès sur {model.device}") | |
| return model | |
| except Exception as e: | |
| logger.error(f"Erreur lors du chargement du modèle : {e}") | |
| st.error(f"Erreur chargement modèle : {e}") | |
| return None | |
| embed_model = load_embed_model() | |
| tokenizer = load_tokenizer() | |
| model = load_model() | |
| # --- BASES VECTORIELLES --- | |
| dimension = 384 | |
| if 'doc_index' not in st.session_state: | |
| st.session_state.doc_index = faiss.IndexFlatL2(dimension) | |
| st.session_state.doc_chunks = [] | |
| st.session_state.doc_metadata = [] | |
| if 'memory_index' not in st.session_state: | |
| st.session_state.memory_index = faiss.IndexFlatL2(dimension) | |
| st.session_state.memory_texts = [] | |
| st.session_state.memory_metadata = [] | |
| # --- GESTION DU CONTEXTE CONVERSATIONNEL --- | |
| class ConversationContext: | |
| def __init__(self): | |
| self.current_subject = None | |
| self.subject_embedding = None | |
| self.subject_start_time = None | |
| self.message_count = 0 | |
| self.subject_keywords = [] | |
| def update_subject(self, message: str, embedding: np.ndarray): | |
| keywords = self._extract_keywords(message) | |
| if self.subject_embedding is not None: | |
| similarity = np.dot(embedding.flatten(), self.subject_embedding.flatten()) | |
| if similarity < 0.4: # Seuil abaissé pour maintenir le sujet plus longtemps | |
| logger.info(f"Changement de sujet détecté (similarité: {similarity:.2f})") | |
| self._archive_current_subject() | |
| self.current_subject = message | |
| self.subject_embedding = embedding | |
| self.subject_start_time = datetime.now() | |
| self.message_count = 1 | |
| self.subject_keywords = keywords | |
| else: | |
| self.message_count += 1 | |
| self.subject_keywords.extend(keywords) | |
| self.subject_keywords = list(set(self.subject_keywords))[:10] | |
| else: | |
| self.current_subject = message | |
| self.subject_embedding = embedding | |
| self.subject_start_time = datetime.now() | |
| self.message_count = 1 | |
| self.subject_keywords = keywords | |
| def _extract_keywords(self, text: str) -> List[str]: | |
| stopwords = {'le', 'la', 'les', 'un', 'une', 'des', 'de', 'du', 'et', 'ou', 'est', 'sont', 'à', 'au', 'en', 'pour', 'dans', 'sur', 'avec'} | |
| words = text.lower().split() | |
| keywords = [w for w in words if len(w) > 3 and w not in stopwords] | |
| return keywords[:5] | |
| def _archive_current_subject(self): | |
| if self.current_subject and st.session_state.memory_index.ntotal > 0: | |
| summary = { | |
| "subject": self.current_subject[:200], | |
| "keywords": self.subject_keywords, | |
| "message_count": self.message_count, | |
| "duration": (datetime.now() - self.subject_start_time).seconds, | |
| "archived_at": datetime.now().isoformat() | |
| } | |
| logger.info(f"Sujet archivé: {summary['keywords']}") | |
| if 'conversation_ctx' not in st.session_state: | |
| st.session_state.conversation_ctx = ConversationContext() | |
| # --- UTILITAIRES --- | |
| def extract_text_from_pdf(pdf_bytes: bytes) -> str: | |
| text = "" | |
| try: | |
| pdf_file = BytesIO(pdf_bytes) | |
| reader = PypdfReader(pdf_file) | |
| for page in reader.pages: | |
| page_text = page.extract_text() | |
| if page_text: | |
| text += page_text + "\n" | |
| return text.strip() | |
| except Exception as e: | |
| logger.error(f"Erreur extraction PDF : {e}") | |
| return "" | |
| def process_zip(zip_file) -> List[tuple]: | |
| text_files = [] | |
| with tempfile.TemporaryDirectory() as temp_dir: | |
| with zipfile.ZipFile(zip_file, 'r') as zip_ref: | |
| zip_ref.extractall(temp_dir) | |
| for root, dirs, files in os.walk(temp_dir): | |
| for file in files: | |
| filepath = os.path.join(root, file) | |
| if file.endswith(('.txt', '.md', '.py', '.js', '.html', '.css', '.json', '.xml', '.yaml', '.yml')): | |
| try: | |
| with open(filepath, 'r', encoding='utf-8') as f: | |
| content = f.read() | |
| text_files.append((file, content)) | |
| except: | |
| pass # skip binary or encoding errors | |
| return text_files | |
| def chunk_text(text: str, chunk_size: int = 400, overlap: int = 50) -> List[str]: | |
| if not text.strip(): | |
| return [] | |
| words = text.split() | |
| chunks = [] | |
| i = 0 | |
| while i < len(words): | |
| chunk_words = words[i:i + chunk_size] | |
| chunk = " ".join(chunk_words) | |
| if chunk.strip(): | |
| chunks.append(chunk.strip()) | |
| i += chunk_size - overlap | |
| if i >= len(words) and len(chunk_words) < overlap: | |
| break | |
| return chunks | |
| def add_to_memory_realtime(user_msg: str, ai_response: str, subject_keywords: List[str]): | |
| timestamp = datetime.now().isoformat() | |
| memory_entry = f"""[{timestamp}] Sujet: {', '.join(subject_keywords)} Utilisateur: {user_msg} Kibali: {ai_response}""" | |
| metadata = { | |
| "timestamp": timestamp, | |
| "subject_keywords": subject_keywords, | |
| "user_length": len(user_msg), | |
| "ai_length": len(ai_response), | |
| "hash": hashlib.md5(memory_entry.encode()).hexdigest() | |
| } | |
| if metadata["hash"] not in [m.get("hash") for m in st.session_state.memory_metadata]: | |
| st.session_state.memory_texts.append(memory_entry) | |
| st.session_state.memory_metadata.append(metadata) | |
| mem_emb = embed_model.encode([memory_entry], normalize_embeddings=True).astype('float32') | |
| st.session_state.memory_index.add(mem_emb) | |
| logger.info(f"Mémoire ajoutée en temps réel: {subject_keywords} (total: {len(st.session_state.memory_texts)})") | |
| return True | |
| return False | |
| def retrieve_adaptive_memory(query: str, k: int = 5) -> tuple: | |
| if st.session_state.memory_index.ntotal == 0: | |
| return [], [] | |
| query_emb = embed_model.encode([query], normalize_embeddings=True).astype('float32') | |
| k_search = min(k * 2, st.session_state.memory_index.ntotal) | |
| D, I = st.session_state.memory_index.search(query_emb, k=k_search) | |
| results = [] | |
| for dist, idx in zip(D[0], I[0]): | |
| if 0 <= idx < len(st.session_state.memory_texts): | |
| metadata = st.session_state.memory_metadata[idx] if idx < len(st.session_state.memory_metadata) else {} | |
| recency_score = 1.0 / (1 + (datetime.now() - datetime.fromisoformat(metadata.get("timestamp", datetime.now().isoformat()))).seconds / 3600) | |
| similarity_score = 1.0 / (1 + dist) | |
| keyword_bonus = 0 | |
| if st.session_state.conversation_ctx.subject_keywords: | |
| text_lower = st.session_state.memory_texts[idx].lower() | |
| keyword_bonus = sum(1 for kw in st.session_state.conversation_ctx.subject_keywords if kw in text_lower) * 0.1 | |
| total_score = similarity_score * 0.6 + recency_score * 0.3 + keyword_bonus | |
| results.append({ | |
| "text": st.session_state.memory_texts[idx], | |
| "score": total_score, | |
| "metadata": metadata | |
| }) | |
| results = sorted(results, key=lambda x: x["score"], reverse=True)[:k] | |
| texts = [r["text"] for r in results] | |
| scores = [r["score"] for r in results] | |
| return texts, scores | |
| def generate_response(user_message, geo, thinking_mode, messages_history): | |
| user_emb = embed_model.encode([user_message], normalize_embeddings=True).astype('float32') | |
| st.session_state.conversation_ctx.update_subject(user_message, user_emb) | |
| # RAG Documents PDF | |
| rag_context = "" | |
| rag_sources = [] | |
| if st.session_state.doc_index.ntotal > 0 and len(st.session_state.doc_chunks) > 0: | |
| D, I = st.session_state.doc_index.search(user_emb, k=5) | |
| relevant_chunks = [] | |
| for idx in I[0]: | |
| if 0 <= idx < len(st.session_state.doc_chunks): | |
| relevant_chunks.append(st.session_state.doc_chunks[idx][:1000]) | |
| if idx < len(st.session_state.doc_metadata): | |
| rag_sources.append(st.session_state.doc_metadata[idx].get("source", "PDF")) | |
| if relevant_chunks: | |
| rag_context = "\n\n".join([f"Document : {chunk}" for chunk in relevant_chunks]) | |
| # Mémoire adaptative | |
| memory_context = "" | |
| memory_texts_filtered, memory_scores = retrieve_adaptive_memory(user_message, k=10) | |
| if memory_texts_filtered: | |
| memory_context = "\n\n".join([f"Mémoire (score: {score:.2f}): {text}" for text, score in zip(memory_texts_filtered, memory_scores)]) | |
| # Réflexion stratégique | |
| if thinking_mode: | |
| execute_reflection_plan( | |
| user_message, | |
| geo_info=geo, | |
| messages=messages_history, # Utiliser l'historique complet | |
| current_subject=st.session_state.conversation_ctx.current_subject, | |
| subject_keywords=st.session_state.conversation_ctx.subject_keywords | |
| ) | |
| # Recherche Web | |
| search_query = user_message | |
| if st.session_state.conversation_ctx.subject_keywords: | |
| search_query = f"{user_message} {' '.join(st.session_state.conversation_ctx.subject_keywords[:3])} Gabon" | |
| search_results = web_search(search_query) | |
| web_context = "\n".join([f"- {r['content'][:500]}" for r in search_results.get("results", [])[:6]]) | |
| web_images = search_results.get("images", [])[:4] | |
| # Historique de conversation récent | |
| recent_history = "" | |
| # Supprimé pour éviter les hallucinations, utiliser la mémoire adaptative | |
| # Prompt final | |
| system_prompt = f"""Tu es Kibali, un assistant IA chaleureux, précis et expert du Gabon, basé à {geo['city']}. Réponds toujours en français, de façon naturelle, concise et factuelle. | |
| CONTEXTE CONVERSATIONNEL ACTUEL: | |
| - Sujet en cours: {', '.join(st.session_state.conversation_ctx.subject_keywords) if st.session_state.conversation_ctx.subject_keywords else 'Nouveau sujet'} | |
| - Nombre de messages sur ce sujet: {st.session_state.conversation_ctx.message_count} | |
| PRIORITÉ DES SOURCES: | |
| 1. Documents uploadés (PDF Vault) - Source la plus fiable | |
| 2. Mémoire conversationnelle récente et pertinente | |
| 3. Informations Web actualisées | |
| Si une information vient d'un document uploadé, mentionne-le brièvement. | |
| Adapte-toi aux changements brusques de sujet en restant cohérent, mais reconnais quand le sujet reste le même. | |
| Évite les hallucinations : base-toi uniquement sur les informations fournies. Utilise l'orthographe correcte : Gabon (pays d'Afrique centrale). | |
| Écris en français impeccable : utilise une grammaire correcte, une syntaxe claire, un vocabulaire approprié et des phrases bien construites.""" | |
| full_prompt = f"""### INSTRUCTIONS STRICTES : | |
| {system_prompt} | |
| ### CONTEXTE DOCUMENTS (PDF Vault) : | |
| {rag_context if rag_context else "Aucun document pertinent trouvé."} | |
| ### HISTORIQUE PERTINENT (Mémoire adaptative) : | |
| {memory_context if memory_context else "Pas d'historique pertinent."} | |
| ### INFORMATIONS WEB RÉCENTES : | |
| {web_context if web_context else "Pas d'informations web disponibles."} | |
| ### QUESTION : | |
| {user_message} | |
| ### RÉPONSE (en français uniquement) :""" | |
| inputs = tokenizer(full_prompt, return_tensors="pt", truncation=True, max_length=8192).to(model.device) | |
| streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=120.0) | |
| def generate_stream(): | |
| try: | |
| model.generate( | |
| **inputs, | |
| streamer=streamer, | |
| max_new_tokens=512, | |
| do_sample=True, | |
| temperature=0.3, # Température basse pour qualité | |
| top_p=0.9, | |
| top_k=40, | |
| repetition_penalty=1.1 | |
| ) | |
| except Exception as e: | |
| logger.error(f"Erreur génération : {e}") | |
| thread = Thread(target=generate_stream) | |
| thread.start() | |
| response_text = "" | |
| for new_text in streamer: | |
| if new_text is not None: | |
| response_text += new_text | |
| response_text = response_text.strip() | |
| if response_text: | |
| add_to_memory_realtime( | |
| user_message, | |
| response_text, | |
| st.session_state.conversation_ctx.subject_keywords | |
| ) | |
| context_info = { | |
| "subject_keywords": st.session_state.conversation_ctx.subject_keywords, | |
| "message_count": st.session_state.conversation_ctx.message_count, | |
| "memory_used": len(memory_texts_filtered), | |
| "rag_sources": list(set(rag_sources)), | |
| "web_results": len(search_results.get("results", [])) | |
| } | |
| return response_text, web_images, context_info | |
| # --- INTERFACE STREAMLIT --- | |
| # Centrer le logo | |
| st.markdown(f""" | |
| <style> | |
| .center {{ | |
| width: 100%; | |
| text-align: center; | |
| }} | |
| </style> | |
| <div class="center"> | |
| <img src="data:image/svg+xml;base64,{logo_base64}" width="200"> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| # Sidebar pour paramètres | |
| st.sidebar.image("kibali_logo.svg", width=150) | |
| st.sidebar.header("Paramètres") | |
| latitude = st.sidebar.number_input("Latitude", value=0.0, format="%.6f") | |
| longitude = st.sidebar.number_input("Longitude", value=0.0, format="%.6f") | |
| city = st.sidebar.text_input("Ville", value="Libreville") | |
| thinking_mode = st.sidebar.checkbox("Mode réflexion", value=True) | |
| # Upload PDFs | |
| st.sidebar.header("Upload PDFs") | |
| uploaded_files = st.sidebar.file_uploader("Sélectionnez des PDFs", type="pdf", accept_multiple_files=True) | |
| if uploaded_files and st.sidebar.button("Traiter PDFs"): | |
| total_added = 0 | |
| processed_files = 0 | |
| for file in uploaded_files: | |
| try: | |
| content = file.read() | |
| text = extract_text_from_pdf(content) | |
| if not text: | |
| st.sidebar.warning(f"Aucun texte extrait de {file.name}") | |
| continue | |
| chunks = chunk_text(text) | |
| if not chunks: | |
| continue | |
| timestamp = datetime.now().isoformat() | |
| for chunk in chunks: | |
| st.session_state.doc_metadata.append({ | |
| "source": file.name, | |
| "timestamp": timestamp, | |
| "length": len(chunk) | |
| }) | |
| embeddings = embed_model.encode(chunks, normalize_embeddings=True).astype('float32') | |
| st.session_state.doc_index.add(embeddings) | |
| st.session_state.doc_chunks.extend(chunks) | |
| total_added += len(chunks) | |
| processed_files += 1 | |
| logger.info(f"Upload réussi : {file.name} → {len(chunks)} chunks ajoutés") | |
| except Exception as e: | |
| logger.error(f"Erreur lors du traitement de {file.name} : {e}") | |
| st.sidebar.error(f"Erreur {file.name}: {e}") | |
| st.sidebar.success(f"{processed_files} fichiers traités, {total_added} chunks ajoutés") | |
| # Upload ZIP projet | |
| st.sidebar.header("Upload Projet ZIP") | |
| uploaded_zip = st.sidebar.file_uploader("Sélectionnez un ZIP de projet", type="zip") | |
| if uploaded_zip and st.sidebar.button("Traiter ZIP"): | |
| text_files = process_zip(uploaded_zip) | |
| total_added = 0 | |
| processed_files = 0 | |
| for filename, text in text_files: | |
| chunks = chunk_text(text) | |
| if not chunks: | |
| continue | |
| timestamp = datetime.now().isoformat() | |
| for chunk in chunks: | |
| st.session_state.doc_metadata.append({ | |
| "source": filename, | |
| "timestamp": timestamp, | |
| "length": len(chunk) | |
| }) | |
| embeddings = embed_model.encode(chunks, normalize_embeddings=True).astype('float32') | |
| st.session_state.doc_index.add(embeddings) | |
| st.session_state.doc_chunks.extend(chunks) | |
| total_added += len(chunks) | |
| processed_files += 1 | |
| logger.info(f"Upload réussi : {filename} → {len(chunks)} chunks ajoutés") | |
| st.sidebar.success(f"{processed_files} fichiers traités, {total_added} chunks ajoutés") | |
| # Clear memory | |
| if st.sidebar.button("Effacer mémoire"): | |
| st.session_state.memory_index = faiss.IndexFlatL2(dimension) | |
| st.session_state.memory_texts = [] | |
| st.session_state.memory_metadata = [] | |
| st.session_state.conversation_ctx = ConversationContext() | |
| st.sidebar.success("Mémoire effacée") | |
| # Status | |
| st.sidebar.header("Statut") | |
| st.sidebar.write(f"Chunks docs: {len(st.session_state.doc_chunks)}") | |
| st.sidebar.write(f"Entrées mémoire: {len(st.session_state.memory_texts)}") | |
| st.sidebar.write(f"CUDA: {torch.cuda.is_available()}") | |
| st.sidebar.write(f"Sujet actuel: {st.session_state.conversation_ctx.current_subject[:50] if st.session_state.conversation_ctx.current_subject else 'Aucun'}") | |
| st.sidebar.write(f"Messages sujet: {st.session_state.conversation_ctx.message_count}") | |
| # Chat | |
| if "messages" not in st.session_state: | |
| st.session_state.messages = [] | |
| for message in st.session_state.messages: | |
| with st.chat_message(message["role"], avatar="kibali_logo.svg" if message["role"] == "assistant" else None): | |
| st.markdown(message["content"]) | |
| if "images" in message and message["images"]: | |
| for img_url in message["images"]: | |
| st.image(img_url) | |
| if prompt := st.chat_input("Posez votre question..."): | |
| st.session_state.messages.append({"role": "user", "content": prompt}) | |
| with st.chat_message("user"): | |
| st.markdown(prompt) | |
| geo = {"latitude": latitude, "longitude": longitude, "city": city} | |
| with st.chat_message("assistant", avatar="kibali_logo.svg"): | |
| with st.spinner("Génération de la réponse..."): | |
| response, images, context_info = generate_response(prompt, geo, thinking_mode, st.session_state.messages) | |
| st.markdown(response) | |
| if images: | |
| for img_url in images: | |
| st.image(img_url) | |
| st.session_state.messages.append({"role": "assistant", "content": response, "images": images, "context": context_info}) |