Spaces:
Paused
Paused
| from fastapi import FastAPI, UploadFile, File, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.staticfiles import StaticFiles | |
| from pydantic import BaseModel | |
| from typing import List, Optional | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig | |
| from sentence_transformers import SentenceTransformer | |
| import faiss | |
| import numpy as np | |
| from threading import Thread | |
| import os | |
| from io import BytesIO | |
| import logging | |
| from datetime import datetime | |
| import json | |
| import hashlib | |
| # --- CONFIGURATION LOGGING --- | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # --- CONFIGURATION PDF --- | |
| try: | |
| from pypdf import PdfReader as PypdfReader | |
| PDF_READER = "pypdf" | |
| except ImportError: | |
| try: | |
| import PyPDF2 | |
| from PyPDF2 import PdfReader as PypdfReader | |
| PDF_READER = "PyPDF2" | |
| except ImportError: | |
| raise ImportError("Installe pypdf ou PyPDF2 : pip install pypdf") | |
| # --- OUTILS PERSONNALISÉS --- | |
| from tools.web import web_search | |
| from tools.todo import execute_reflection_plan | |
| from tools.geo import get_geo_context | |
| app = FastAPI(title="Kibali AI API", version="1.0") | |
| # --- SERVEUR STATIQUE --- | |
| script_dir = os.path.dirname(os.path.abspath(__file__)) | |
| static_dir = os.path.join(script_dir, "static") | |
| os.makedirs(static_dir, exist_ok=True) | |
| app.mount("/static", StaticFiles(directory=static_dir), name="static") | |
| # --- CORS --- | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # --- CHARGEMENT DES MODÈLES (téléchargement depuis Hugging Face Hub) --- | |
| HF_MODEL_ID = "BelikanM/kibali-final-merged" | |
| CACHE_DIR = "/data/cache" # Dossier persistant sur HF Spaces | |
| os.makedirs(CACHE_DIR, exist_ok=True) | |
| logger.info("Chargement du modèle d'embedding...") | |
| embed_model = SentenceTransformer( | |
| 'paraphrase-multilingual-MiniLM-L12-v2', | |
| cache_folder=CACHE_DIR | |
| ) | |
| logger.info(f"Chargement du tokenizer et du modèle LLM depuis Hugging Face : {HF_MODEL_ID}") | |
| tokenizer = AutoTokenizer.from_pretrained(HF_MODEL_ID, cache_dir=CACHE_DIR) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| # Configuration 4-bit pour réduire la consommation VRAM | |
| bnb_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_use_double_quant=True, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_compute_dtype=torch.float16 | |
| ) | |
| try: | |
| model = AutoModelForCausalLM.from_pretrained( | |
| HF_MODEL_ID, | |
| quantization_config=bnb_config, | |
| device_map="auto", | |
| torch_dtype=torch.float16, | |
| trust_remote_code=True, | |
| low_cpu_mem_usage=True, | |
| cache_dir=CACHE_DIR | |
| ) | |
| logger.info(f"Modèle chargé avec succès sur {model.device}") | |
| except Exception as e: | |
| logger.error(f"Erreur lors du chargement du modèle : {e}") | |
| raise e | |
| # --- BASES VECTORIELLES GLOBALES --- | |
| dimension = 384 | |
| doc_index = faiss.IndexFlatL2(dimension) | |
| doc_chunks: List[str] = [] | |
| doc_metadata: List[dict] = [] | |
| memory_index = faiss.IndexFlatL2(dimension) | |
| memory_texts: List[str] = [] | |
| memory_metadata: List[dict] = [] | |
| # --- GESTION DU CONTEXTE CONVERSATIONNEL --- | |
| class ConversationContext: | |
| def __init__(self): | |
| self.current_subject = None | |
| self.subject_embedding = None | |
| self.subject_start_time = None | |
| self.message_count = 0 | |
| self.subject_keywords = [] | |
| def update_subject(self, message: str, embedding: np.ndarray): | |
| keywords = self._extract_keywords(message) | |
| if self.subject_embedding is not None: | |
| similarity = np.dot(embedding.flatten(), self.subject_embedding.flatten()) | |
| if similarity < 0.6: | |
| logger.info(f"Changement de sujet détecté (similarité: {similarity:.2f})") | |
| self._archive_current_subject() | |
| self.current_subject = message | |
| self.subject_embedding = embedding | |
| self.subject_start_time = datetime.now() | |
| self.message_count = 1 | |
| self.subject_keywords = keywords | |
| else: | |
| self.message_count += 1 | |
| self.subject_keywords.extend(keywords) | |
| self.subject_keywords = list(set(self.subject_keywords))[:10] | |
| else: | |
| self.current_subject = message | |
| self.subject_embedding = embedding | |
| self.subject_start_time = datetime.now() | |
| self.message_count = 1 | |
| self.subject_keywords = keywords | |
| def _extract_keywords(self, text: str) -> List[str]: | |
| stopwords = {'le', 'la', 'les', 'un', 'une', 'des', 'de', 'du', 'et', 'ou', | |
| 'est', 'sont', 'à', 'au', 'en', 'pour', 'dans', 'sur', 'avec'} | |
| words = text.lower().split() | |
| keywords = [w for w in words if len(w) > 3 and w not in stopwords] | |
| return keywords[:5] | |
| def _archive_current_subject(self): | |
| if self.current_subject and memory_index.ntotal > 0: | |
| summary = { | |
| "subject": self.current_subject[:200], | |
| "keywords": self.subject_keywords, | |
| "message_count": self.message_count, | |
| "duration": (datetime.now() - self.subject_start_time).seconds, | |
| "archived_at": datetime.now().isoformat() | |
| } | |
| logger.info(f"Sujet archivé: {summary['keywords']}") | |
| conversation_ctx = ConversationContext() | |
| # --- MODÈLES PYDANTIC --- | |
| class Message(BaseModel): | |
| role: str | |
| content: str | |
| class ChatRequest(BaseModel): | |
| messages: List[Message] | |
| latitude: float | |
| longitude: float | |
| city: Optional[str] = "Libreville" | |
| thinking_mode: bool = True | |
| class ChatResponse(BaseModel): | |
| response: str | |
| images: List[str] = [] | |
| context_info: Optional[dict] = None | |
| # --- UTILITAIRES --- | |
| def extract_text_from_pdf(pdf_bytes: bytes) -> str: | |
| text = "" | |
| try: | |
| pdf_file = BytesIO(pdf_bytes) | |
| reader = PypdfReader(pdf_file) | |
| for page in reader.pages: | |
| page_text = page.extract_text() | |
| if page_text: | |
| text += page_text + "\n" | |
| return text.strip() | |
| except Exception as e: | |
| logger.error(f"Erreur extraction PDF : {e}") | |
| return "" | |
| def chunk_text(text: str, chunk_size: int = 400, overlap: int = 50) -> List[str]: | |
| if not text.strip(): | |
| return [] | |
| words = text.split() | |
| chunks = [] | |
| i = 0 | |
| while i < len(words): | |
| chunk_words = words[i:i + chunk_size] | |
| chunk = " ".join(chunk_words) | |
| if chunk.strip(): | |
| chunks.append(chunk.strip()) | |
| i += chunk_size - overlap | |
| if i >= len(words) and len(chunk_words) < overlap: | |
| break | |
| return chunks | |
| def add_to_memory_realtime(user_msg: str, ai_response: str, subject_keywords: List[str]): | |
| timestamp = datetime.now().isoformat() | |
| memory_entry = f"""[{timestamp}] | |
| Sujet: {', '.join(subject_keywords)} | |
| Utilisateur: {user_msg} | |
| Kibali: {ai_response}""" | |
| metadata = { | |
| "timestamp": timestamp, | |
| "subject_keywords": subject_keywords, | |
| "user_length": len(user_msg), | |
| "ai_length": len(ai_response), | |
| "hash": hashlib.md5(memory_entry.encode()).hexdigest() | |
| } | |
| if metadata["hash"] not in [m.get("hash") for m in memory_metadata]: | |
| memory_texts.append(memory_entry) | |
| memory_metadata.append(metadata) | |
| mem_emb = embed_model.encode([memory_entry], normalize_embeddings=True).astype('float32') | |
| memory_index.add(mem_emb) | |
| logger.info(f"Mémoire ajoutée en temps réel: {subject_keywords} (total: {len(memory_texts)})") | |
| return True | |
| return False | |
| def retrieve_adaptive_memory(query: str, k: int = 5) -> tuple: | |
| if memory_index.ntotal == 0: | |
| return [], [] | |
| query_emb = embed_model.encode([query], normalize_embeddings=True).astype('float32') | |
| k_search = min(k * 2, memory_index.ntotal) | |
| D, I = memory_index.search(query_emb, k=k_search) | |
| results = [] | |
| for dist, idx in zip(D[0], I[0]): | |
| if 0 <= idx < len(memory_texts): | |
| metadata = memory_metadata[idx] if idx < len(memory_metadata) else {} | |
| recency_score = 1.0 / (1 + (datetime.now() - datetime.fromisoformat(metadata.get("timestamp", datetime.now().isoformat()))).seconds / 3600) | |
| similarity_score = 1.0 / (1 + dist) | |
| keyword_bonus = 0 | |
| if conversation_ctx.subject_keywords: | |
| text_lower = memory_texts[idx].lower() | |
| keyword_bonus = sum(1 for kw in conversation_ctx.subject_keywords if kw in text_lower) * 0.1 | |
| total_score = similarity_score * 0.6 + recency_score * 0.3 + keyword_bonus | |
| results.append({ | |
| "text": memory_texts[idx], | |
| "score": total_score, | |
| "metadata": metadata | |
| }) | |
| results = sorted(results, key=lambda x: x["score"], reverse=True)[:k] | |
| texts = [r["text"] for r in results] | |
| scores = [r["score"] for r in results] | |
| return texts, scores | |
| # --- ROUTES --- | |
| async def status(): | |
| return { | |
| "status": "ready", | |
| "doc_chunks": len(doc_chunks), | |
| "memory_entries": len(memory_texts), | |
| "pdf_library": PDF_READER, | |
| "model_device": str(model.device), | |
| "torch_cuda_available": torch.cuda.is_available(), | |
| "current_subject": conversation_ctx.current_subject[:100] if conversation_ctx.current_subject else None, | |
| "subject_message_count": conversation_ctx.message_count | |
| } | |
| async def chat(request: ChatRequest): | |
| user_message = request.messages[-1].content.strip() | |
| if not user_message: | |
| raise HTTPException(status_code=400, detail="Message vide") | |
| geo = { | |
| "latitude": request.latitude, | |
| "longitude": request.longitude, | |
| "city": request.city or "Libreville" | |
| } | |
| user_emb = embed_model.encode([user_message], normalize_embeddings=True).astype('float32') | |
| conversation_ctx.update_subject(user_message, user_emb) | |
| # RAG Documents PDF | |
| rag_context = "" | |
| rag_sources = [] | |
| if doc_index.ntotal > 0 and len(doc_chunks) > 0: | |
| D, I = doc_index.search(user_emb, k=5) | |
| relevant_chunks = [] | |
| for idx in I[0]: | |
| if 0 <= idx < len(doc_chunks): | |
| relevant_chunks.append(doc_chunks[idx][:1000]) | |
| if idx < len(doc_metadata): | |
| rag_sources.append(doc_metadata[idx].get("source", "PDF")) | |
| if relevant_chunks: | |
| rag_context = "\n\n".join([f"Document : {chunk}" for chunk in relevant_chunks]) | |
| # Mémoire adaptative | |
| memory_context = "" | |
| memory_texts_filtered, memory_scores = retrieve_adaptive_memory(user_message, k=5) | |
| if memory_texts_filtered: | |
| memory_context = "\n\n".join([f"Mémoire (score: {score:.2f}): {text}" | |
| for text, score in zip(memory_texts_filtered, memory_scores)]) | |
| # Réflexion stratégique | |
| if request.thinking_mode: | |
| execute_reflection_plan( | |
| user_message, | |
| geo_info=geo, | |
| messages=request.messages, | |
| current_subject=conversation_ctx.current_subject, | |
| subject_keywords=conversation_ctx.subject_keywords | |
| ) | |
| # Recherche Web | |
| search_query = user_message | |
| if conversation_ctx.subject_keywords: | |
| search_query = f"{user_message} {' '.join(conversation_ctx.subject_keywords[:3])} Gabon" | |
| search_results = web_search(search_query) | |
| web_context = "\n".join([f"- {r['content'][:500]}" for r in search_results.get("results", [])[:6]]) | |
| web_images = search_results.get("images", [])[:4] | |
| # Prompt final | |
| system_prompt = f"""Tu es Kibali, un assistant IA chaleureux, précis et expert du Gabon, basé à {geo['city']}. | |
| Réponds toujours en français, de façon naturelle, concise et factuelle. | |
| CONTEXTE CONVERSATIONNEL ACTUEL: | |
| - Sujet en cours: {', '.join(conversation_ctx.subject_keywords) if conversation_ctx.subject_keywords else 'Nouveau sujet'} | |
| - Nombre de messages sur ce sujet: {conversation_ctx.message_count} | |
| PRIORITÉ DES SOURCES: | |
| 1. Documents uploadés (PDF Vault) - Source la plus fiable | |
| 2. Mémoire conversationnelle récente et pertinente | |
| 3. Informations Web actualisées | |
| Si une information vient d'un document uploadé, mentionne-le brièvement. | |
| Adapte-toi aux changements brusques de sujet en restant cohérent.""" | |
| full_prompt = f"""### INSTRUCTIONS STRICTES : | |
| {system_prompt} | |
| ### CONTEXTE DOCUMENTS (PDF Vault) : | |
| {rag_context if rag_context else "Aucun document pertinent trouvé."} | |
| ### HISTORIQUE PERTINENT (Mémoire adaptative) : | |
| {memory_context if memory_context else "Pas d'historique pertinent."} | |
| ### INFORMATIONS WEB RÉCENTES : | |
| {web_context if web_context else "Pas d'informations web disponibles."} | |
| ### QUESTION : | |
| {user_message} | |
| ### RÉPONSE (en français uniquement) : | |
| """ | |
| inputs = tokenizer(full_prompt, return_tensors="pt", truncation=True, max_length=8192).to(model.device) | |
| streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=120.0) | |
| def generate_stream(): | |
| try: | |
| model.generate( | |
| **inputs, | |
| streamer=streamer, | |
| max_new_tokens=1024, | |
| temperature=0.6, | |
| do_sample=True, | |
| top_p=0.85, | |
| top_k=50, | |
| repetition_penalty=1.2, | |
| length_penalty=0.8 | |
| ) | |
| except Exception as e: | |
| logger.error(f"Erreur génération : {e}") | |
| thread = Thread(target=generate_stream) | |
| thread.start() | |
| response_text = "" | |
| for new_text in streamer: | |
| if new_text is not None: | |
| response_text += new_text | |
| response_text = response_text.strip() | |
| if response_text: | |
| add_to_memory_realtime( | |
| user_message, | |
| response_text, | |
| conversation_ctx.subject_keywords | |
| ) | |
| context_info = { | |
| "subject_keywords": conversation_ctx.subject_keywords, | |
| "message_count": conversation_ctx.message_count, | |
| "memory_used": len(memory_texts_filtered), | |
| "rag_sources": list(set(rag_sources)), | |
| "web_results": len(search_results.get("results", [])) | |
| } | |
| return ChatResponse(response=response_text, images=web_images, context_info=context_info) | |
| async def upload(files: List[UploadFile] = File(...)): | |
| total_added = 0 | |
| processed_files = 0 | |
| for file in files: | |
| if not file.filename.lower().endswith(".pdf"): | |
| continue | |
| try: | |
| content = await file.read() | |
| text = extract_text_from_pdf(content) | |
| if not text: | |
| logger.warning(f"Aucun texte extrait de {file.filename}") | |
| continue | |
| chunks = chunk_text(text) | |
| if not chunks: | |
| continue | |
| timestamp = datetime.now().isoformat() | |
| for chunk in chunks: | |
| doc_metadata.append({ | |
| "source": file.filename, | |
| "timestamp": timestamp, | |
| "length": len(chunk) | |
| }) | |
| embeddings = embed_model.encode(chunks, normalize_embeddings=True).astype('float32') | |
| doc_index.add(embeddings) | |
| doc_chunks.extend(chunks) | |
| total_added += len(chunks) | |
| processed_files += 1 | |
| logger.info(f"Upload réussi : {file.filename} → {len(chunks)} chunks ajoutés") | |
| except Exception as e: | |
| logger.error(f"Erreur lors du traitement de {file.filename} : {e}") | |
| return { | |
| "status": "success", | |
| "files_processed": processed_files, | |
| "chunks_added": total_added, | |
| "total_doc_chunks": len(doc_chunks) | |
| } | |
| async def upload_pdfs(files: List[UploadFile] = File(...)): | |
| return await upload(files) | |
| async def clear_memory(): | |
| global memory_index, memory_texts, memory_metadata | |
| memory_index = faiss.IndexFlatL2(dimension) | |
| memory_texts = [] | |
| memory_metadata = [] | |
| conversation_ctx.__init__() | |
| return {"status": "memory_cleared", "message": "Mémoire conversationnelle effacée"} | |
| # --- DEMARRAGE --- | |
| async def startup_event(): | |
| logger.info("🚀 Kibali AI API démarrée avec succès sur Hugging Face Spaces !") | |
| logger.info(f"Accès : https://your-username-your-space.hf.space | Docs : /docs") | |
| logger.info(f"Mémoire adaptative et réflexion contextuelle activées ✓") |