File size: 10,606 Bytes
eec757a
 
 
 
 
 
 
95c9540
eec757a
 
 
 
 
95c9540
eec757a
 
 
 
 
 
bc9198b
eec757a
 
 
95c9540
 
eec757a
bc9198b
eec757a
bc9198b
eec757a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95c9540
eec757a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95c9540
 
 
 
 
 
 
eec757a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95c9540
bc9198b
eec757a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95c9540
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
from flask import Flask, request, jsonify
from llama_index.core import VectorStoreIndex
from llama_index.core import Settings
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.llms.groq import Groq
from llama_index.vector_stores.pinecone import PineconeVectorStore
from pinecone import Pinecone
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from PyPDF2 import PdfReader
from flask_cors import CORS
from functools import wraps
from dotenv import load_dotenv
from huggingface_hub import InferenceClient
import re, jwt, os, json, gc, torch

load_dotenv()

SECRET_KEY = os.getenv("SECRET_KEY")

# Initialize Hugging Face Inference Client for embeddings
client = InferenceClient(token=os.getenv("HF_API_KEY"))

# Load summarization model and tokenizer
model_path = "Jurisight/legal_led"
model = AutoModelForSeq2SeqLM.from_pretrained(model_path, token=os.getenv("HF_API_KEY"))
tokenizer = AutoTokenizer.from_pretrained(model_path, token=os.getenv("HF_API_KEY"))

embed_model = "BAAI/bge-base-en-v1.5"
# Configure LlamaIndex settings
Settings.embed_model = HuggingFaceEmbedding(model_name=embed_model)
Settings.llm = Groq(model="llama3-8b-8192", api_key=os.getenv("GROQ_API_KEY"))

# Initialize Pinecone
pc = Pinecone(api_key=os.getenv("PINECONE_API_KEY"))
pinecone_index_chat = "llamaindex"
pinecone_index_retrieval = "judgment-search"

app = Flask(__name__)
CORS(app)

# Authentication decorator
def authenticate_user(f):
    @wraps(f)
    def decorated_function(*args, **kwargs):
        token = request.headers.get("x-auth-token")
        if not token:
            return jsonify({"error": "Authentication token is missing"}), 401
        try:
            decoded_token = jwt.decode(token, SECRET_KEY, algorithms=["HS256"])
            user_id = decoded_token["id"]
            if not user_id:
                return jsonify({"error": "Invalid token structure"}), 401
        except jwt.ExpiredSignatureError:
            return jsonify({"error": "Token has expired"}), 401
        except jwt.InvalidTokenError:
            return jsonify({"error": "Invalid token"}), 401
        return f(user_id, *args, **kwargs)
    return decorated_function

# System prompt for the chatbot
SYSTEM_PROMPT = (
    "You are Jurisight, a highly knowledgeable legal chatbot. Your purpose is to assist "
    "users with questions related to legal documents, laws, judgments, and legal topics. "
    "Do not answer questions unrelated to the legal domain. Provide accurate and concise "
    "legal responses based on your training and knowledge.\n\n"
    "If the user has uploaded a document, consider only the most recently uploaded document "
    "and its generated summary in your responses. Forget any previous documents or summaries "
    "when a new one is uploaded. If no document has been uploaded, do not assume otherwise.\n\n"
    "Maintain continuity by considering the chat history. If a user follows up on a previous question, "
    "use the past interactions for context rather than responding in isolation. However, "
    "do not reference any document unless one is currently available."
)

# Global storage for document text, summaries, and chat history
document_text_storage = {}
summarized_content = {}
context_text = ""
chat_history = {}

# Function to extract entities from text
def extract_entities(text):
    llm = Groq(model="llama3-8b-8192", api_key=os.getenv("GROQ_API_KEY"))
    prompt = f"""
    Read the following legal document and extract structured data in valid JSON format.
    If some values are missing, **generate a concise 50-100 word summary** based on the document’s context.

    Ensure the following fields are always present:
    - "Client Name": Extract or infer the client's full name.
    - "Gender": Identify if explicitly mentioned; otherwise, infer based on name.
    - "Matter": Identify the case type or legal matter.
    - "Client Objectives": Summarize the client's main objective.
    - "Custody Status": Extract whether the petitioner is in custody (Yes/No).
    - "Crime Registered": Indicate whether a crime has been registered (Yes/No).
    - "Application Filing": Indicate whether an application has been filed (Yes/No).
    - "Legal Analysis.Prayer Details": Summarize the relief sought in 50-100 words.
    - "Legal Analysis.Interim Relief Details": Summarize any interim relief in 50-100 words.
    - "Legal Analysis.Grounds": Extract or infer legal grounds in 50-100 words.

    Return only a valid JSON object without any extra text.

    Document:
    {text}
    """
    try:
        response = llm.complete(prompt)
        extracted_text = response.text.strip()
        json_start = extracted_text.find("{")
        json_end = extracted_text.rfind("}") + 1
        json_data = extracted_text[json_start:json_end]
        return json.loads(json_data)
    except json.JSONDecodeError:
        return {}
    except AttributeError:
        return {}

# Chat endpoint
@app.route('/chat', methods=['POST'])
@authenticate_user
def chat(user_id):
    global context_text, chat_history
    try:
        if not request.json or 'message' not in request.json:
            return jsonify({"error": "Invalid request format"}), 400

        user_message = request.json['message']
        document_text = document_text_storage.get(user_id, "")
        summary_text = summarized_content.get(user_id, "")
        if document_text and "Document Context:" not in context_text:
            context_text += f"Document Context:\n{document_text}\n\n"
        if summary_text and "Summarized Content:" not in context_text:
            context_text += f"Summarized Content:\n{summary_text}\n\n"
        chat_memory = "\n".join(chat_history.get(user_id, [])[-10:])
        formatted_message = f"{SYSTEM_PROMPT}\n\nDocument Context:\n{document_text}\n\nSummarized Content:\n{summary_text}\n\nChat History:\n{chat_memory}\nUser: {user_message}\nJurisight:"
        pinecone_index = pc.Index(pinecone_index_chat)
        vector_store = PineconeVectorStore(pinecone_index=pinecone_index)
        index = VectorStoreIndex.from_vector_store(vector_store=vector_store)
        query_engine = index.as_query_engine()
        response = query_engine.query(formatted_message)
        chat_history.setdefault(user_id, []).append(f"User: {user_message}")
        chat_history[user_id].append(f"Jurisight: {response}")
        response = {"response": f"{response}"}
        return jsonify(response), 200
    except Exception as e:
        return jsonify({"error": "Internal server error"}), 500

# Summarize endpoint
@app.route('/summarize', methods=['POST'])
@authenticate_user
def summarize(user_id):
    def clean_text(text):
        cleaned_text = re.sub(r'\s+', ' ', text).strip()
        return cleaned_text

    def summarize_legal_document(document_text, chunk_size=1024, max_output_length=128):
        try:
            chunks = [document_text[i:i+chunk_size] for i in range(0, len(document_text), chunk_size)]
            summaries = []
            for chunk in chunks:
                inputs = tokenizer(
                    chunk,
                    max_length=chunk_size,
                    padding="max_length",
                    truncation=True,
                    return_tensors="pt"
                )
                with torch.no_grad():  # Disable gradients for memory optimization
                    summary_ids = model.generate(
                        inputs["input_ids"],
                        num_beams=4,
                        max_length=max_output_length,
                        early_stopping=True
                    )
                summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True).strip()
                summaries.append(summary)
            return " ".join(summaries)
        except Exception as e:
            raise

    if 'file' not in request.files:
        return jsonify({"error": "No file provided"}), 400

    file = request.files['file']
    if file.filename == '':
        return jsonify({"error": "Empty file uploaded"}), 400

    try:
        reader = PdfReader(file)
        document_text = ""
        for page in reader.pages:
            text = page.extract_text()
            if text:
                document_text += text.strip() + " "

        document_text = clean_text(document_text)
        if not document_text or len(document_text.split()) < 10:
            return jsonify({"error": "The document does not contain sufficient readable text."}), 400
        
        document_text_storage[user_id] = document_text
        summary = summarize_legal_document(document_text)
        summarized_content[user_id] = summary
        return jsonify({"summary": summary}), 200
    except Exception as e:
        return jsonify({"error": "Error processing the file"}), 500

# Retrieve cases endpoint
@app.route('/retrieve-cases', methods=['POST'])
@authenticate_user
def retrieve_cases(user_id):
    def generate_embedding(text):
        # Use Hugging Face Inference API for embeddings
        result = client.feature_extraction(model=embed_model, text=text)
        return result

    def query_pinecone(query_text, top_k=10):
        query_embedding = generate_embedding(query_text)
        retrieval_index = pc.Index(pinecone_index_retrieval)
        results = retrieval_index.query(vector=query_embedding.tolist(), top_k=top_k, include_metadata=True)
        return results
    
    if not request.json:
        return jsonify({"error": "No file or query provided"}), 400
    
    document_text = document_text_storage.get(user_id, None)
    if not document_text:
        return jsonify({"error": "No document available for retrieval"}), 400
    
    try:
        top_k = request.json.get('top_k', 10)
        results = query_pinecone(document_text, top_k=top_k)
        if not results['matches']:
            return jsonify({"error": "No relevant cases found."}), 200
        case_links = [{"score": result['score'], "url": result['metadata']['url']} for result in results['matches']]
        return jsonify({"case_links": case_links}), 200
    except Exception as e:
        return jsonify({"error": "Error processing the file"}), 500

# Fetch form data endpoint
@app.route('/fetch-form-data', methods=['GET'])
@authenticate_user
def fetch_form_data(user_id):
    if user_id not in document_text_storage:
        return jsonify({"error": "No document found"}), 400
    extracted_data = extract_entities(document_text_storage[user_id])
    return jsonify(extracted_data), 200

# Run the app
if __name__ == '__main__':
    app.run(debug=True, host='0.0.0.0', port=7860)