NeerAbhy's picture
Update app.py
bd3b1d7 verified
import time
import os
import streamlit as st
from dotenv import load_dotenv
from PyPDF2 import PdfReader
from langchain_groq import ChatGroq
from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain.prompts import PromptTemplate
from langchain.memory import ConversationBufferWindowMemory
from langchain.chains import ConversationalRetrievalChain
from footer import footer
# ───────────────────────────────
# PAGE CONFIG
# ───────────────────────────────
st.set_page_config(page_title="LLM Legal Advisor", layout="centered")
col1, col2, col3 = st.columns([1, 30, 1])
with col2:
st.image("images/law.png", use_column_width=True)
st.markdown("""
<style>
#MainMenu {visibility:hidden;}
footer {visibility:hidden;}
</style>
""", unsafe_allow_html=True)
# ───────────────────────────────
# SESSION STATE
# ───────────────────────────────
if "messages" not in st.session_state:
st.session_state.messages = []
if "memory" not in st.session_state:
st.session_state.memory = ConversationBufferWindowMemory(
k=2,
memory_key="chat_history",
return_messages=True
)
if "uploaded_doc_text" not in st.session_state:
st.session_state.uploaded_doc_text = None
# ───────────────────────────────
# LOAD EMBEDDINGS
# ───────────────────────────────
@st.cache_resource
def load_embeddings():
return HuggingFaceEmbeddings(model_name="law-ai/InLegalBERT")
embeddings = load_embeddings()
db = FAISS.load_local(
"ipc_embed_db",
embeddings,
allow_dangerous_deserialization=True
)
db_retriever = db.as_retriever(
search_type="similarity",
search_kwargs={"k": 3}
)
# ───────────────────────────────
# PROMPT TEMPLATE
# ───────────────────────────────
prompt_template = """
<s>[INST]
As a legal chatbot specializing in the Indian Penal Code, provide clear and accurate answers.
Rules:
- Use bullet points
- Avoid unnecessary complexity
- Clarify misconceptions
- End with a short summary
CONTEXT: {context}
CHAT HISTORY: {chat_history}
QUESTION: {question}
ANSWER:
- [Key legal point]
- [Clarification]
- [Exception]
- [Summary]
</s>[INST]
"""
prompt = PromptTemplate(
template=prompt_template,
input_variables=["context", "question", "chat_history"],
)
# ───────────────────────────────
# GROQ API SETUP
# ───────────────────────────────
if "GROQ_API_KEY" not in st.session_state:
api_key = os.getenv("GROQ_API_KEY") or st.text_input(
"Enter your Groq API key:",
type="password"
)
if api_key:
st.session_state["GROQ_API_KEY"] = api_key
os.environ["GROQ_API_KEY"] = api_key
st.success("API key set successfully")
else:
st.warning("Please enter your Groq API key")
st.stop()
else:
os.environ["GROQ_API_KEY"] = st.session_state["GROQ_API_KEY"]
# ───────────────────────────────
# LLM
# ───────────────────────────────
llm = ChatGroq(
model="llama-3.1-8b-instant",
api_key=os.environ["GROQ_API_KEY"]
)
qa = ConversationalRetrievalChain.from_llm(
llm=llm,
memory=st.session_state.memory,
retriever=db_retriever,
combine_docs_chain_kwargs={"prompt": prompt}
)
# ───────────────────────────────
# FUNCTIONS
# ───────────────────────────────
def extract_answer(full_response):
start = full_response.find("Response:")
if start != -1:
return full_response[start + len("Response:"):].strip()
return full_response.strip()
def reset_conversation():
st.session_state.messages = []
st.session_state.memory.clear()
# SAFE FILE READER
def read_uploaded_file(file):
filename = file.name.lower()
if filename.endswith(".pdf"):
pdf = PdfReader(file)
text = ""
for page in pdf.pages:
text += page.extract_text() or ""
return text.strip()
elif filename.endswith(".txt"):
content = file.read()
try:
return content.decode("utf-8").strip()
except UnicodeDecodeError:
return content.decode("latin-1").strip()
else:
return "Unsupported file type"
# ───────────────────────────────
# UI
# ───────────────────────────────
st.header("βš–οΈ AI-Powered Legal Advisor")
uploaded_file = st.file_uploader(
"πŸ“Ž Upload legal document (PDF or TXT)",
type=["pdf", "txt"]
)
if uploaded_file is not None:
st.session_state.uploaded_doc_text = read_uploaded_file(uploaded_file)
st.success("Document uploaded successfully")
# ───────────────────────────────
# SHOW CHAT HISTORY
# ───────────────────────────────
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.write(message["content"])
# ───────────────────────────────
# USER INPUT
# ───────────────────────────────
input_prompt = st.chat_input("Ask a legal question...")
if input_prompt:
with st.chat_message("user"):
st.markdown(input_prompt)
st.session_state.messages.append(
{"role": "user", "content": input_prompt}
)
with st.chat_message("assistant"):
with st.spinner("Thinking..."):
if "summarize" in input_prompt.lower() or "summary" in input_prompt.lower():
if st.session_state.uploaded_doc_text:
summary_prompt = f"""
You are a legal assistant.
Summarize the following legal document.
Include:
- Purpose
- Key clauses
- Legal obligations
- Simple summary
Document:
{st.session_state.uploaded_doc_text[:6000]}
"""
summary = llm.invoke(summary_prompt)
answer = summary.content
else:
answer = "Please upload a document first."
else:
result = qa.invoke({"question": input_prompt})
answer = extract_answer(result["answer"])
st.write(answer)
st.session_state.messages.append(
{"role": "assistant", "content": answer}
)
# ───────────────────────────────
# RESET BUTTON
# ───────────────────────────────
if st.button("Reset Chat"):
reset_conversation()
st.rerun()
footer()