File size: 5,994 Bytes
5c60ed2 0bf6060 f8adcff 90d1e52 0635997 df1b3de 0635997 90d1e52 df1b3de 90d1e52 df1b3de 90d1e52 df1b3de 90d1e52 df1b3de 90d1e52 df1b3de 90d1e52 df1b3de 90d1e52 df1b3de 90d1e52 df1b3de 0635997 60ac7f7 df1b3de 60ac7f7 6e990be f846748 df1b3de e24fae8 dd76368 0635997 79c456d 0635997 df1b3de d5c54ef df1b3de aa78463 df1b3de a6051b9 df1b3de 5cc0589 e24fae8 df1b3de 6399f7b df1b3de 6399f7b df1b3de 6399f7b df1b3de 6399f7b df1b3de 6399f7b df1b3de 0b192b7 df1b3de 34f414b 8b0ad99 f846748 df1b3de f846748 df1b3de f846748 df1b3de f846748 df1b3de f846748 df1b3de d2eb5fb a6051b9 df1b3de 4c92796 f846748 cddcba8 f846748 df1b3de f846748 5cc0589 df1b3de | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 | import gradio as gr
from transformers import pipeline
from huggingface_hub import InferenceClient, login, snapshot_download
from langchain_community.vectorstores import FAISS, DistanceStrategy
from langchain_huggingface import HuggingFaceEmbeddings
import os
import pandas as pd
from datetime import datetime
from smolagents import Tool, HfApiModel, ToolCallingAgent
from langchain_core.vectorstores import VectorStore
class RetrieverTool(Tool):
name = "retriever"
description = "Using semantic similarity in German, French, English and Italian, retrieves some documents from the knowledge base that have the closest embeddings to the input query."
inputs = {
"query": {
"type": "string",
"description": "The query to perform. This should be semantically close to your target documents. Use the affirmative form rather than a question.",
}
}
output_type = "string"
def __init__(self, vectordb: VectorStore, **kwargs):
super().__init__(**kwargs)
self.vectordb = vectordb
def forward(self, query: str) -> str:
assert isinstance(query, str), "Your search query must be a string"
docs = self.vectordb.similarity_search(
query,
k=7,
)
spacer = " \n"
context = ""
nb_char = 100
for doc in docs:
case_text = df[df["case_url"] == doc.metadata["case_url"]].case_text.values[0]
index = case_text.find(doc.page_content)
start = max(0, index - nb_char)
end = min(len(case_text), index + len(doc.page_content) + nb_char)
case_text_summary = case_text[start:end]
context += "#######" + spacer
context += "# Case number: " + doc.metadata["case_ref"] + " " + doc.metadata["case_nb"] + spacer
context += "# Case source: " + ("Swiss Federal Court" if doc.metadata["case_ref"] == "ATF" else "European Court of Human Rights") + spacer
context += "# Case date: " + doc.metadata["case_date"] + spacer
context += "# Case url: " + doc.metadata["case_url"] + spacer
#context += "# Case text: " + doc.page_content + spacer
context += "# Case extract: " + case_text_summary + spacer
return "\nRetrieved documents:\n" + context
"""
For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
"""
HF_TOKEN=os.getenv('TOKEN')
login(HF_TOKEN)
model = "meta-llama/Meta-Llama-3-8B-Instruct"
#model = "swiss-ai/Apertus-8B-Instruct-2509"
client = InferenceClient(model)
folder = snapshot_download(repo_id="umaiku/faiss_index", repo_type="dataset", local_dir=os.getcwd())
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/paraphrase-multilingual-mpnet-base-v2")
vector_db = FAISS.load_local("faiss_index_mpnet_cos", embeddings, allow_dangerous_deserialization=True, distance_strategy=DistanceStrategy.COSINE)
df = pd.read_csv("bger_cedh_db 1954-2024.csv")
retriever_tool = RetrieverTool(vector_db)
agent = ToolCallingAgent(tools=[retriever_tool], model=HfApiModel(model))
def respond(message, history: list[tuple[str, str]], system_message, max_tokens, temperature, top_p, score,):
print(datetime.now())
context = retriever_tool(message)
print(message)
# is_law = client.text_generation(f"""Given the user question below, classify it as either being about "Law" or "Other".
#Do NOT respond with more than one word.
#Question:
#{message}""")
# print(is_law)
if True: #is_law.lower() != "other":
prompt = f"""Given the question and supporting documents below, give a comprehensive answer to the question.
Respond only to the question asked, response should be relevant to the question and in the same language as the question.
Provide the number of the source document when relevant, as well as the link to the document.
If you cannot find information, do not give up and try calling your retriever again with different arguments!
Always give url of the sources at the end and only answer in the language the question is asked.
Question:
{message}
{context}
"""
else:
prompt = f"""A user wrote the following message, please answer him to best of your knowledge in the language of his message:
{message}"""
messages = [{"role": "system", "content": system_message}]
for val in history:
if val[0]:
messages.append({"role": "user", "content": val[0]})
if val[1]:
messages.append({"role": "assistant", "content": val[1]})
messages.append({"role": "user", "content": prompt})
response = ""
for message in client.chat_completion(
messages,
max_tokens=max_tokens,
stream=True,
temperature=temperature,
top_p=top_p,
):
token = message.choices[0].delta.content
response += token
yield response
"""
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
"""
demo = gr.ChatInterface(
respond,
additional_inputs=[
gr.Textbox(value="You are assisting a jurist or a layer in finding relevant Swiss Jurisprudence cases to their question.", label="System message"),
gr.Slider(minimum=1, maximum=24000, value=5000, step=1, label="Max new tokens"),
gr.Slider(minimum=0.1, maximum=4.0, value=0.1, step=0.1, label="Temperature"),
gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.95,
step=0.05,
label="Top-p (nucleus sampling)",
),
gr.Slider(minimum=0, maximum=1, value=0.75, step=0.05, label="Score Threshold"),
],
description="# 📜 ALexI: Artificial Legal Intelligence for Swiss Jurisprudence",
)
if __name__ == "__main__":
print("Ready!")
demo.launch(debug=True) |