Aranwer commited on
Commit
38b37ec
·
verified ·
1 Parent(s): b6df14b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -24
app.py CHANGED
@@ -5,11 +5,9 @@ import numpy as np
5
  import ast
6
  import gradio as gr
7
  import faiss
8
-
9
  from sentence_transformers import SentenceTransformer
10
  from transformers import pipeline
11
 
12
- # Unzip the dataset if not already done
13
  zip_path = "lexglue-legal-nlp-benchmark-dataset.zip"
14
  extract_dir = "lexglue_data"
15
 
@@ -17,61 +15,81 @@ if not os.path.exists(extract_dir):
17
  with zipfile.ZipFile(zip_path, 'r') as zip_ref:
18
  zip_ref.extractall(extract_dir)
19
 
20
- # Load CSV from extracted folder
21
  df = pd.read_csv(os.path.join(extract_dir, "case_hold_test.csv"))
22
  df = df[['context', 'endings', 'label']]
23
  df['endings'] = df['endings'].apply(ast.literal_eval)
24
 
25
- # Prepare corpus: concatenate context with each ending
26
  corpus = []
27
  for idx, row in df.iterrows():
28
  context = row['context']
29
  for ending in row['endings']:
30
  corpus.append(f"{context.strip()} {ending.strip()}")
31
 
32
- # Load Sentence Transformer and encode the corpus
33
  embedder = SentenceTransformer('all-MiniLM-L6-v2')
34
  corpus_embeddings = embedder.encode(corpus, show_progress_bar=True)
35
 
36
- # Create FAISS index
37
  dimension = corpus_embeddings.shape[1]
38
  index = faiss.IndexFlatL2(dimension)
39
  index.add(np.array(corpus_embeddings))
40
 
41
- # Load text generation pipeline
42
  generator = pipeline("text-generation", model="gpt2")
43
 
44
- # Query Function
 
 
 
 
 
 
45
  def legal_assistant_query(query):
46
  query_embedding = embedder.encode([query])
47
  D, I = index.search(np.array(query_embedding), k=5)
48
 
49
- # Limit the number of retrieved documents or trim context
50
  retrieved_docs = [corpus[i] for i in I[0]]
51
-
52
- # Combine retrieved documents into a single context and ensure it doesn't exceed token limit
53
- context_combined = "\n\n".join(retrieved_docs[:3]) # Limit to 3 docs to avoid overflow
54
- max_length = 1024 # Set appropriate limit based on GPT-2's token length (around 1024 tokens)
55
-
56
- # Ensure the context combined does not exceed max length
57
  context_combined = context_combined[:max_length]
58
 
59
- # Prepare the prompt for GPT-2
60
  prompt = f"Given the following legal references, answer the question:\n\n{context_combined}\n\nQuestion: {query}\nAnswer:"
61
-
62
- # Generate the response
63
  result = generator(prompt, max_new_tokens=200, do_sample=True)[0]['generated_text']
64
 
65
- # Extract the answer from the generated text
66
- return result.split("Answer:")[-1].strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
68
- # Gradio Interface
69
  iface = gr.Interface(
70
  fn=legal_assistant_query,
71
- inputs=gr.Textbox(lines=2, placeholder="Ask a legal question..."),
72
- outputs=gr.Textbox(label="Legal Response"),
 
 
 
 
 
 
 
73
  title="🧑‍⚖️ Legal Assistant Chatbot",
74
- description="Ask any legal question and get context-based case references using the LexGLUE dataset."
75
  )
76
 
77
  iface.launch()
 
5
  import ast
6
  import gradio as gr
7
  import faiss
 
8
  from sentence_transformers import SentenceTransformer
9
  from transformers import pipeline
10
 
 
11
  zip_path = "lexglue-legal-nlp-benchmark-dataset.zip"
12
  extract_dir = "lexglue_data"
13
 
 
15
  with zipfile.ZipFile(zip_path, 'r') as zip_ref:
16
  zip_ref.extractall(extract_dir)
17
 
 
18
  df = pd.read_csv(os.path.join(extract_dir, "case_hold_test.csv"))
19
  df = df[['context', 'endings', 'label']]
20
  df['endings'] = df['endings'].apply(ast.literal_eval)
21
 
 
22
  corpus = []
23
  for idx, row in df.iterrows():
24
  context = row['context']
25
  for ending in row['endings']:
26
  corpus.append(f"{context.strip()} {ending.strip()}")
27
 
 
28
  embedder = SentenceTransformer('all-MiniLM-L6-v2')
29
  corpus_embeddings = embedder.encode(corpus, show_progress_bar=True)
30
 
 
31
  dimension = corpus_embeddings.shape[1]
32
  index = faiss.IndexFlatL2(dimension)
33
  index.add(np.array(corpus_embeddings))
34
 
 
35
  generator = pipeline("text-generation", model="gpt2")
36
 
37
+ history = []
38
+
39
+ def simplify_legal_text(text):
40
+ prompt = f"Simplify the following legal text into plain English:\n\n{text}"
41
+ simplified_text = generator(prompt, max_new_tokens=100, do_sample=False)[0]['generated_text']
42
+ return simplified_text.strip()
43
+
44
  def legal_assistant_query(query):
45
  query_embedding = embedder.encode([query])
46
  D, I = index.search(np.array(query_embedding), k=5)
47
 
 
48
  retrieved_docs = [corpus[i] for i in I[0]]
49
+ context_combined = "\n\n".join(retrieved_docs[:3])
50
+ max_length = 1024
 
 
 
 
51
  context_combined = context_combined[:max_length]
52
 
 
53
  prompt = f"Given the following legal references, answer the question:\n\n{context_combined}\n\nQuestion: {query}\nAnswer:"
 
 
54
  result = generator(prompt, max_new_tokens=200, do_sample=True)[0]['generated_text']
55
 
56
+ answer = result.split("Answer:")[-1].strip()
57
+
58
+ # Simplify the answer if it's complex
59
+ simplified_answer = simplify_legal_text(answer)
60
+
61
+ # Maintain session history of last 5 questions and answers
62
+ history.append({"question": query, "answer": simplified_answer})
63
+ if len(history) > 5:
64
+ history.pop(0)
65
+
66
+ return simplified_answer
67
+
68
+ def show_history():
69
+ history_text = "\n\n".join([f"Q: {entry['question']}\nA: {entry['answer']}" for entry in history])
70
+ return history_text if history_text else "No history yet."
71
+
72
+ sample_questions = [
73
+ "Can you explain the constitutional rights of a citizen in simple terms?",
74
+ "What does a breach of contract mean?",
75
+ "How do courts determine if someone is guilty of a crime?",
76
+ "What is the difference between civil and criminal law?",
77
+ "Can you explain what 'reasonable doubt' is in a criminal trial?"
78
+ ]
79
 
 
80
  iface = gr.Interface(
81
  fn=legal_assistant_query,
82
+ inputs=[
83
+ gr.Textbox(lines=2, placeholder="Ask a legal question..."),
84
+ gr.Button("Show History")
85
+ ],
86
+ outputs=[
87
+ gr.Textbox(label="Legal Response"),
88
+ gr.Textbox(label="Session History", lines=10),
89
+ gr.Textbox(label="Sample Questions", value="\n".join(sample_questions), lines=6)
90
+ ],
91
  title="🧑‍⚖️ Legal Assistant Chatbot",
92
+ description="Ask any legal question and get context-based case references using the LexGLUE dataset. The assistant will also simplify legal language into plain English and maintain a session history."
93
  )
94
 
95
  iface.launch()