# main.py - Backend API FastAPI pour Kibali AI from fastapi import FastAPI, UploadFile, File, HTTPException from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from typing import List, Optional import torch from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig from sentence_transformers import SentenceTransformer import faiss import numpy as np from threading import Thread import os from io import BytesIO # PDF Reader (même détection que avant) try: from pypdf import PdfReader as PypdfReader PDF_READER = "pypdf" except ImportError: try: import PyPDF2 PypdfReader = PyPDF2.PdfReader PDF_READER = "PyPDF2" except ImportError: raise ImportError("Installe pypdf ou PyPDF2 : pip install pypdf") # Outils personnalisés from tools.web import web_search from tools.todo import execute_reflection_plan from tools.geo import get_geo_context app = FastAPI(title="Kibali AI API", version="1.0") # CORS pour permettre au frontend JS d'appeler l'API app.add_middleware( CORSMiddleware, allow_origins=["*"], # À restreindre en prod allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # --- CHARGEMENT MODÈLES (au démarrage) --- MODEL_PATH = "/home/belikan/geoscan/agent_kibali/model_cache" # à adapter ou via env var embed_model = SentenceTransformer('paraphrase-multilingual-MiniLM-L12-v2') tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, local_files_only=True) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.float16 ) model = AutoModelForCausalLM.from_pretrained( MODEL_PATH, quantization_config=bnb_config, device_map="auto", torch_dtype=torch.float16, trust_remote_code=True ) # Base vectorielle globale (en mémoire, persiste tant que l'app tourne) dimension = 384 vector_index = faiss.IndexFlatL2(dimension) doc_chunks: List[str] = [] memory_text: List[str] = [] # --- Modèles Pydantic --- class Message(BaseModel): role: str content: str class ChatRequest(BaseModel): messages: List[Message] latitude: float longitude: float city: Optional[str] = "Libreville" thinking_mode: bool = True class ChatResponse(BaseModel): response: str images: List[str] = [] # --- Fonctions utilitaires --- def extract_text_from_pdf(pdf_bytes: bytes) -> str: text = "" pdf_file = BytesIO(pdf_bytes) pdf_reader = PypdfReader(pdf_file) for page in pdf_reader.pages: page_text = page.extract_text() if page_text: text += page_text + "\n" return text def chunk_text(text: str, chunk_size: int = 400, overlap: int = 50) -> List[str]: words = text.split() chunks = [] i = 0 while i < len(words): chunk = " ".join(words[i:i + chunk_size]) chunks.append(chunk) i += chunk_size - overlap return chunks # --- Routes API --- @app.post("/chat") async def chat(request: ChatRequest): prompt = request.messages[-1].content geo = {"latitude": request.latitude, "longitude": request.longitude, "city": request.city} # RAG sur documents rag_ctx = "" if vector_index.ntotal > 0 and doc_chunks: query_vec = embed_model.encode([prompt], normalize_embeddings=True).astype('float32') D, I = vector_index.search(query_vec, k=5) relevant = [doc_chunks[i] for i in I[0] if i < len(doc_chunks)] rag_ctx = "\n\n".join([f"Doc: {c[:800]}" for c in relevant]) # Mémoire conversation past_ctx = "" if memory_text: query_vec = embed_model.encode([prompt], normalize_embeddings=True).astype('float32') D, I = vector_index.search(query_vec, k=2) past_ctx = "\n".join([memory_text[i] for i in I[0] if 0 <= i < len(memory_text)]) # Web search search_data = web_search(prompt) web_ctx = "\n".join([f"- {r['content'][:300]}" for r in search_data.get("results", [])]) imgs = search_data.get("images", [])[:3] # Prompt final sys_instr = f"Tu es Kibali, assistant intelligent au Gabon ({geo['city']}). Réponds précisément." final_prompt = f"""### SYSTEM: {sys_instr} ### DOCUMENTS: {rag_ctx} ### MÉMOIRE: {past_ctx} ### WEB: {web_ctx} ### QUESTION: {prompt} ### RÉPONSE:""" inputs = tokenizer(final_prompt, return_tensors="pt").to(model.device) streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) def generate(): model.generate(**inputs, streamer=streamer, max_new_tokens=800, temperature=0.3, do_sample=True) Thread(target=generate).start() full_response = "" for token in streamer: if "###" in token: break full_response += token # Mise à jour mémoire new_mem = f"Q: {prompt} | R: {full_response[:500]}..." memory_text.append(new_mem) mem_emb = embed_model.encode([new_mem], normalize_embeddings=True).astype('float32') vector_index.add(mem_emb) return ChatResponse(response=full_response, images=imgs) @app.post("/upload-pdfs") async def upload_pdfs(files: List[UploadFile] = File(...)): new_chunks = [] for file in files: if not file.filename.endswith(".pdf"): continue content = await file.read() text = extract_text_from_pdf(content) if text.strip(): chunks = chunk_text(text) new_chunks.extend(chunks) if new_chunks: embeddings = embed_model.encode(new_chunks, normalize_embeddings=True).astype('float32') vector_index.add(embeddings) doc_chunks.extend(new_chunks) return {"status": "success", "chunks_added": len(new_chunks), "total_chunks": len(doc_chunks)} @app.get("/status") async def status(): return {"status": "ready", "chunks": len(doc_chunks), "pdf_library": PDF_READER}