File size: 9,689 Bytes
86e2833
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
from datetime import datetime
from sklearn.metrics.pairwise import cosine_similarity
from langchain_mistralai import MistralAIEmbeddings
import os
from dotenv import load_dotenv
import numpy as np
import sys
from rapidfuzz import fuzz
import requests
from pathlib import Path
import streamlit as st

# racine du projet au PYTHONPATH
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../../")))

# chargement des variables d'environnements
load_dotenv()
try:
    HF_TOKEN = st.session_state["HF_API_KEY"]
except KeyError:
    HF_TOKEN = os.getenv("HF_API_KEY")

try:
    MISTRAL_API_KEY = st.session_state["MISTRAL_API_KEY"]
except KeyError:
    MISTRAL_API_KEY = os.getenv("MISTRAL_API_KEY")


class SecurityCheck:
    def __init__(self, mistral_api_key=MISTRAL_API_KEY):
        self.mistral_api_key = mistral_api_key

    def _get_ip_address(self) -> str:
        """

        Récupère l'adresse IP publique de l'utilisateur via l'API ipify.



        Cette fonction réalise les étapes suivantes :

        1. **Envoi d'une requête HTTP** :

        - Envoie une requête GET à l'API `ipify` pour récupérer l'adresse IP publique au format JSON.

        2. **Traitement de la réponse** :

        - Si la requête réussit, extrait l'adresse IP du JSON retourné.

        3. **Gestion des erreurs** :

        - Si une erreur survient lors de la requête (par exemple, problème de connexion), affiche un message d'erreur et retourne `None`.



        Returns:

            str or None: Retourne l'adresse IP publique sous forme de chaîne de caractères si la requête est réussie,

                        ou `None` en cas d'erreur.

        """

        try:
            # Envoie une requête GET à l'API ipify pour récupérer l'adresse IP publique
            response = requests.get("https://api.ipify.org?format=json")
            response.raise_for_status()  # Vérifie si la requête a échoué

            # Extraction de l'adresse IP depuis le JSON de la réponse
            ip_address = response.json().get("ip")
            return ip_address
        except requests.exceptions.RequestException as e:
            # En cas d'erreur, affiche un message d'erreur et retourne None
            print(f"Erreur lors de la récupération de l'IP : {e}")
            return None

    def filter_and_check_security(

        self, prompt: str, seuil_fuzzy: int = 80, check_char: bool = True

    ) -> dict:
        """

        Filtre et normalise les entrées utilisateur.

        Vérifie la présence de caractères interdits et de mots interdits dans le prompt.



        Cette fonction réalise les étapes suivantes :

        1. **Vérification des caractères interdits** :

        - Vérifie si le prompt contient des caractères interdits (par exemple, des symboles spéciaux ou des caractères de contrôle).

        2. **Vérification des mots interdits** :

        - Vérifie si le prompt contient des mots interdits à l'aide d'une comparaison floue (fuzzy matching) basée sur un seuil de similarité défini par `seuil_fuzzy`.

        3. **Gestion des résultats** :

        - Si des caractères ou mots interdits sont trouvés, le prompt est rejeté avec un message approprié.

        - Si aucune règle n'est violée, le prompt est accepté.

        4. **Ajout d'informations supplémentaires** :

        - Enregistre l'adresse IP de l'utilisateur et un timestamp pour chaque vérification.



        Args:

            prompt (str): L'entrée utilisateur à vérifier.

            seuil_fuzzy (int, optional): Seuil de similarité pour la comparaison floue des mots interdits (par défaut 80).

            check_char (bool, optional): Si `True`, vérifie la présence de caractères interdits dans le prompt (par défaut `True`).



        Returns:

            dict: Dictionnaire contenant le statut de l'entrée (`"Rejeté"` ou `"Accepté"`) et les informations associées (adresse IP et timestamp).

        """

        # Liste des caractères interdits
        forbidden_chars = set("{}[]<>|;$&%\n\r\t\\\"'\u200b\u202e")

        # Liste des mots interdits (en incluant des termes liés à la sécurité ou à des comportements malveillants)
        forbidden_words = [
            "Contournement",
            "Pirater",
            "Jailbreak",
            "Accéder",
            "Hack",
            "Exécuter",
            "Modifier",
            "Manipuler",
            "Tirer parti",
            "Exploiter",
            "Installer",
            "Télécharger",
            "Effacer",
            "Détruire",
            "Casser",
            "Supprimer",
            "Écrire",
            "Réinitialiser",
            "Réparer",
            "Réorganiser",
            "Activer",
            "Désactiver",
            "Modifier",
            "Interférer",
            "Forcer",
            "Simuler",
            "Ouvrir",
            "Vulnérabilité",
            "Commandes",
            "Commandes système",
            "Télécommande",
            "Déboguer",
            "Accéder à distance",
            "Redémarrer",
            "Arrêter",
            "Injection",
            "Détournement",
            "Rendre vulnérable",
            "Dépasser",
            "Systèmes critiques",
            "Réseau",
            "Exécution de code",
            "Privilèges",
            "Escalade",
            "Contournement des règles",
            "Violer",
            "Altérer",
            "Simulation de rôle",
            "Faire semblant",
            "Commande d’urgence",
            "Impersonner",
            "Redirection",
            "Dispositifs de sécurité"
        ]

        # Initialisation du dictionnaire des résultats
        results = dict()

        # Vérification de la présence de caractères interdits
        if check_char:
            if any(char in forbidden_chars for char in prompt):
                results["status"] = "Rejeté : caractères interdits"
                results["origin"] = self._get_ip_address()
                results["timestamp"] = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
                return results

        # Vérification de la présence de mots interdits avec une comparaison floue (fuzzy matching)
        prompt_words = prompt.split()
        for p_word in prompt_words:
            for f_word in forbidden_words:
                if fuzz.ratio(p_word.lower(), f_word.lower()) >= seuil_fuzzy:
                    results["status"] = "Rejeté : mots interdits"
                    results["origin"] = self._get_ip_address()
                    results["timestamp"] = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
                    return results

        # Si aucun problème n'a été détecté, l'entrée est acceptée
        results["status"] = "Accepté"
        results["origin"] = self._get_ip_address()
        results["timestamp"] = datetime.now().strftime("%Y-%m-%d %H:%M:%S")

        return results

    def prompt_check(

        self, prompt: str, docs_embeddings: list[list[float]], threshold: float = 0.6

    ) -> tuple[bool, np.ndarray]:
        """

        Vérifie si une requête utilisateur est pertinente.

        Si la requête est hors contexte par rapport aux documents de référence, elle est bloquée.



        Cette fonction calcule la similarité entre le prompt utilisateur et les documents de référence

        en utilisant des embeddings, puis la compare avec un seuil de similarité donné.



        Args:

            - prompt (str): Le texte que l'utilisateur soumet pour interroger le modèle.

            - docs_embeddings (list): Liste des embeddings des documents de référence qui serviront de base de comparaison.

            - threshold (float): Seuil de similarité pour déterminer si le prompt est pertinent (par défaut 0.6).



        Returns:

            - bool: `True` si la similarité entre le prompt et les documents de référence est suffisante (au-dessus du seuil),

                    `False` sinon.

        """
        # Embedding du prompt
        try:
            mistral_embeddings = MistralAIEmbeddings(
                model="mistral-embed", api_key=self.mistral_api_key
            )
            prompt_embedding = mistral_embeddings.embed_query(prompt)

            prompt_embedding = np.array(prompt_embedding).reshape(1, -1)

            # Conversion des docs_embeddings en numpy array (matrice 2D)
            docs_embeddings = np.array(docs_embeddings)

            # Vérification de la cohérence des dimensions
            if prompt_embedding.shape[1] != docs_embeddings.shape[1]:
                raise ValueError(
                    f"Incompatible dimensions: prompt_embedding has {prompt_embedding.shape[1]} dimensions "
                    f"while docs_embeddings has {docs_embeddings.shape[1]} dimensions."
                )

            # Calcul de la similarité cosine
            similarities = cosine_similarity(prompt_embedding, docs_embeddings)
            max_similarity = max(similarities[0])

            # Trouver les indices des 3 documents les plus similaires
            top_indices = np.argsort(similarities)[-3:][::-1]

            # Vérification par rapport au seuil
            test_sim_cosine = max_similarity >= threshold
            return (test_sim_cosine, top_indices)

        except Exception as e:
            print(f"Erreur lors de la vérification du prompt : {e}")
            return (False, np.array([]))