menikev commited on
Commit
a508099
·
verified ·
1 Parent(s): af2b4ba

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -22
app.py CHANGED
@@ -34,25 +34,39 @@ pipe = pipeline(
34
  "text-generation",
35
  model="google/flan-t5-base", # ✅ smaller + CPU friendly
36
  max_new_tokens=256, # Reduced from 512 to fit within context
37
- temperature=0.3,
38
  do_sample=True,
39
- pad_token_id=0 # Add padding token
 
 
40
  )
41
  llm = HuggingFacePipeline(pipeline=pipe)
42
 
43
 
44
  # =====================================================
45
- # Prompts (shortened to reduce token usage)
46
  # =====================================================
47
- english_system_prompt = """You are a Nigerian Legal AI Assistant. Provide direct answers about Nigerian law with relevant sections/articles. Always end with: "⚠️ This is not legal advice. Consult a qualified lawyer."
48
- """
49
 
50
- pidgin_system_prompt = """You be Nigerian Legal AI Assistant. Give direct answer about Nigerian law with correct section/article. Always end with: "⚠️ No be legal advice o, abeg meet lawyer."
51
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
 
54
  # =====================================================
55
- # Conversational QA Chain with fixed memory
56
  # =====================================================
57
  memory = ConversationBufferMemory(
58
  memory_key="chat_history",
@@ -65,32 +79,51 @@ qa_chain = ConversationalRetrievalChain.from_llm(
65
  retriever=retriever,
66
  memory=memory,
67
  return_source_documents=True,
 
68
  )
69
 
70
 
71
  # =====================================================
72
- # Chat function with better token management
73
  # =====================================================
74
  def answer_question(user_input, lang_choice, history=[]):
75
  try:
76
- # Pick system prompt
77
- if lang_choice == "pidgin":
78
- system_prompt = pidgin_system_prompt
79
- else:
80
- system_prompt = english_system_prompt
 
 
 
 
 
81
 
82
  # Truncate user input if too long
83
  max_input_length = 200 # Limit user input length
84
  if len(user_input) > max_input_length:
85
  user_input = user_input[:max_input_length] + "..."
86
 
87
- # Create shorter question format
88
- question = f"{system_prompt}\nQ: {user_input}"
89
-
90
- # Run QA
91
- result = qa_chain.invoke({"question": question})
92
  answer = result["answer"]
93
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  # Collect sources (with sections) - limit to top 3
95
  sources = []
96
  for doc in result["source_documents"][:3]: # Limit to top 3 sources
@@ -98,11 +131,11 @@ def answer_question(user_input, lang_choice, history=[]):
98
  source = doc.metadata.get("source", "Unknown Document").replace(".pdf", "")
99
  sources.append(f"[{section}] from {source}")
100
 
101
- if sources:
102
  answer += "\n\n📚 Sources:\n" + "\n".join(sources)
103
 
104
  # Truncate answer if too long
105
- max_answer_length = 800
106
  if len(answer) > max_answer_length:
107
  answer = answer[:max_answer_length] + "...\n\n⚠️ Response truncated due to length limits."
108
 
@@ -115,7 +148,10 @@ def answer_question(user_input, lang_choice, history=[]):
115
  return history, history
116
 
117
  except Exception as e:
118
- error_msg = f"Sorry, I encountered an error: {str(e)[:100]}..."
 
 
 
119
  history.append(("You: " + user_input, "Bot: " + error_msg))
120
  return history, history
121
 
 
34
  "text-generation",
35
  model="google/flan-t5-base", # ✅ smaller + CPU friendly
36
  max_new_tokens=256, # Reduced from 512 to fit within context
37
+ temperature=0.7,
38
  do_sample=True,
39
+ pad_token_id=0, # Add padding token
40
+ truncation=True,
41
+ return_full_text=False # Only return generated text, not the prompt
42
  )
43
  llm = HuggingFacePipeline(pipeline=pipe)
44
 
45
 
46
  # =====================================================
47
+ # Custom prompt template for better responses
48
  # =====================================================
49
+ custom_template = """Based on the following Nigerian law documents, answer the user's question clearly and directly.
 
50
 
51
+ Context: {context}
52
+
53
+ Question: {question}
54
+
55
+ Instructions:
56
+ - Give a direct, helpful answer
57
+ - Quote specific sections when relevant
58
+ - Use simple, clear language
59
+ - For greetings, respond politely and ask how you can help with Nigerian law
60
+
61
+ Answer:"""
62
+
63
+ PROMPT = PromptTemplate(
64
+ template=custom_template, input_variables=["context", "question"]
65
+ )
66
 
67
 
68
  # =====================================================
69
+ # Conversational QA Chain with custom prompt
70
  # =====================================================
71
  memory = ConversationBufferMemory(
72
  memory_key="chat_history",
 
79
  retriever=retriever,
80
  memory=memory,
81
  return_source_documents=True,
82
+ combine_docs_chain_kwargs={"prompt": PROMPT} # Use custom prompt
83
  )
84
 
85
 
86
  # =====================================================
87
+ # Chat function with better response handling
88
  # =====================================================
89
  def answer_question(user_input, lang_choice, history=[]):
90
  try:
91
+ # Handle greetings and simple queries
92
+ user_lower = user_input.lower().strip()
93
+ if user_lower in ["hello", "hi", "hey", "good morning", "good afternoon", "good evening"]:
94
+ if lang_choice == "pidgin":
95
+ response = "Hello! How far? I be your Nigerian Legal AI Assistant. Wetin you wan know about Nigerian law today? ⚠️ No be legal advice o, abeg meet lawyer if matter serious."
96
+ else:
97
+ response = "Hello! I'm your Nigerian Legal AI Assistant. How can I help you with Nigerian law today? ⚠️ This is not legal advice. Please consult a qualified lawyer for specific issues."
98
+
99
+ history.append(("You: " + user_input, "Bot: " + response))
100
+ return history, history
101
 
102
  # Truncate user input if too long
103
  max_input_length = 200 # Limit user input length
104
  if len(user_input) > max_input_length:
105
  user_input = user_input[:max_input_length] + "..."
106
 
107
+ # Run QA with simple question
108
+ result = qa_chain.invoke({"question": user_input})
 
 
 
109
  answer = result["answer"]
110
 
111
+ # Clean up the answer - remove any retrieval artifacts
112
+ if "Use the following pieces of context" in answer:
113
+ # If the model returns retrieval instructions, provide a fallback
114
+ if lang_choice == "pidgin":
115
+ answer = "I dey try find information about your question for Nigerian law documents. Wetin specifically you wan know? ⚠️ No be legal advice o."
116
+ else:
117
+ answer = "I'm searching through Nigerian law documents for your question. Could you be more specific about what you'd like to know? ⚠️ This is not legal advice."
118
+
119
+ # Add disclaimer if not present
120
+ if lang_choice == "pidgin":
121
+ if "No be legal advice" not in answer:
122
+ answer += "\n\n⚠️ No be legal advice o, abeg meet lawyer if matter serious."
123
+ else:
124
+ if "not legal advice" not in answer.lower():
125
+ answer += "\n\n⚠️ This is not legal advice. Please consult a qualified lawyer for specific issues."
126
+
127
  # Collect sources (with sections) - limit to top 3
128
  sources = []
129
  for doc in result["source_documents"][:3]: # Limit to top 3 sources
 
131
  source = doc.metadata.get("source", "Unknown Document").replace(".pdf", "")
132
  sources.append(f"[{section}] from {source}")
133
 
134
+ if sources and len(answer) < 400: # Only add sources if answer isn't too long
135
  answer += "\n\n📚 Sources:\n" + "\n".join(sources)
136
 
137
  # Truncate answer if too long
138
+ max_answer_length = 600
139
  if len(answer) > max_answer_length:
140
  answer = answer[:max_answer_length] + "...\n\n⚠️ Response truncated due to length limits."
141
 
 
148
  return history, history
149
 
150
  except Exception as e:
151
+ if lang_choice == "pidgin":
152
+ error_msg = f"Sorry o, I get small wahala: {str(e)[:50]}... Try ask again."
153
+ else:
154
+ error_msg = f"Sorry, I encountered an error: {str(e)[:50]}... Please try asking again."
155
  history.append(("You: " + user_input, "Bot: " + error_msg))
156
  return history, history
157