omm7 commited on
Commit
f6585d4
·
verified ·
1 Parent(s): 2bb7f94

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +121 -0
app.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import uuid
4
+ import gradio as gr
5
+ import chromadb
6
+ import numpy as np
7
+ from pathlib import Path
8
+ from sentence_transformers import SentenceTransformer
9
+ from huggingface_hub import CommitScheduler
10
+ from chromadb.errors import NotFoundError
11
+ from openai import OpenAI
12
+
13
+ # Load embedding model
14
+ embed_model = SentenceTransformer("BAAI/bge-small-en-v1.5")
15
+
16
+ # Load ChromaDB client
17
+ chroma_client = chromadb.PersistentClient(path="./clause_index")
18
+ try:
19
+ collection = chroma_client.get_collection("legal_clauses")
20
+ except NotFoundError:
21
+ collection = None
22
+
23
+ # Setup OpenAI/Hugging Face client
24
+ client = OpenAI(
25
+ base_url="https://router.huggingface.co/featherless-ai/v1",
26
+ api_key=os.getenv("HF_TOKEN"),
27
+ )
28
+
29
+ # Prompt template
30
+ system_message = """You are a legal AI assistant trained on contract clause examples from the CUAD dataset.
31
+ If no clauses are retrieved from the database, infer the answer using your understanding of common contractual standards. and report that no clause retrieved"""
32
+ user_template = """
33
+ ### Context:
34
+ {context}
35
+
36
+ ### Question:
37
+ {question}
38
+ """
39
+
40
+ # Setup logging
41
+ log_file = Path("logs/") / f"query_{uuid.uuid4()}.json"
42
+ log_file.parent.mkdir(exist_ok=True)
43
+ scheduler = CommitScheduler(
44
+ repo_id="legal-rag-output",
45
+ repo_type="dataset",
46
+ folder_path=log_file.parent,
47
+ path_in_repo="logs",
48
+ every=2
49
+ )
50
+
51
+ # Main QA function
52
+ def predict(question):
53
+ try:
54
+ # Encode query
55
+ query_embedding = embed_model.encode([question], normalize_embeddings=True)[0]
56
+
57
+ # Default fallback context
58
+ context = "No relevant clauses were found in the database. Please answer using your legal understanding from the CUAD dataset."
59
+
60
+ # If collection exists, try retrieval
61
+ if collection:
62
+ try:
63
+ results = collection.query(
64
+ query_embeddings=[query_embedding.tolist()],
65
+ n_results=3
66
+ )
67
+ documents = results["documents"][0]
68
+ metadatas = results["metadatas"][0]
69
+
70
+ if documents:
71
+ context = "\n\n".join(
72
+ f"[Clause Type: {m['clause_type']}] {doc}"
73
+ for doc, m in zip(documents, metadatas)
74
+ )
75
+ except Exception:
76
+ context = "Due to an internal retrieval issue, please answer based on your legal knowledge from CUAD dataset."
77
+
78
+ # Construct prompt
79
+ prompt = [
80
+ {"role": "system", "content": system_message},
81
+ {"role": "user", "content": user_template.format(context=context, question=question)}
82
+ ]
83
+
84
+ # Generate response
85
+ stream = client.chat.completions.create(
86
+ model="mistralai/Mistral-7B-Instruct-v0.2",
87
+ messages=prompt,
88
+ temperature=0.4,
89
+ top_p=0.7,
90
+ stream=True
91
+ )
92
+
93
+ output = ""
94
+ for chunk in stream:
95
+ output += chunk.choices[0].delta.content or ""
96
+
97
+ except Exception as e:
98
+ output = f"An internal error occurred while generating the response: {str(e)}"
99
+
100
+ # Log to file
101
+ with scheduler.lock:
102
+ with log_file.open("a") as f:
103
+ f.write(json.dumps({
104
+ "question": question,
105
+ "context": context,
106
+ "response": output
107
+ }) + "\n")
108
+
109
+ return output
110
+
111
+ # Gradio UI
112
+ demo = gr.Interface(
113
+ fn=predict,
114
+ inputs=gr.Textbox(label="Enter your legal question:", lines=4),
115
+ outputs=gr.Textbox(label="Answer"),
116
+ title="⚖️ GL_LegalMind",
117
+ description="Ask contract-related legal questions. Answers are based on retrieved clauses or inferred from CUAD knowledge."
118
+ )
119
+
120
+ demo.queue()
121
+ demo.launch()