|
|
from fastapi import FastAPI, UploadFile, File, HTTPException
|
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
|
from fastapi.staticfiles import StaticFiles
|
|
|
from pydantic import BaseModel
|
|
|
from typing import List, Optional
|
|
|
import torch
|
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig
|
|
|
from sentence_transformers import SentenceTransformer
|
|
|
import faiss
|
|
|
import numpy as np
|
|
|
from threading import Thread
|
|
|
import os
|
|
|
from io import BytesIO
|
|
|
import logging
|
|
|
from datetime import datetime
|
|
|
import json
|
|
|
import hashlib
|
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO)
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
try:
|
|
|
from pypdf import PdfReader as PypdfReader
|
|
|
PDF_READER = "pypdf"
|
|
|
except ImportError:
|
|
|
try:
|
|
|
import PyPDF2
|
|
|
from PyPDF2 import PdfReader as PypdfReader
|
|
|
PDF_READER = "PyPDF2"
|
|
|
except ImportError:
|
|
|
raise ImportError("Installe pypdf ou PyPDF2 : pip install pypdf")
|
|
|
|
|
|
|
|
|
from tools.web import web_search
|
|
|
from tools.todo import execute_reflection_plan
|
|
|
from tools.geo import get_geo_context
|
|
|
|
|
|
app = FastAPI(title="Kibali AI API", version="1.0")
|
|
|
|
|
|
|
|
|
script_dir = os.path.dirname(os.path.abspath(__file__))
|
|
|
static_dir = os.path.join(script_dir, "static")
|
|
|
os.makedirs(static_dir, exist_ok=True)
|
|
|
app.mount("/static", StaticFiles(directory=static_dir), name="static")
|
|
|
|
|
|
|
|
|
app.add_middleware(
|
|
|
CORSMiddleware,
|
|
|
allow_origins=["*"],
|
|
|
allow_credentials=True,
|
|
|
allow_methods=["*"],
|
|
|
allow_headers=["*"],
|
|
|
)
|
|
|
|
|
|
|
|
|
HF_MODEL_ID = "BelikanM/kibali-final-merged"
|
|
|
CACHE_DIR = "/data/cache"
|
|
|
|
|
|
os.makedirs(CACHE_DIR, exist_ok=True)
|
|
|
|
|
|
logger.info("Chargement du modèle d'embedding...")
|
|
|
embed_model = SentenceTransformer(
|
|
|
'paraphrase-multilingual-MiniLM-L12-v2',
|
|
|
cache_folder=CACHE_DIR
|
|
|
)
|
|
|
|
|
|
logger.info(f"Chargement du tokenizer et du modèle LLM depuis Hugging Face : {HF_MODEL_ID}")
|
|
|
tokenizer = AutoTokenizer.from_pretrained(HF_MODEL_ID, cache_dir=CACHE_DIR)
|
|
|
|
|
|
if tokenizer.pad_token is None:
|
|
|
tokenizer.pad_token = tokenizer.eos_token
|
|
|
|
|
|
|
|
|
bnb_config = BitsAndBytesConfig(
|
|
|
load_in_4bit=True,
|
|
|
bnb_4bit_use_double_quant=True,
|
|
|
bnb_4bit_quant_type="nf4",
|
|
|
bnb_4bit_compute_dtype=torch.float16
|
|
|
)
|
|
|
|
|
|
try:
|
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
|
HF_MODEL_ID,
|
|
|
quantization_config=bnb_config,
|
|
|
device_map="auto",
|
|
|
torch_dtype=torch.float16,
|
|
|
trust_remote_code=True,
|
|
|
low_cpu_mem_usage=True,
|
|
|
cache_dir=CACHE_DIR
|
|
|
)
|
|
|
logger.info(f"Modèle chargé avec succès sur {model.device}")
|
|
|
except Exception as e:
|
|
|
logger.error(f"Erreur lors du chargement du modèle : {e}")
|
|
|
raise e
|
|
|
|
|
|
|
|
|
dimension = 384
|
|
|
doc_index = faiss.IndexFlatL2(dimension)
|
|
|
doc_chunks: List[str] = []
|
|
|
doc_metadata: List[dict] = []
|
|
|
|
|
|
memory_index = faiss.IndexFlatL2(dimension)
|
|
|
memory_texts: List[str] = []
|
|
|
memory_metadata: List[dict] = []
|
|
|
|
|
|
|
|
|
class ConversationContext:
|
|
|
def __init__(self):
|
|
|
self.current_subject = None
|
|
|
self.subject_embedding = None
|
|
|
self.subject_start_time = None
|
|
|
self.message_count = 0
|
|
|
self.subject_keywords = []
|
|
|
|
|
|
def update_subject(self, message: str, embedding: np.ndarray):
|
|
|
keywords = self._extract_keywords(message)
|
|
|
|
|
|
if self.subject_embedding is not None:
|
|
|
similarity = np.dot(embedding.flatten(), self.subject_embedding.flatten())
|
|
|
if similarity < 0.6:
|
|
|
logger.info(f"Changement de sujet détecté (similarité: {similarity:.2f})")
|
|
|
self._archive_current_subject()
|
|
|
self.current_subject = message
|
|
|
self.subject_embedding = embedding
|
|
|
self.subject_start_time = datetime.now()
|
|
|
self.message_count = 1
|
|
|
self.subject_keywords = keywords
|
|
|
else:
|
|
|
self.message_count += 1
|
|
|
self.subject_keywords.extend(keywords)
|
|
|
self.subject_keywords = list(set(self.subject_keywords))[:10]
|
|
|
else:
|
|
|
self.current_subject = message
|
|
|
self.subject_embedding = embedding
|
|
|
self.subject_start_time = datetime.now()
|
|
|
self.message_count = 1
|
|
|
self.subject_keywords = keywords
|
|
|
|
|
|
def _extract_keywords(self, text: str) -> List[str]:
|
|
|
stopwords = {'le', 'la', 'les', 'un', 'une', 'des', 'de', 'du', 'et', 'ou',
|
|
|
'est', 'sont', 'à', 'au', 'en', 'pour', 'dans', 'sur', 'avec'}
|
|
|
words = text.lower().split()
|
|
|
keywords = [w for w in words if len(w) > 3 and w not in stopwords]
|
|
|
return keywords[:5]
|
|
|
|
|
|
def _archive_current_subject(self):
|
|
|
if self.current_subject and memory_index.ntotal > 0:
|
|
|
summary = {
|
|
|
"subject": self.current_subject[:200],
|
|
|
"keywords": self.subject_keywords,
|
|
|
"message_count": self.message_count,
|
|
|
"duration": (datetime.now() - self.subject_start_time).seconds,
|
|
|
"archived_at": datetime.now().isoformat()
|
|
|
}
|
|
|
logger.info(f"Sujet archivé: {summary['keywords']}")
|
|
|
|
|
|
conversation_ctx = ConversationContext()
|
|
|
|
|
|
|
|
|
class Message(BaseModel):
|
|
|
role: str
|
|
|
content: str
|
|
|
|
|
|
class ChatRequest(BaseModel):
|
|
|
messages: List[Message]
|
|
|
latitude: float
|
|
|
longitude: float
|
|
|
city: Optional[str] = "Libreville"
|
|
|
thinking_mode: bool = True
|
|
|
|
|
|
class ChatResponse(BaseModel):
|
|
|
response: str
|
|
|
images: List[str] = []
|
|
|
context_info: Optional[dict] = None
|
|
|
|
|
|
|
|
|
def extract_text_from_pdf(pdf_bytes: bytes) -> str:
|
|
|
text = ""
|
|
|
try:
|
|
|
pdf_file = BytesIO(pdf_bytes)
|
|
|
reader = PypdfReader(pdf_file)
|
|
|
for page in reader.pages:
|
|
|
page_text = page.extract_text()
|
|
|
if page_text:
|
|
|
text += page_text + "\n"
|
|
|
return text.strip()
|
|
|
except Exception as e:
|
|
|
logger.error(f"Erreur extraction PDF : {e}")
|
|
|
return ""
|
|
|
|
|
|
def chunk_text(text: str, chunk_size: int = 400, overlap: int = 50) -> List[str]:
|
|
|
if not text.strip():
|
|
|
return []
|
|
|
words = text.split()
|
|
|
chunks = []
|
|
|
i = 0
|
|
|
while i < len(words):
|
|
|
chunk_words = words[i:i + chunk_size]
|
|
|
chunk = " ".join(chunk_words)
|
|
|
if chunk.strip():
|
|
|
chunks.append(chunk.strip())
|
|
|
i += chunk_size - overlap
|
|
|
if i >= len(words) and len(chunk_words) < overlap:
|
|
|
break
|
|
|
return chunks
|
|
|
|
|
|
def add_to_memory_realtime(user_msg: str, ai_response: str, subject_keywords: List[str]):
|
|
|
timestamp = datetime.now().isoformat()
|
|
|
memory_entry = f"""[{timestamp}]
|
|
|
Sujet: {', '.join(subject_keywords)}
|
|
|
Utilisateur: {user_msg}
|
|
|
Kibali: {ai_response}"""
|
|
|
|
|
|
metadata = {
|
|
|
"timestamp": timestamp,
|
|
|
"subject_keywords": subject_keywords,
|
|
|
"user_length": len(user_msg),
|
|
|
"ai_length": len(ai_response),
|
|
|
"hash": hashlib.md5(memory_entry.encode()).hexdigest()
|
|
|
}
|
|
|
|
|
|
if metadata["hash"] not in [m.get("hash") for m in memory_metadata]:
|
|
|
memory_texts.append(memory_entry)
|
|
|
memory_metadata.append(metadata)
|
|
|
mem_emb = embed_model.encode([memory_entry], normalize_embeddings=True).astype('float32')
|
|
|
memory_index.add(mem_emb)
|
|
|
logger.info(f"Mémoire ajoutée en temps réel: {subject_keywords} (total: {len(memory_texts)})")
|
|
|
return True
|
|
|
return False
|
|
|
|
|
|
def retrieve_adaptive_memory(query: str, k: int = 5) -> tuple:
|
|
|
if memory_index.ntotal == 0:
|
|
|
return [], []
|
|
|
|
|
|
query_emb = embed_model.encode([query], normalize_embeddings=True).astype('float32')
|
|
|
k_search = min(k * 2, memory_index.ntotal)
|
|
|
D, I = memory_index.search(query_emb, k=k_search)
|
|
|
|
|
|
results = []
|
|
|
for dist, idx in zip(D[0], I[0]):
|
|
|
if 0 <= idx < len(memory_texts):
|
|
|
metadata = memory_metadata[idx] if idx < len(memory_metadata) else {}
|
|
|
recency_score = 1.0 / (1 + (datetime.now() - datetime.fromisoformat(metadata.get("timestamp", datetime.now().isoformat()))).seconds / 3600)
|
|
|
similarity_score = 1.0 / (1 + dist)
|
|
|
keyword_bonus = 0
|
|
|
if conversation_ctx.subject_keywords:
|
|
|
text_lower = memory_texts[idx].lower()
|
|
|
keyword_bonus = sum(1 for kw in conversation_ctx.subject_keywords if kw in text_lower) * 0.1
|
|
|
total_score = similarity_score * 0.6 + recency_score * 0.3 + keyword_bonus
|
|
|
|
|
|
results.append({
|
|
|
"text": memory_texts[idx],
|
|
|
"score": total_score,
|
|
|
"metadata": metadata
|
|
|
})
|
|
|
|
|
|
results = sorted(results, key=lambda x: x["score"], reverse=True)[:k]
|
|
|
texts = [r["text"] for r in results]
|
|
|
scores = [r["score"] for r in results]
|
|
|
return texts, scores
|
|
|
|
|
|
|
|
|
@app.get("/status")
|
|
|
async def status():
|
|
|
return {
|
|
|
"status": "ready",
|
|
|
"doc_chunks": len(doc_chunks),
|
|
|
"memory_entries": len(memory_texts),
|
|
|
"pdf_library": PDF_READER,
|
|
|
"model_device": str(model.device),
|
|
|
"torch_cuda_available": torch.cuda.is_available(),
|
|
|
"current_subject": conversation_ctx.current_subject[:100] if conversation_ctx.current_subject else None,
|
|
|
"subject_message_count": conversation_ctx.message_count
|
|
|
}
|
|
|
|
|
|
@app.post("/chat", response_model=ChatResponse)
|
|
|
async def chat(request: ChatRequest):
|
|
|
user_message = request.messages[-1].content.strip()
|
|
|
if not user_message:
|
|
|
raise HTTPException(status_code=400, detail="Message vide")
|
|
|
|
|
|
geo = {
|
|
|
"latitude": request.latitude,
|
|
|
"longitude": request.longitude,
|
|
|
"city": request.city or "Libreville"
|
|
|
}
|
|
|
|
|
|
user_emb = embed_model.encode([user_message], normalize_embeddings=True).astype('float32')
|
|
|
conversation_ctx.update_subject(user_message, user_emb)
|
|
|
|
|
|
|
|
|
rag_context = ""
|
|
|
rag_sources = []
|
|
|
if doc_index.ntotal > 0 and len(doc_chunks) > 0:
|
|
|
D, I = doc_index.search(user_emb, k=5)
|
|
|
relevant_chunks = []
|
|
|
for idx in I[0]:
|
|
|
if 0 <= idx < len(doc_chunks):
|
|
|
relevant_chunks.append(doc_chunks[idx][:1000])
|
|
|
if idx < len(doc_metadata):
|
|
|
rag_sources.append(doc_metadata[idx].get("source", "PDF"))
|
|
|
if relevant_chunks:
|
|
|
rag_context = "\n\n".join([f"Document : {chunk}" for chunk in relevant_chunks])
|
|
|
|
|
|
|
|
|
memory_context = ""
|
|
|
memory_texts_filtered, memory_scores = retrieve_adaptive_memory(user_message, k=5)
|
|
|
if memory_texts_filtered:
|
|
|
memory_context = "\n\n".join([f"Mémoire (score: {score:.2f}): {text}"
|
|
|
for text, score in zip(memory_texts_filtered, memory_scores)])
|
|
|
|
|
|
|
|
|
if request.thinking_mode:
|
|
|
execute_reflection_plan(
|
|
|
user_message,
|
|
|
geo_info=geo,
|
|
|
messages=request.messages,
|
|
|
current_subject=conversation_ctx.current_subject,
|
|
|
subject_keywords=conversation_ctx.subject_keywords
|
|
|
)
|
|
|
|
|
|
|
|
|
search_query = user_message
|
|
|
if conversation_ctx.subject_keywords:
|
|
|
search_query = f"{user_message} {' '.join(conversation_ctx.subject_keywords[:3])} Gabon"
|
|
|
|
|
|
search_results = web_search(search_query)
|
|
|
web_context = "\n".join([f"- {r['content'][:500]}" for r in search_results.get("results", [])[:6]])
|
|
|
web_images = search_results.get("images", [])[:4]
|
|
|
|
|
|
|
|
|
system_prompt = f"""Tu es Kibali, un assistant IA chaleureux, précis et expert du Gabon, basé à {geo['city']}.
|
|
|
Réponds toujours en français, de façon naturelle, concise et factuelle.
|
|
|
CONTEXTE CONVERSATIONNEL ACTUEL:
|
|
|
- Sujet en cours: {', '.join(conversation_ctx.subject_keywords) if conversation_ctx.subject_keywords else 'Nouveau sujet'}
|
|
|
- Nombre de messages sur ce sujet: {conversation_ctx.message_count}
|
|
|
PRIORITÉ DES SOURCES:
|
|
|
1. Documents uploadés (PDF Vault) - Source la plus fiable
|
|
|
2. Mémoire conversationnelle récente et pertinente
|
|
|
3. Informations Web actualisées
|
|
|
Si une information vient d'un document uploadé, mentionne-le brièvement.
|
|
|
Adapte-toi aux changements brusques de sujet en restant cohérent."""
|
|
|
|
|
|
full_prompt = f"""### INSTRUCTIONS STRICTES :
|
|
|
{system_prompt}
|
|
|
### CONTEXTE DOCUMENTS (PDF Vault) :
|
|
|
{rag_context if rag_context else "Aucun document pertinent trouvé."}
|
|
|
### HISTORIQUE PERTINENT (Mémoire adaptative) :
|
|
|
{memory_context if memory_context else "Pas d'historique pertinent."}
|
|
|
### INFORMATIONS WEB RÉCENTES :
|
|
|
{web_context if web_context else "Pas d'informations web disponibles."}
|
|
|
### QUESTION :
|
|
|
{user_message}
|
|
|
### RÉPONSE (en français uniquement) :
|
|
|
"""
|
|
|
|
|
|
inputs = tokenizer(full_prompt, return_tensors="pt", truncation=True, max_length=8192).to(model.device)
|
|
|
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=120.0)
|
|
|
|
|
|
def generate_stream():
|
|
|
try:
|
|
|
model.generate(
|
|
|
**inputs,
|
|
|
streamer=streamer,
|
|
|
max_new_tokens=1024,
|
|
|
temperature=0.6,
|
|
|
do_sample=True,
|
|
|
top_p=0.85,
|
|
|
top_k=50,
|
|
|
repetition_penalty=1.2,
|
|
|
length_penalty=0.8
|
|
|
)
|
|
|
except Exception as e:
|
|
|
logger.error(f"Erreur génération : {e}")
|
|
|
|
|
|
thread = Thread(target=generate_stream)
|
|
|
thread.start()
|
|
|
|
|
|
response_text = ""
|
|
|
for new_text in streamer:
|
|
|
if new_text is not None:
|
|
|
response_text += new_text
|
|
|
response_text = response_text.strip()
|
|
|
|
|
|
if response_text:
|
|
|
add_to_memory_realtime(
|
|
|
user_message,
|
|
|
response_text,
|
|
|
conversation_ctx.subject_keywords
|
|
|
)
|
|
|
|
|
|
context_info = {
|
|
|
"subject_keywords": conversation_ctx.subject_keywords,
|
|
|
"message_count": conversation_ctx.message_count,
|
|
|
"memory_used": len(memory_texts_filtered),
|
|
|
"rag_sources": list(set(rag_sources)),
|
|
|
"web_results": len(search_results.get("results", []))
|
|
|
}
|
|
|
|
|
|
return ChatResponse(response=response_text, images=web_images, context_info=context_info)
|
|
|
|
|
|
@app.post("/upload")
|
|
|
async def upload(files: List[UploadFile] = File(...)):
|
|
|
total_added = 0
|
|
|
processed_files = 0
|
|
|
for file in files:
|
|
|
if not file.filename.lower().endswith(".pdf"):
|
|
|
continue
|
|
|
try:
|
|
|
content = await file.read()
|
|
|
text = extract_text_from_pdf(content)
|
|
|
if not text:
|
|
|
logger.warning(f"Aucun texte extrait de {file.filename}")
|
|
|
continue
|
|
|
chunks = chunk_text(text)
|
|
|
if not chunks:
|
|
|
continue
|
|
|
timestamp = datetime.now().isoformat()
|
|
|
for chunk in chunks:
|
|
|
doc_metadata.append({
|
|
|
"source": file.filename,
|
|
|
"timestamp": timestamp,
|
|
|
"length": len(chunk)
|
|
|
})
|
|
|
embeddings = embed_model.encode(chunks, normalize_embeddings=True).astype('float32')
|
|
|
doc_index.add(embeddings)
|
|
|
doc_chunks.extend(chunks)
|
|
|
total_added += len(chunks)
|
|
|
processed_files += 1
|
|
|
logger.info(f"Upload réussi : {file.filename} → {len(chunks)} chunks ajoutés")
|
|
|
except Exception as e:
|
|
|
logger.error(f"Erreur lors du traitement de {file.filename} : {e}")
|
|
|
return {
|
|
|
"status": "success",
|
|
|
"files_processed": processed_files,
|
|
|
"chunks_added": total_added,
|
|
|
"total_doc_chunks": len(doc_chunks)
|
|
|
}
|
|
|
|
|
|
@app.post("/upload-pdfs")
|
|
|
async def upload_pdfs(files: List[UploadFile] = File(...)):
|
|
|
return await upload(files)
|
|
|
|
|
|
@app.post("/clear-memory")
|
|
|
async def clear_memory():
|
|
|
global memory_index, memory_texts, memory_metadata
|
|
|
memory_index = faiss.IndexFlatL2(dimension)
|
|
|
memory_texts = []
|
|
|
memory_metadata = []
|
|
|
conversation_ctx.__init__()
|
|
|
return {"status": "memory_cleared", "message": "Mémoire conversationnelle effacée"}
|
|
|
|
|
|
|
|
|
@app.on_event("startup")
|
|
|
async def startup_event():
|
|
|
logger.info("🚀 Kibali AI API démarrée avec succès sur Hugging Face Spaces !")
|
|
|
logger.info(f"Accès : https://your-username-your-space.hf.space | Docs : /docs")
|
|
|
logger.info(f"Mémoire adaptative et réflexion contextuelle activées ✓") |