Aranwer commited on
Commit
904c6a6
·
verified ·
1 Parent(s): d26beb1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -95
app.py CHANGED
@@ -1,113 +1,66 @@
 
 
 
 
 
1
  import gradio as gr
2
- from datasets import load_dataset
3
- from sentence_transformers import SentenceTransformer
4
  import faiss
5
- import numpy as np
 
6
  from transformers import pipeline
7
 
8
- # Load dataset
9
- dataset = load_dataset("lex_glue", "scotus")
10
- corpus = [doc['text'] for doc in dataset['train'].select(range(200))] # just 200 to keep it light
 
 
 
 
 
 
 
 
 
11
 
12
- # Embedding model
13
- embedder = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
14
- corpus_embeddings = embedder.encode(corpus, convert_to_numpy=True)
 
 
 
15
 
16
- # Build FAISS index
 
 
 
 
17
  dimension = corpus_embeddings.shape[1]
18
  index = faiss.IndexFlatL2(dimension)
19
- index.add(corpus_embeddings)
20
 
21
- # Text generation model
22
- gen_pipeline = pipeline("text2text-generation", model="facebook/bart-large-cnn")
23
 
24
- # RAG-like query function
25
- def rag_query(user_question):
26
- # Encode the user question
27
- question_embedding = embedder.encode([user_question])
28
-
29
- k = 3 # top 3 documents
30
- if index.ntotal < k:
31
- k = index.ntotal # Adjust if there are fewer documents than requested
32
-
33
- # Perform the search in the FAISS index
34
- _, indices = index.search(np.array(question_embedding), k=k)
35
 
36
- # Ensure indices are valid (within range of the corpus)
37
- valid_indices = [i for i in indices[0] if i < len(corpus)]
38
-
39
- if len(valid_indices) == 0:
40
- return "Sorry, no relevant documents were found."
41
-
42
- # Extract relevant context from the corpus based on valid indices
43
- context = " ".join([corpus[i] for i in valid_indices])
44
 
45
- # Prepare the prompt and generate the response
46
- prompt = f"Question: {user_question}\nContext: {context}\nAnswer:"
47
- result = gen_pipeline(prompt, max_length=250, do_sample=False)[0]['generated_text']
48
 
49
- return result
50
-
51
- # Gradio UI
52
- def chatbot_interface(query):
53
- return rag_query(query)
54
-
55
- # Styling for the interface
56
- css = """
57
- .gradio-container {
58
- background-color: #f0f4f8;
59
- font-family: Arial, sans-serif;
60
- }
61
- .gradio-input {
62
- background-color: #ffffff;
63
- border-radius: 5px;
64
- border: 1px solid #d1d1d1;
65
- font-size: 16px;
66
- padding: 10px;
67
- }
68
- .gradio-button {
69
- background-color: #4CAF50;
70
- color: white;
71
- border-radius: 5px;
72
- border: none;
73
- padding: 10px 20px;
74
- font-size: 16px;
75
- }
76
- .gradio-button:hover {
77
- background-color: #45a049;
78
- }
79
- .gradio-output {
80
- background-color: #ffffff;
81
- border-radius: 5px;
82
- padding: 15px;
83
- font-size: 16px;
84
- border: 1px solid #d1d1d1;
85
- }
86
- .gradio-title {
87
- font-size: 28px;
88
- font-weight: bold;
89
- color: #333333;
90
- text-align: center;
91
- margin-bottom: 20px;
92
- }
93
- .gradio-description {
94
- font-size: 16px;
95
- color: #666666;
96
- text-align: center;
97
- margin-bottom: 30px;
98
- }
99
- """
100
 
101
- # Create the Gradio interface
102
  iface = gr.Interface(
103
- fn=chatbot_interface,
104
- inputs="text",
105
- outputs="text",
106
  title="🧑‍⚖️ Legal Assistant Chatbot",
107
- description="Ask legal questions based on case data (LexGLUE - SCOTUS subset). Get answers derived from relevant court case texts.",
108
- theme="compact",
109
- css=css
110
  )
111
 
112
- # Launch the Gradio interface
113
- iface.launch()
 
1
+ import zipfile
2
+ import os
3
+ import pandas as pd
4
+ import numpy as np
5
+ import ast
6
  import gradio as gr
 
 
7
  import faiss
8
+
9
+ from sentence_transformers import SentenceTransformer
10
  from transformers import pipeline
11
 
12
+ # Unzip the dataset if not already done
13
+ zip_path = "lexglue-legal-nlp-benchmark-dataset.zip"
14
+ extract_dir = "lexglue_data"
15
+
16
+ if not os.path.exists(extract_dir):
17
+ with zipfile.ZipFile(zip_path, 'r') as zip_ref:
18
+ zip_ref.extractall(extract_dir)
19
+
20
+ # Load CSV from extracted folder
21
+ df = pd.read_csv(os.path.join(extract_dir, "case_hold_test.csv"))
22
+ df = df[['context', 'endings', 'label']]
23
+ df['endings'] = df['endings'].apply(ast.literal_eval)
24
 
25
+ # Prepare corpus: concatenate context with each ending
26
+ corpus = []
27
+ for idx, row in df.iterrows():
28
+ context = row['context']
29
+ for ending in row['endings']:
30
+ corpus.append(f"{context.strip()} {ending.strip()}")
31
 
32
+ # Load Sentence Transformer and encode the corpus
33
+ embedder = SentenceTransformer('all-MiniLM-L6-v2')
34
+ corpus_embeddings = embedder.encode(corpus, show_progress_bar=True)
35
+
36
+ # Create FAISS index
37
  dimension = corpus_embeddings.shape[1]
38
  index = faiss.IndexFlatL2(dimension)
39
+ index.add(np.array(corpus_embeddings))
40
 
41
+ # Load text generation pipeline
42
+ generator = pipeline("text-generation", model="gpt2")
43
 
44
+ # Query Function
45
+ def legal_assistant_query(query):
46
+ query_embedding = embedder.encode([query])
47
+ D, I = index.search(np.array(query_embedding), k=5)
 
 
 
 
 
 
 
48
 
49
+ retrieved_docs = [corpus[i] for i in I[0]]
50
+ context_combined = "\n\n".join(retrieved_docs)
 
 
 
 
 
 
51
 
52
+ prompt = f"Given the following legal references, answer the question:\n\n{context_combined}\n\nQuestion: {query}\nAnswer:"
53
+ result = generator(prompt, max_new_tokens=200, do_sample=True)[0]['generated_text']
 
54
 
55
+ return result.split("Answer:")[-1].strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
+ # Gradio Interface
58
  iface = gr.Interface(
59
+ fn=legal_assistant_query,
60
+ inputs=gr.Textbox(lines=2, placeholder="Ask a legal question..."),
61
+ outputs=gr.Textbox(label="Legal Response"),
62
  title="🧑‍⚖️ Legal Assistant Chatbot",
63
+ description="Ask any legal question and get context-based case references using the LexGLUE dataset."
 
 
64
  )
65
 
66
+ iface.launch()