|
|
|
|
|
from fastapi import FastAPI, UploadFile, File, HTTPException |
|
|
from fastapi.middleware.cors import CORSMiddleware |
|
|
from pydantic import BaseModel |
|
|
from typing import List, Optional |
|
|
import torch |
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig |
|
|
from sentence_transformers import SentenceTransformer |
|
|
import faiss |
|
|
import numpy as np |
|
|
from threading import Thread |
|
|
import os |
|
|
from io import BytesIO |
|
|
|
|
|
|
|
|
try: |
|
|
from pypdf import PdfReader as PypdfReader |
|
|
PDF_READER = "pypdf" |
|
|
except ImportError: |
|
|
try: |
|
|
import PyPDF2 |
|
|
PypdfReader = PyPDF2.PdfReader |
|
|
PDF_READER = "PyPDF2" |
|
|
except ImportError: |
|
|
raise ImportError("Installe pypdf ou PyPDF2 : pip install pypdf") |
|
|
|
|
|
|
|
|
from tools.web import web_search |
|
|
from tools.todo import execute_reflection_plan |
|
|
from tools.geo import get_geo_context |
|
|
|
|
|
app = FastAPI(title="Kibali AI API", version="1.0") |
|
|
|
|
|
|
|
|
app.add_middleware( |
|
|
CORSMiddleware, |
|
|
allow_origins=["*"], |
|
|
allow_credentials=True, |
|
|
allow_methods=["*"], |
|
|
allow_headers=["*"], |
|
|
) |
|
|
|
|
|
|
|
|
MODEL_PATH = "/home/belikan/geoscan/agent_kibali/model_cache" |
|
|
embed_model = SentenceTransformer('paraphrase-multilingual-MiniLM-L12-v2') |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, local_files_only=True) |
|
|
if tokenizer.pad_token is None: |
|
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
|
|
bnb_config = BitsAndBytesConfig( |
|
|
load_in_4bit=True, |
|
|
bnb_4bit_use_double_quant=True, |
|
|
bnb_4bit_quant_type="nf4", |
|
|
bnb_4bit_compute_dtype=torch.float16 |
|
|
) |
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
MODEL_PATH, |
|
|
quantization_config=bnb_config, |
|
|
device_map="auto", |
|
|
torch_dtype=torch.float16, |
|
|
trust_remote_code=True |
|
|
) |
|
|
|
|
|
|
|
|
dimension = 384 |
|
|
vector_index = faiss.IndexFlatL2(dimension) |
|
|
doc_chunks: List[str] = [] |
|
|
memory_text: List[str] = [] |
|
|
|
|
|
|
|
|
class Message(BaseModel): |
|
|
role: str |
|
|
content: str |
|
|
|
|
|
class ChatRequest(BaseModel): |
|
|
messages: List[Message] |
|
|
latitude: float |
|
|
longitude: float |
|
|
city: Optional[str] = "Libreville" |
|
|
thinking_mode: bool = True |
|
|
|
|
|
class ChatResponse(BaseModel): |
|
|
response: str |
|
|
images: List[str] = [] |
|
|
|
|
|
|
|
|
def extract_text_from_pdf(pdf_bytes: bytes) -> str: |
|
|
text = "" |
|
|
pdf_file = BytesIO(pdf_bytes) |
|
|
pdf_reader = PypdfReader(pdf_file) |
|
|
for page in pdf_reader.pages: |
|
|
page_text = page.extract_text() |
|
|
if page_text: |
|
|
text += page_text + "\n" |
|
|
return text |
|
|
|
|
|
def chunk_text(text: str, chunk_size: int = 400, overlap: int = 50) -> List[str]: |
|
|
words = text.split() |
|
|
chunks = [] |
|
|
i = 0 |
|
|
while i < len(words): |
|
|
chunk = " ".join(words[i:i + chunk_size]) |
|
|
chunks.append(chunk) |
|
|
i += chunk_size - overlap |
|
|
return chunks |
|
|
|
|
|
|
|
|
@app.post("/chat") |
|
|
async def chat(request: ChatRequest): |
|
|
prompt = request.messages[-1].content |
|
|
geo = {"latitude": request.latitude, "longitude": request.longitude, "city": request.city} |
|
|
|
|
|
|
|
|
rag_ctx = "" |
|
|
if vector_index.ntotal > 0 and doc_chunks: |
|
|
query_vec = embed_model.encode([prompt], normalize_embeddings=True).astype('float32') |
|
|
D, I = vector_index.search(query_vec, k=5) |
|
|
relevant = [doc_chunks[i] for i in I[0] if i < len(doc_chunks)] |
|
|
rag_ctx = "\n\n".join([f"Doc: {c[:800]}" for c in relevant]) |
|
|
|
|
|
|
|
|
past_ctx = "" |
|
|
if memory_text: |
|
|
query_vec = embed_model.encode([prompt], normalize_embeddings=True).astype('float32') |
|
|
D, I = vector_index.search(query_vec, k=2) |
|
|
past_ctx = "\n".join([memory_text[i] for i in I[0] if 0 <= i < len(memory_text)]) |
|
|
|
|
|
|
|
|
search_data = web_search(prompt) |
|
|
web_ctx = "\n".join([f"- {r['content'][:300]}" for r in search_data.get("results", [])]) |
|
|
imgs = search_data.get("images", [])[:3] |
|
|
|
|
|
|
|
|
sys_instr = f"Tu es Kibali, assistant intelligent au Gabon ({geo['city']}). Réponds précisément." |
|
|
final_prompt = f"""### SYSTEM: {sys_instr} |
|
|
### DOCUMENTS: {rag_ctx} |
|
|
### MÉMOIRE: {past_ctx} |
|
|
### WEB: {web_ctx} |
|
|
### QUESTION: {prompt} |
|
|
### RÉPONSE:""" |
|
|
|
|
|
inputs = tokenizer(final_prompt, return_tensors="pt").to(model.device) |
|
|
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) |
|
|
|
|
|
def generate(): |
|
|
model.generate(**inputs, streamer=streamer, max_new_tokens=800, temperature=0.3, do_sample=True) |
|
|
|
|
|
Thread(target=generate).start() |
|
|
|
|
|
full_response = "" |
|
|
for token in streamer: |
|
|
if "###" in token: |
|
|
break |
|
|
full_response += token |
|
|
|
|
|
|
|
|
new_mem = f"Q: {prompt} | R: {full_response[:500]}..." |
|
|
memory_text.append(new_mem) |
|
|
mem_emb = embed_model.encode([new_mem], normalize_embeddings=True).astype('float32') |
|
|
vector_index.add(mem_emb) |
|
|
|
|
|
return ChatResponse(response=full_response, images=imgs) |
|
|
|
|
|
@app.post("/upload-pdfs") |
|
|
async def upload_pdfs(files: List[UploadFile] = File(...)): |
|
|
new_chunks = [] |
|
|
for file in files: |
|
|
if not file.filename.endswith(".pdf"): |
|
|
continue |
|
|
content = await file.read() |
|
|
text = extract_text_from_pdf(content) |
|
|
if text.strip(): |
|
|
chunks = chunk_text(text) |
|
|
new_chunks.extend(chunks) |
|
|
|
|
|
if new_chunks: |
|
|
embeddings = embed_model.encode(new_chunks, normalize_embeddings=True).astype('float32') |
|
|
vector_index.add(embeddings) |
|
|
doc_chunks.extend(new_chunks) |
|
|
|
|
|
return {"status": "success", "chunks_added": len(new_chunks), "total_chunks": len(doc_chunks)} |
|
|
|
|
|
@app.get("/status") |
|
|
async def status(): |
|
|
return {"status": "ready", "chunks": len(doc_chunks), "pdf_library": PDF_READER} |