|
|
import os |
|
|
import json |
|
|
import uuid |
|
|
import gradio as gr |
|
|
import chromadb |
|
|
import numpy as np |
|
|
from pathlib import Path |
|
|
from sentence_transformers import SentenceTransformer |
|
|
from huggingface_hub import CommitScheduler |
|
|
from chromadb.errors import NotFoundError |
|
|
from openai import OpenAI |
|
|
|
|
|
|
|
|
embed_model = SentenceTransformer("BAAI/bge-small-en-v1.5") |
|
|
|
|
|
|
|
|
chroma_client = chromadb.PersistentClient(path="./clause_index") |
|
|
try: |
|
|
collection = chroma_client.get_collection("legal_clauses") |
|
|
except NotFoundError: |
|
|
collection = None |
|
|
|
|
|
|
|
|
client = OpenAI( |
|
|
base_url="https://router.huggingface.co/featherless-ai/v1", |
|
|
api_key=os.getenv("HF_TOKEN"), |
|
|
) |
|
|
|
|
|
|
|
|
system_message = """You are a legal AI assistant trained on contract clause examples from the CUAD dataset. |
|
|
If no clauses are retrieved from the database, infer the answer using your understanding of common contractual standards. and report that no clause retrieved""" |
|
|
user_template = """ |
|
|
### Context: |
|
|
{context} |
|
|
|
|
|
### Question: |
|
|
{question} |
|
|
""" |
|
|
|
|
|
|
|
|
log_file = Path("logs/") / f"query_{uuid.uuid4()}.json" |
|
|
log_file.parent.mkdir(exist_ok=True) |
|
|
scheduler = CommitScheduler( |
|
|
repo_id="legal-rag-output", |
|
|
repo_type="dataset", |
|
|
folder_path=log_file.parent, |
|
|
path_in_repo="logs", |
|
|
every=2 |
|
|
) |
|
|
|
|
|
|
|
|
def predict(question): |
|
|
try: |
|
|
|
|
|
query_embedding = embed_model.encode([question], normalize_embeddings=True)[0] |
|
|
|
|
|
|
|
|
context = "No relevant clauses were found in the database. Please answer using your legal understanding from the CUAD dataset." |
|
|
|
|
|
|
|
|
if collection: |
|
|
try: |
|
|
results = collection.query( |
|
|
query_embeddings=[query_embedding.tolist()], |
|
|
n_results=3 |
|
|
) |
|
|
documents = results["documents"][0] |
|
|
metadatas = results["metadatas"][0] |
|
|
|
|
|
if documents: |
|
|
context = "\n\n".join( |
|
|
f"[Clause Type: {m['clause_type']}] {doc}" |
|
|
for doc, m in zip(documents, metadatas) |
|
|
) |
|
|
except Exception: |
|
|
context = "Due to an internal retrieval issue, please answer based on your legal knowledge from CUAD dataset." |
|
|
|
|
|
|
|
|
prompt = [ |
|
|
{"role": "system", "content": system_message}, |
|
|
{"role": "user", "content": user_template.format(context=context, question=question)} |
|
|
] |
|
|
|
|
|
|
|
|
stream = client.chat.completions.create( |
|
|
model="mistralai/Mistral-7B-Instruct-v0.2", |
|
|
messages=prompt, |
|
|
temperature=0.4, |
|
|
top_p=0.7, |
|
|
stream=True |
|
|
) |
|
|
|
|
|
output = "" |
|
|
for chunk in stream: |
|
|
output += chunk.choices[0].delta.content or "" |
|
|
|
|
|
except Exception as e: |
|
|
output = f"An internal error occurred while generating the response: {str(e)}" |
|
|
|
|
|
|
|
|
with scheduler.lock: |
|
|
with log_file.open("a") as f: |
|
|
f.write(json.dumps({ |
|
|
"question": question, |
|
|
"context": context, |
|
|
"response": output |
|
|
}) + "\n") |
|
|
|
|
|
return output |
|
|
|
|
|
|
|
|
demo = gr.Interface( |
|
|
fn=predict, |
|
|
inputs=gr.Textbox(label="Enter your legal question:", lines=4), |
|
|
outputs=gr.Textbox(label="Answer"), |
|
|
title="⚖️ GL_LegalMind", |
|
|
description="Ask contract-related legal questions. Answers are based on retrieved clauses or inferred from CUAD knowledge." |
|
|
) |
|
|
|
|
|
demo.queue() |
|
|
demo.launch() |
|
|
|