File size: 4,330 Bytes
44a12cc
 
 
 
 
 
 
 
 
 
 
f3a7708
585edd8
f3a7708
44a12cc
585edd8
44a12cc
 
585edd8
 
 
44a12cc
 
3b07574
44a12cc
 
 
 
 
 
 
 
2abc1bb
44a12cc
 
 
 
 
 
f3a7708
 
 
44a12cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c410f41
646f518
44a12cc
 
 
 
 
 
 
c410f41
 
 
 
 
 
 
 
 
 
 
 
44a12cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import os
import faiss
import pickle
import numpy as np
import openai
import tiktoken
from dotenv import load_dotenv
from openai import OpenAI
from pathlib import Path
from huggingface_hub import hf_hub_download

# assurez-vous que ce dossier existe et est writeable
CACHE_DIR = os.getenv("CACHE_DIR", "/tmp/cache")
os.makedirs(CACHE_DIR, exist_ok=True)


# —— CONFIG ——
load_dotenv()
os.environ["TRANSFORMERS_CACHE"] = os.getenv("TRANSFORMERS_CACHE", "/tmp/huggingface/cache")
os.environ["HF_HOME"]          = os.getenv("HF_HOME", "/tmp/huggingface")

# Use the new OpenAI client
client = OpenAI()
EMBED_MODEL           = "text-embedding-3-large"
CHAT_MODEL            = "o4-mini-2025-04-16"
FAISS_INDEX_FILE      = "tindle_index.faiss"
IDS_PKL               = "tindle_ids.pkl"
CHUNKS_PKL            = "tindle_chunks.pkl"
TOP_K                 = 10
MAX_TOKENS_CONTEXT    = 4000
SYSTEM_PROMPT = (
    "Tu es un assistant expert en droit fiscal. "
    "Fais d'abord appel aux passages fournis pour répondre. "
    "Si ces passages sont insuffisants, utilise tes connaissances générales en le précisant clairement."
)

# —— CHARGEMENT DE L'INDEX ——

# Télécharger les fichiers depuis le repo Hugging Face
index_path = hf_hub_download(repo_id="Jordanche/fiscarag", filename=FAISS_INDEX_FILE, repo_type="dataset", cache_dir=CACHE_DIR)
ids_path = hf_hub_download(repo_id="Jordanche/fiscarag", filename=IDS_PKL, repo_type="dataset" ,cache_dir=CACHE_DIR)
chunks_path = hf_hub_download(repo_id="Jordanche/fiscarag", filename=CHUNKS_PKL, repo_type="dataset",cache_dir=CACHE_DIR)

# Charger les fichiers
index = faiss.read_index(index_path)
with open(ids_path, "rb") as f:
    ids = pickle.load(f)
with open(chunks_path, "rb") as f:
    chunks_dict = pickle.load(f)

# —— TOKEN COUNTER ——
enc = tiktoken.get_encoding("cl100k_base")
def num_tokens(s: str) -> int:
    return len(enc.encode(s))


# —— FONCTIONS RAG ——

def embed_question(question: str) -> list[float]:
    resp = client.embeddings.create(
        model=EMBED_MODEL,
        input=[question]
    )
    # on récupère l'attribut .data, puis .embedding
    return resp.data[0].embedding


def retrieve_chunks(q_emb: list[float], k: int = TOP_K):
    xq = np.array([q_emb], dtype="float32")
    distances, indices = index.search(xq, k)
    out = []
    for dist, idx in zip(distances[0], indices[0]):
        cid  = ids[idx]
        meta = chunks_dict[cid]
        out.append({
            "score":  float(dist),
            "id":     cid,
            "text":   meta["text"],
            "metadata": {cle: val for cle, val in meta.items() if cle != "text"}  # Inclure le dictionnaire metadata complet
        })
    return out


def build_context(chunks, max_tokens=MAX_TOKENS_CONTEXT):
    parts, tokens = [], 0
    for c in sorted(chunks, key=lambda x: x["score"]):
        # Construire la section métadonnées
        metadata_parts = []
        for key, value in c["metadata"].items():
            metadata_parts.append(f"{key}: {value}")
        
        metadata_str = f" | ".join(metadata_parts) if metadata_parts else ""
        source_info = f"(Source: {c['id']}"
        if metadata_str:
            source_info += f" | {metadata_str}"
        source_info += ")"
        
        piece = f"{source_info} {c['text']}"
        nt = num_tokens(piece)
        if tokens + nt > max_tokens:
            break
        parts.append(piece)
        tokens += nt
    return "\n\n".join(parts)


def make_prompt(question: str, context: str):
    return [
      {"role": "system", "content": SYSTEM_PROMPT},
      {"role": "user",   "content": f"Question: {question}\n\nContexte:\n{context}"}
    ]


def answer_question(question: str, k: int = TOP_K) -> str:
    # 1) Embed
    q_emb = embed_question(question)

    # 2) Retrieve
    top_chunks = retrieve_chunks(q_emb, k)

    # 3) Assemble
    context = build_context(top_chunks)

    # 4) Prompt
    messages = make_prompt(question, context)
    # 5) Call LLM
    resp = client.chat.completions.create(
        model=CHAT_MODEL,
        messages=messages    )
    return resp.choices[0].message.content

# —— EXEMPLE ——
if __name__ == "__main__":
    question = "Quels sont les délais pour la réhabilitation d'hôtels en outre-mer ?"
    print(answer_question(question, k=10))