test_law / app.py
omm7's picture
Upload app.py with huggingface_hub
8ea81be verified
import os
import json
import uuid
import gradio as gr
import chromadb
import numpy as np
from pathlib import Path
from sentence_transformers import SentenceTransformer
from huggingface_hub import CommitScheduler
from chromadb.errors import NotFoundError
from openai import OpenAI
# Load embedding model
embed_model = SentenceTransformer("BAAI/bge-small-en-v1.5")
# Load ChromaDB client
chroma_client = chromadb.PersistentClient(path="./clause_index")
try:
collection = chroma_client.get_collection("legal_clauses")
except NotFoundError:
collection = None
# Setup OpenAI/Hugging Face client
client = OpenAI(
base_url="https://router.huggingface.co/featherless-ai/v1",
api_key=os.getenv("HF_TOKEN"),
)
# Prompt template
system_message = """You are a legal AI assistant trained on contract clause examples from the CUAD dataset.
If no clauses are retrieved from the database, infer the answer using your understanding of common contractual standards. and report that no clause retrieved"""
user_template = """
### Context:
{context}
### Question:
{question}
"""
# Setup logging
log_file = Path("logs/") / f"query_{uuid.uuid4()}.json"
log_file.parent.mkdir(exist_ok=True)
scheduler = CommitScheduler(
repo_id="legal-rag-output",
repo_type="dataset",
folder_path=log_file.parent,
path_in_repo="logs",
every=2
)
# Main QA function
def predict(question):
try:
# Encode query
query_embedding = embed_model.encode([question], normalize_embeddings=True)[0]
# Default fallback context
context = "No relevant clauses were found in the database. Please answer using your legal understanding from the CUAD dataset."
# If collection exists, try retrieval
if collection:
try:
results = collection.query(
query_embeddings=[query_embedding.tolist()],
n_results=3
)
documents = results["documents"][0]
metadatas = results["metadatas"][0]
if documents:
context = "\n\n".join(
f"[Clause Type: {m['clause_type']}] {doc}"
for doc, m in zip(documents, metadatas)
)
except Exception:
context = "Due to an internal retrieval issue, please answer based on your legal knowledge from CUAD dataset."
# Construct prompt
prompt = [
{"role": "system", "content": system_message},
{"role": "user", "content": user_template.format(context=context, question=question)}
]
# Generate response
stream = client.chat.completions.create(
model="mistralai/Mistral-7B-Instruct-v0.2",
messages=prompt,
temperature=0.4,
top_p=0.7,
stream=True
)
output = ""
for chunk in stream:
output += chunk.choices[0].delta.content or ""
except Exception as e:
output = f"An internal error occurred while generating the response: {str(e)}"
# Log to file
with scheduler.lock:
with log_file.open("a") as f:
f.write(json.dumps({
"question": question,
"context": context,
"response": output
}) + "\n")
return output
# Gradio UI
demo = gr.Interface(
fn=predict,
inputs=gr.Textbox(label="Enter your legal question:", lines=4),
outputs=gr.Textbox(label="Answer"),
title="⚖️ GL_LegalMind",
description="Ask contract-related legal questions. Answers are based on retrieved clauses or inferred from CUAD knowledge."
)
demo.queue()
demo.launch()