File size: 2,624 Bytes
7919734
 
 
 
 
 
 
 
 
 
97cda46
7919734
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97cda46
 
 
7919734
97cda46
7919734
97cda46
7919734
 
 
 
 
 
 
 
97cda46
7919734
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
"""
data_loader.py — Dataset loading + FAISS index construction + query embedding.

Runs at Space startup (once). Each dataset gets a FAISS index built in memory.
Graceful degradation: if one source fails, the others continue.
"""

import os
import numpy as np
from datasets import load_dataset
from mistralai import Mistral

DATASET_REPO = "ArthurSrz/open_codes"
EMBED_MODEL = "mistral-embed"
EMBED_DIM = 1024

# Tracks which sources loaded successfully
LOADING_STATUS: dict[str, bool] = {
    "articles": False,
    "jurisprudence": False,
    "circulaires": False,
    "reponses": False,
}

_datasets: dict = {}


def load_all_datasets() -> dict:
    """
    Load all four configs from ArthurSrz/open_codes and build FAISS indexes.
    Returns dict with keys: articles, jurisprudence, circulaires, reponses.
    Missing sources have value None.
    """
    configs = [
        ("articles",      "default"),
        ("jurisprudence", "jurisprudence"),
        ("circulaires",   "circulaires"),
        ("reponses",      "reponses_legis"),
    ]

    result: dict = {}

    for key, config_name in configs:
        try:
            print(f"[data_loader] Loading {config_name}…")
            ds = load_dataset(DATASET_REPO, name=config_name, split="train")
            ds.add_faiss_index(column="embedding")
            result[key] = ds
            LOADING_STATUS[key] = True
            print(f"[data_loader] ✓ {config_name}: {len(ds)} rows, FAISS index built")
        except Exception as e:
            print(f"[data_loader] ✗ {config_name} failed: {e}")
            result[key] = None
            LOADING_STATUS[key] = False

    _datasets.update(result)
    return result


def embed_query(query_text: str, hf_token: str) -> list[float]:
    """
    Embed a query string using Mistral mistral-embed via HF Inference API.
    Returns a 1024-dim float list.
    Raises ValueError with user-readable message on failure.
    """
    try:
        api_key = os.environ.get("MISTRAL_API_KEY", "")
        client = Mistral(api_key=api_key)
        response = client.embeddings.create(
            model=EMBED_MODEL,
            inputs=[query_text],
        )
        embedding = response.data[0].embedding
        if len(embedding) != EMBED_DIM:
            raise ValueError(
                f"Embedding dimension mismatch: expected {EMBED_DIM}, got {len(embedding)}"
            )
        return embedding
    except Exception as e:
        raise ValueError(
            f"Impossible d'encoder la requête : {e}. "
            "Vérifiez que MISTRAL_API_KEY est configuré dans les secrets du Space."
        ) from e