| | import gradio as gr |
| | from transformers import pipeline |
| | from huggingface_hub import InferenceClient, login, snapshot_download |
| | from langchain_community.vectorstores import FAISS, DistanceStrategy |
| | from langchain_huggingface import HuggingFaceEmbeddings |
| | import os |
| | import pandas as pd |
| | from datetime import datetime |
| |
|
| | from smolagents import Tool, HfApiModel, ToolCallingAgent |
| | from langchain_core.vectorstores import VectorStore |
| |
|
| |
|
| | class RetrieverTool(Tool): |
| | name = "retriever" |
| | description = "Using semantic similarity in German, French, English and Italian, retrieves some documents from the knowledge base that have the closest embeddings to the input query." |
| | inputs = { |
| | "query": { |
| | "type": "string", |
| | "description": "The query to perform. This should be semantically close to your target documents. Use the affirmative form rather than a question.", |
| | } |
| | } |
| | output_type = "string" |
| |
|
| | def __init__(self, vectordb: VectorStore, **kwargs): |
| | super().__init__(**kwargs) |
| | self.vectordb = vectordb |
| |
|
| | def forward(self, query: str) -> str: |
| | assert isinstance(query, str), "Your search query must be a string" |
| |
|
| | docs = self.vectordb.similarity_search( |
| | query, |
| | k=7, |
| | ) |
| |
|
| | spacer = " \n" |
| | context = "" |
| | nb_char = 100 |
| | |
| | for doc in docs: |
| | case_text = df[df["case_url"] == doc.metadata["case_url"]].case_text.values[0] |
| | index = case_text.find(doc.page_content) |
| | start = max(0, index - nb_char) |
| | end = min(len(case_text), index + len(doc.page_content) + nb_char) |
| | case_text_summary = case_text[start:end] |
| | |
| | context += "#######" + spacer |
| | context += "# Case number: " + doc.metadata["case_ref"] + " " + doc.metadata["case_nb"] + spacer |
| | context += "# Case source: " + ("Swiss Federal Court" if doc.metadata["case_ref"] == "ATF" else "European Court of Human Rights") + spacer |
| | context += "# Case date: " + doc.metadata["case_date"] + spacer |
| | context += "# Case url: " + doc.metadata["case_url"] + spacer |
| | |
| | context += "# Case extract: " + case_text_summary + spacer |
| |
|
| |
|
| | return "\nRetrieved documents:\n" + context |
| |
|
| |
|
| | """ |
| | For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference |
| | """ |
| | HF_TOKEN=os.getenv('TOKEN') |
| | login(HF_TOKEN) |
| |
|
| | model = "meta-llama/Meta-Llama-3-8B-Instruct" |
| | |
| |
|
| | client = InferenceClient(model) |
| |
|
| | folder = snapshot_download(repo_id="umaiku/faiss_index", repo_type="dataset", local_dir=os.getcwd()) |
| |
|
| | embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/paraphrase-multilingual-mpnet-base-v2") |
| |
|
| | vector_db = FAISS.load_local("faiss_index_mpnet_cos", embeddings, allow_dangerous_deserialization=True, distance_strategy=DistanceStrategy.COSINE) |
| |
|
| | df = pd.read_csv("bger_cedh_db 1954-2024.csv") |
| |
|
| | retriever_tool = RetrieverTool(vector_db) |
| | agent = ToolCallingAgent(tools=[retriever_tool], model=HfApiModel(model)) |
| |
|
| | def respond(message, history: list[tuple[str, str]], system_message, max_tokens, temperature, top_p, score,): |
| |
|
| | print(datetime.now()) |
| | context = retriever_tool(message) |
| | |
| | print(message) |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | |
| |
|
| | if True: |
| | prompt = f"""Given the question and supporting documents below, give a comprehensive answer to the question. |
| | Respond only to the question asked, response should be relevant to the question and in the same language as the question. |
| | Provide the number of the source document when relevant, as well as the link to the document. |
| | If you cannot find information, do not give up and try calling your retriever again with different arguments! |
| | Always give url of the sources at the end and only answer in the language the question is asked. |
| | |
| | Question: |
| | {message} |
| | |
| | {context} |
| | """ |
| | else: |
| | prompt = f"""A user wrote the following message, please answer him to best of your knowledge in the language of his message: |
| | {message}""" |
| | |
| | messages = [{"role": "system", "content": system_message}] |
| |
|
| | for val in history: |
| | if val[0]: |
| | messages.append({"role": "user", "content": val[0]}) |
| | if val[1]: |
| | messages.append({"role": "assistant", "content": val[1]}) |
| |
|
| | messages.append({"role": "user", "content": prompt}) |
| |
|
| | response = "" |
| |
|
| |
|
| | for message in client.chat_completion( |
| | messages, |
| | max_tokens=max_tokens, |
| | stream=True, |
| | temperature=temperature, |
| | top_p=top_p, |
| | ): |
| | token = message.choices[0].delta.content |
| | |
| | response += token |
| | yield response |
| |
|
| |
|
| | """ |
| | For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface |
| | """ |
| | demo = gr.ChatInterface( |
| | respond, |
| | additional_inputs=[ |
| | gr.Textbox(value="You are assisting a jurist or a layer in finding relevant Swiss Jurisprudence cases to their question.", label="System message"), |
| | gr.Slider(minimum=1, maximum=24000, value=5000, step=1, label="Max new tokens"), |
| | gr.Slider(minimum=0.1, maximum=4.0, value=0.1, step=0.1, label="Temperature"), |
| | gr.Slider( |
| | minimum=0.1, |
| | maximum=1.0, |
| | value=0.95, |
| | step=0.05, |
| | label="Top-p (nucleus sampling)", |
| | ), |
| | gr.Slider(minimum=0, maximum=1, value=0.75, step=0.05, label="Score Threshold"), |
| | ], |
| | description="# 📜 ALexI: Artificial Legal Intelligence for Swiss Jurisprudence", |
| | ) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | print("Ready!") |
| | demo.launch(debug=True) |