agent-kibali / app.py
lojol469-cmd
Initialisation propre de Kibali : API et Frontend
30b5e11
# main.py - Backend API FastAPI pour Kibali AI
from fastapi import FastAPI, UploadFile, File, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from typing import List, Optional
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig
from sentence_transformers import SentenceTransformer
import faiss
import numpy as np
from threading import Thread
import os
from io import BytesIO
# PDF Reader (même détection que avant)
try:
from pypdf import PdfReader as PypdfReader
PDF_READER = "pypdf"
except ImportError:
try:
import PyPDF2
PypdfReader = PyPDF2.PdfReader
PDF_READER = "PyPDF2"
except ImportError:
raise ImportError("Installe pypdf ou PyPDF2 : pip install pypdf")
# Outils personnalisés
from tools.web import web_search
from tools.todo import execute_reflection_plan
from tools.geo import get_geo_context
app = FastAPI(title="Kibali AI API", version="1.0")
# CORS pour permettre au frontend JS d'appeler l'API
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # À restreindre en prod
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# --- CHARGEMENT MODÈLES (au démarrage) ---
MODEL_PATH = "/home/belikan/geoscan/agent_kibali/model_cache" # à adapter ou via env var
embed_model = SentenceTransformer('paraphrase-multilingual-MiniLM-L12-v2')
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, local_files_only=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16
)
model = AutoModelForCausalLM.from_pretrained(
MODEL_PATH,
quantization_config=bnb_config,
device_map="auto",
torch_dtype=torch.float16,
trust_remote_code=True
)
# Base vectorielle globale (en mémoire, persiste tant que l'app tourne)
dimension = 384
vector_index = faiss.IndexFlatL2(dimension)
doc_chunks: List[str] = []
memory_text: List[str] = []
# --- Modèles Pydantic ---
class Message(BaseModel):
role: str
content: str
class ChatRequest(BaseModel):
messages: List[Message]
latitude: float
longitude: float
city: Optional[str] = "Libreville"
thinking_mode: bool = True
class ChatResponse(BaseModel):
response: str
images: List[str] = []
# --- Fonctions utilitaires ---
def extract_text_from_pdf(pdf_bytes: bytes) -> str:
text = ""
pdf_file = BytesIO(pdf_bytes)
pdf_reader = PypdfReader(pdf_file)
for page in pdf_reader.pages:
page_text = page.extract_text()
if page_text:
text += page_text + "\n"
return text
def chunk_text(text: str, chunk_size: int = 400, overlap: int = 50) -> List[str]:
words = text.split()
chunks = []
i = 0
while i < len(words):
chunk = " ".join(words[i:i + chunk_size])
chunks.append(chunk)
i += chunk_size - overlap
return chunks
# --- Routes API ---
@app.post("/chat")
async def chat(request: ChatRequest):
prompt = request.messages[-1].content
geo = {"latitude": request.latitude, "longitude": request.longitude, "city": request.city}
# RAG sur documents
rag_ctx = ""
if vector_index.ntotal > 0 and doc_chunks:
query_vec = embed_model.encode([prompt], normalize_embeddings=True).astype('float32')
D, I = vector_index.search(query_vec, k=5)
relevant = [doc_chunks[i] for i in I[0] if i < len(doc_chunks)]
rag_ctx = "\n\n".join([f"Doc: {c[:800]}" for c in relevant])
# Mémoire conversation
past_ctx = ""
if memory_text:
query_vec = embed_model.encode([prompt], normalize_embeddings=True).astype('float32')
D, I = vector_index.search(query_vec, k=2)
past_ctx = "\n".join([memory_text[i] for i in I[0] if 0 <= i < len(memory_text)])
# Web search
search_data = web_search(prompt)
web_ctx = "\n".join([f"- {r['content'][:300]}" for r in search_data.get("results", [])])
imgs = search_data.get("images", [])[:3]
# Prompt final
sys_instr = f"Tu es Kibali, assistant intelligent au Gabon ({geo['city']}). Réponds précisément."
final_prompt = f"""### SYSTEM: {sys_instr}
### DOCUMENTS: {rag_ctx}
### MÉMOIRE: {past_ctx}
### WEB: {web_ctx}
### QUESTION: {prompt}
### RÉPONSE:"""
inputs = tokenizer(final_prompt, return_tensors="pt").to(model.device)
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
def generate():
model.generate(**inputs, streamer=streamer, max_new_tokens=800, temperature=0.3, do_sample=True)
Thread(target=generate).start()
full_response = ""
for token in streamer:
if "###" in token:
break
full_response += token
# Mise à jour mémoire
new_mem = f"Q: {prompt} | R: {full_response[:500]}..."
memory_text.append(new_mem)
mem_emb = embed_model.encode([new_mem], normalize_embeddings=True).astype('float32')
vector_index.add(mem_emb)
return ChatResponse(response=full_response, images=imgs)
@app.post("/upload-pdfs")
async def upload_pdfs(files: List[UploadFile] = File(...)):
new_chunks = []
for file in files:
if not file.filename.endswith(".pdf"):
continue
content = await file.read()
text = extract_text_from_pdf(content)
if text.strip():
chunks = chunk_text(text)
new_chunks.extend(chunks)
if new_chunks:
embeddings = embed_model.encode(new_chunks, normalize_embeddings=True).astype('float32')
vector_index.add(embeddings)
doc_chunks.extend(new_chunks)
return {"status": "success", "chunks_added": len(new_chunks), "total_chunks": len(doc_chunks)}
@app.get("/status")
async def status():
return {"status": "ready", "chunks": len(doc_chunks), "pdf_library": PDF_READER}