menikev commited on
Commit
af2b4ba
·
verified ·
1 Parent(s): 3db42ba

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +70 -62
app.py CHANGED
@@ -33,56 +33,32 @@ retriever = get_retriever()
33
  pipe = pipeline(
34
  "text-generation",
35
  model="google/flan-t5-base", # ✅ smaller + CPU friendly
36
- max_new_tokens=512,
37
- temperature=0.3
 
 
38
  )
39
  llm = HuggingFacePipeline(pipeline=pipe)
40
 
41
 
42
  # =====================================================
43
- # Prompts
44
  # =====================================================
45
- english_system_prompt = """
46
- You are a Nigerian Legal AI Assistant specialized in Nigerian law. You have deep knowledge of:
47
- - Nigerian Constitution 1999
48
- - Labour Act and Employment Laws
49
- - Nigeria Data Protection Act
50
- - Federal Competition and Consumer Protection Act (FCCPA)
51
-
52
- PERSONALITY: Professional but approachable, uses Nigerian legal terminology, understands local context.
53
-
54
- RESPONSE STYLE:
55
- - Start with direct answer to the question
56
- - Quote specific sections/articles when available
57
- - Explain in simple terms what the law means
58
- - Always include disclaimer: "⚠️ This is not legal advice. Please consult a qualified lawyer for specific issues."
59
- - Use Nigerian English expressions naturally (but not forced)
60
-
61
- CONVERSATION MEMORY: Remember previous questions in this chat to provide contextual follow-ups.
62
  """
63
 
64
- pidgin_system_prompt = """
65
- You be Nigerian Legal AI Assistant wey sabi Nigerian law well well. You get knowledge of:
66
- - Nigerian Constitution 1999
67
- - Labour Act and Employment Laws
68
- - Nigeria Data Protection Act
69
- - Federal Competition and Consumer Protection Act (FCCPA)
70
-
71
- PERSONALITY: Friendly, approachable, dey use Naija way of talk but still correct for legal matter.
72
-
73
- RESPONSE STYLE:
74
- - Start with direct answer
75
- - Mention the exact section/article if available
76
- - Explain am for clear Pidgin wey anybody fit understand
77
- - Always add disclaimer: "⚠️ No be legal advice o, abeg meet lawyer if matter serious."
78
- - Remember wetin dem don ask before, make conversation flow well.
79
  """
80
 
81
 
82
  # =====================================================
83
- # Conversational QA Chain
84
  # =====================================================
85
- memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
 
 
 
 
86
 
87
  qa_chain = ConversationalRetrievalChain.from_llm(
88
  llm=llm,
@@ -93,31 +69,55 @@ qa_chain = ConversationalRetrievalChain.from_llm(
93
 
94
 
95
  # =====================================================
96
- # Chat function
97
  # =====================================================
98
  def answer_question(user_input, lang_choice, history=[]):
99
- # Pick system prompt
100
- if lang_choice == "pidgin":
101
- system_prompt = pidgin_system_prompt
102
- else:
103
- system_prompt = english_system_prompt
104
-
105
- # Run QA
106
- result = qa_chain.invoke({"question": f"{system_prompt}\n\nUser: {user_input}"})
107
- answer = result["answer"]
108
-
109
- # Collect sources (with sections)
110
- sources = []
111
- for doc in result["source_documents"]:
112
- section = doc.metadata.get("section", "Unknown Section")
113
- source = doc.metadata.get("source", "Unknown Document").replace(".pdf", "")
114
- sources.append(f"[{section}] from {source}")
115
-
116
- if sources:
117
- answer += "\n\n📚 Sources:\n" + "\n".join(sources)
118
-
119
- history.append(("You: " + user_input, "Bot: " + answer))
120
- return history, history
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
 
122
 
123
  # =====================================================
@@ -129,16 +129,24 @@ with gr.Blocks(css=".gradio-container {max-width: 800px !important}") as demo:
129
  with gr.Row():
130
  with gr.Column(scale=4):
131
  chatbot = gr.Chatbot(label="Chat with Legal AI", height=500)
132
- msg = gr.Textbox(label="Ask your question here...")
 
 
 
 
133
  lang_choice = gr.Radio(["english", "pidgin"], value="english", label="Language")
134
  clear = gr.Button("Clear Chat")
135
 
136
  state = gr.State([])
137
 
138
  def reset():
 
 
 
139
  return [], []
140
 
141
  msg.submit(answer_question, [msg, lang_choice, state], [chatbot, state])
 
142
  clear.click(reset, None, [chatbot, state])
143
 
144
- demo.launch()
 
33
  pipe = pipeline(
34
  "text-generation",
35
  model="google/flan-t5-base", # ✅ smaller + CPU friendly
36
+ max_new_tokens=256, # Reduced from 512 to fit within context
37
+ temperature=0.3,
38
+ do_sample=True,
39
+ pad_token_id=0 # Add padding token
40
  )
41
  llm = HuggingFacePipeline(pipeline=pipe)
42
 
43
 
44
  # =====================================================
45
+ # Prompts (shortened to reduce token usage)
46
  # =====================================================
47
+ english_system_prompt = """You are a Nigerian Legal AI Assistant. Provide direct answers about Nigerian law with relevant sections/articles. Always end with: "⚠️ This is not legal advice. Consult a qualified lawyer."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  """
49
 
50
+ pidgin_system_prompt = """You be Nigerian Legal AI Assistant. Give direct answer about Nigerian law with correct section/article. Always end with: "⚠️ No be legal advice o, abeg meet lawyer."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  """
52
 
53
 
54
  # =====================================================
55
+ # Conversational QA Chain with fixed memory
56
  # =====================================================
57
+ memory = ConversationBufferMemory(
58
+ memory_key="chat_history",
59
+ return_messages=True,
60
+ output_key="answer" # Fix: specify which output to store in memory
61
+ )
62
 
63
  qa_chain = ConversationalRetrievalChain.from_llm(
64
  llm=llm,
 
69
 
70
 
71
  # =====================================================
72
+ # Chat function with better token management
73
  # =====================================================
74
  def answer_question(user_input, lang_choice, history=[]):
75
+ try:
76
+ # Pick system prompt
77
+ if lang_choice == "pidgin":
78
+ system_prompt = pidgin_system_prompt
79
+ else:
80
+ system_prompt = english_system_prompt
81
+
82
+ # Truncate user input if too long
83
+ max_input_length = 200 # Limit user input length
84
+ if len(user_input) > max_input_length:
85
+ user_input = user_input[:max_input_length] + "..."
86
+
87
+ # Create shorter question format
88
+ question = f"{system_prompt}\nQ: {user_input}"
89
+
90
+ # Run QA
91
+ result = qa_chain.invoke({"question": question})
92
+ answer = result["answer"]
93
+
94
+ # Collect sources (with sections) - limit to top 3
95
+ sources = []
96
+ for doc in result["source_documents"][:3]: # Limit to top 3 sources
97
+ section = doc.metadata.get("section", "Unknown Section")
98
+ source = doc.metadata.get("source", "Unknown Document").replace(".pdf", "")
99
+ sources.append(f"[{section}] from {source}")
100
+
101
+ if sources:
102
+ answer += "\n\n📚 Sources:\n" + "\n".join(sources)
103
+
104
+ # Truncate answer if too long
105
+ max_answer_length = 800
106
+ if len(answer) > max_answer_length:
107
+ answer = answer[:max_answer_length] + "...\n\n⚠️ Response truncated due to length limits."
108
+
109
+ history.append(("You: " + user_input, "Bot: " + answer))
110
+
111
+ # Limit history to last 5 exchanges to prevent memory overflow
112
+ if len(history) > 5:
113
+ history = history[-5:]
114
+
115
+ return history, history
116
+
117
+ except Exception as e:
118
+ error_msg = f"Sorry, I encountered an error: {str(e)[:100]}..."
119
+ history.append(("You: " + user_input, "Bot: " + error_msg))
120
+ return history, history
121
 
122
 
123
  # =====================================================
 
129
  with gr.Row():
130
  with gr.Column(scale=4):
131
  chatbot = gr.Chatbot(label="Chat with Legal AI", height=500)
132
+ msg = gr.Textbox(
133
+ label="Ask your question here...",
134
+ placeholder="e.g., What are the rights of employees in Nigeria?",
135
+ max_lines=3
136
+ )
137
  lang_choice = gr.Radio(["english", "pidgin"], value="english", label="Language")
138
  clear = gr.Button("Clear Chat")
139
 
140
  state = gr.State([])
141
 
142
  def reset():
143
+ # Clear memory as well
144
+ global memory
145
+ memory.clear()
146
  return [], []
147
 
148
  msg.submit(answer_question, [msg, lang_choice, state], [chatbot, state])
149
+ msg.submit(lambda: "", None, msg) # Clear input after submit
150
  clear.click(reset, None, [chatbot, state])
151
 
152
+ demo.launch()