File size: 5,994 Bytes
5c60ed2
0bf6060
f8adcff
90d1e52
0635997
df1b3de
 
 
0635997
90d1e52
df1b3de
90d1e52
 
 
 
df1b3de
90d1e52
 
 
 
 
 
 
 
df1b3de
90d1e52
 
 
 
 
df1b3de
 
 
 
 
90d1e52
 
 
 
df1b3de
90d1e52
df1b3de
 
 
 
 
 
90d1e52
df1b3de
 
 
 
 
90d1e52
 
 
df1b3de
0635997
60ac7f7
df1b3de
 
 
 
 
60ac7f7
6e990be
 
f846748
df1b3de
e24fae8
dd76368
0635997
79c456d
0635997
df1b3de
d5c54ef
df1b3de
aa78463
df1b3de
 
a6051b9
df1b3de
5cc0589
e24fae8
df1b3de
 
 
6399f7b
df1b3de
 
 
 
6399f7b
df1b3de
 
 
 
6399f7b
 
 
 
df1b3de
6399f7b
df1b3de
 
6399f7b
 
df1b3de
 
 
 
0b192b7
 
df1b3de
 
 
 
 
34f414b
8b0ad99
f846748
df1b3de
 
 
 
 
f846748
df1b3de
f846748
 
 
df1b3de
 
 
 
f846748
 
df1b3de
 
 
f846748
 
 
df1b3de
d2eb5fb
a6051b9
df1b3de
 
 
 
 
 
 
4c92796
f846748
cddcba8
f846748
 
df1b3de
f846748
5cc0589
df1b3de
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
import gradio as gr
from transformers import pipeline
from huggingface_hub import InferenceClient, login, snapshot_download
from langchain_community.vectorstores import FAISS, DistanceStrategy
from langchain_huggingface import HuggingFaceEmbeddings
import os
import pandas as pd
from datetime import datetime

from smolagents import Tool, HfApiModel, ToolCallingAgent
from langchain_core.vectorstores import VectorStore


class RetrieverTool(Tool):
    name = "retriever"
    description = "Using semantic similarity in German, French, English and Italian, retrieves some documents from the knowledge base that have the closest embeddings to the input query."
    inputs = {
        "query": {
            "type": "string",
            "description": "The query to perform. This should be semantically close to your target documents. Use the affirmative form rather than a question.",
        }
    }
    output_type = "string"

    def __init__(self, vectordb: VectorStore, **kwargs):
        super().__init__(**kwargs)
        self.vectordb = vectordb

    def forward(self, query: str) -> str:
        assert isinstance(query, str), "Your search query must be a string"

        docs = self.vectordb.similarity_search(
            query,
            k=7,
        )

        spacer = " \n"
        context = ""
        nb_char = 100
        
        for doc in docs:
            case_text = df[df["case_url"] == doc.metadata["case_url"]].case_text.values[0]
            index = case_text.find(doc.page_content)
            start = max(0, index - nb_char)
            end = min(len(case_text), index + len(doc.page_content) + nb_char)
            case_text_summary = case_text[start:end]
            
            context += "#######" + spacer
            context += "# Case number: " + doc.metadata["case_ref"] + " " + doc.metadata["case_nb"] + spacer
            context += "# Case source: " + ("Swiss Federal Court" if doc.metadata["case_ref"] == "ATF" else "European Court of Human Rights") + spacer
            context += "# Case date: " + doc.metadata["case_date"] + spacer
            context += "# Case url: " + doc.metadata["case_url"] + spacer
            #context += "# Case text: " + doc.page_content + spacer
            context += "# Case extract: " + case_text_summary + spacer


        return "\nRetrieved documents:\n" + context


"""
For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
"""
HF_TOKEN=os.getenv('TOKEN')
login(HF_TOKEN)

model = "meta-llama/Meta-Llama-3-8B-Instruct"
#model = "swiss-ai/Apertus-8B-Instruct-2509"

client = InferenceClient(model)

folder = snapshot_download(repo_id="umaiku/faiss_index", repo_type="dataset", local_dir=os.getcwd())

embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/paraphrase-multilingual-mpnet-base-v2")

vector_db = FAISS.load_local("faiss_index_mpnet_cos", embeddings, allow_dangerous_deserialization=True, distance_strategy=DistanceStrategy.COSINE)

df = pd.read_csv("bger_cedh_db 1954-2024.csv")

retriever_tool = RetrieverTool(vector_db)
agent = ToolCallingAgent(tools=[retriever_tool], model=HfApiModel(model))

def respond(message, history: list[tuple[str, str]], system_message, max_tokens, temperature, top_p, score,):

    print(datetime.now())
    context = retriever_tool(message)
    
    print(message)

#    is_law = client.text_generation(f"""Given the user question below, classify it as either being about "Law" or "Other".
#Do NOT respond with more than one word.
#Question:
#{message}""")

#    print(is_law)

    if True: #is_law.lower() != "other":    
        prompt = f"""Given the question and supporting documents below, give a comprehensive answer to the question.
Respond only to the question asked, response should be relevant to the question and in the same language as the question.
Provide the number of the source document when relevant, as well as the link to the document.
If you cannot find information, do not give up and try calling your retriever again with different arguments!
Always give url of the sources at the end and only answer in the language the question is asked.
    
Question:
{message}
    
{context}
"""
    else:
        prompt = f"""A user wrote the following message, please answer him to best of your knowledge in the language of his message:
{message}"""
    
    messages = [{"role": "system", "content": system_message}]

    for val in history:
        if val[0]:
            messages.append({"role": "user", "content": val[0]})
        if val[1]:
            messages.append({"role": "assistant", "content": val[1]})

    messages.append({"role": "user", "content": prompt})

    response = ""


    for message in client.chat_completion(
        messages,
        max_tokens=max_tokens,
        stream=True,
        temperature=temperature,
        top_p=top_p,
    ):
        token = message.choices[0].delta.content
        
        response += token
        yield response


"""
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
"""
demo = gr.ChatInterface(
    respond,
    additional_inputs=[
        gr.Textbox(value="You are assisting a jurist or a layer in finding relevant Swiss Jurisprudence cases to their question.", label="System message"),
        gr.Slider(minimum=1, maximum=24000, value=5000, step=1, label="Max new tokens"),
        gr.Slider(minimum=0.1, maximum=4.0, value=0.1, step=0.1, label="Temperature"),
        gr.Slider(
            minimum=0.1,
            maximum=1.0,
            value=0.95,
            step=0.05,
            label="Top-p (nucleus sampling)",
        ),
        gr.Slider(minimum=0, maximum=1, value=0.75, step=0.05, label="Score Threshold"),
    ],
    description="# 📜 ALexI: Artificial Legal Intelligence for Swiss Jurisprudence",
)


if __name__ == "__main__":
    print("Ready!")
    demo.launch(debug=True)