LegalAI-Assist / app.py
prasannahf's picture
Update app.py
66042b6 verified
import os
import streamlit as st
from dotenv import load_dotenv
from langchain.schema import HumanMessage
from langchain_groq import ChatGroq
from langgraph.graph import StateGraph, START, END
from pydantic import BaseModel
from langsmith import traceable
import traceback
# βœ… Load API keys securely
load_dotenv()
GROQ_API_KEY = os.getenv("GROQ_API_KEY")
LANGSMITH_API_KEY = os.getenv("LANGSMITH_API_KEY")
os.environ["GROQ_API_KEY"] = GROQ_API_KEY
os.environ["LANGCHAIN_TRACING_V2"] = "true"
os.environ["LANGCHAIN_API_KEY"] = LANGSMITH_API_KEY
# βœ… Initialize LLM (Using Groq Llama3-8B)
llm = ChatGroq(groq_api_key=GROQ_API_KEY, model_name="llama3-8b-8192")
# βœ… Define the LegalState Model
class LegalState(BaseModel):
original_text: str
tone: str
complexity: int = 1
rewritten_text: str = None
summary: str = None
key_clauses: str = None
risk_analysis: str = None
compliance_report: str = None
contract_suggestions: str = None
legal_arguments: str = None
formatted_text: str = None
comparison_result: str = None
final_report: str = None
# βœ… Function to invoke LLM with error handling
def generate_response(prompt):
try:
response = llm.invoke([HumanMessage(content=prompt)])
return response.content
except Exception as e:
return f"❌ Error: {str(e)}"
# βœ… Define Worker Functions
@traceable(name="Rewrite Legal Text")
def rewrite_text(state: LegalState):
prompt = f"""Rewrite this legal text in '{state.tone}' tone with complexity level {state.complexity}:
{state.original_text}"""
return {"rewritten_text": generate_response(prompt)}
@traceable(name="Summarize Legal Text")
def summarize_text(state: LegalState):
return {"summary": generate_response(f"Summarize this legal text:\n\n{state.rewritten_text}")}
@traceable(name="Extract Key Clauses")
def extract_clauses(state: LegalState):
return {"key_clauses": generate_response(f"Extract key legal clauses:\n\n{state.rewritten_text}")}
@traceable(name="Detect Risks in Document")
def detect_risks(state: LegalState):
return {"risk_analysis": generate_response(f"Analyze for risks:\n\n{state.rewritten_text}")}
@traceable(name="Check Compliance")
def check_compliance(state: LegalState):
return {"compliance_report": generate_response(f"Check legal compliance:\n\n{state.rewritten_text}")}
@traceable(name="Suggest Contract Improvements")
def suggest_improvements(state: LegalState):
return {"contract_suggestions": generate_response(f"Suggest improvements:\n\n{state.rewritten_text}")}
@traceable(name="Generate Legal Arguments")
def generate_arguments(state: LegalState):
return {"legal_arguments": generate_response(f"Generate legal arguments:\n\n{state.rewritten_text}")}
@traceable(name="Format Legal Document")
def format_document(state: LegalState):
return {"formatted_text": generate_response(f"Format this legal document:\n\n{state.rewritten_text}")}
@traceable(name="Compare Original vs Rewritten")
def compare_texts(state: LegalState):
prompt = f"Compare original vs rewritten:\n\nOriginal: {state.original_text}\nRewritten: {state.rewritten_text}"
return {"comparison_result": generate_response(prompt)}
# βœ… Build LangGraph Workflow
builder = StateGraph(LegalState)
builder.add_node("rewrite_text", rewrite_text)
builder.add_node("summarize_text", summarize_text)
builder.add_node("extract_clauses", extract_clauses)
builder.add_node("detect_risks", detect_risks)
builder.add_node("check_compliance", check_compliance)
builder.add_node("suggest_improvements", suggest_improvements)
builder.add_node("generate_arguments", generate_arguments)
builder.add_node("format_document", format_document)
builder.add_node("compare_texts", compare_texts)
builder.add_edge(START, "rewrite_text")
for node in ["summarize_text", "extract_clauses", "detect_risks", "check_compliance", "suggest_improvements", "generate_arguments", "format_document", "compare_texts"]:
builder.add_edge("rewrite_text", node)
graph = builder.compile()
# βœ… Streamlit UI
def main():
st.title("πŸ“œ AI-Powered Legal Text Processor")
st.write("A smart legal document assistant powered by LLMs.")
original_text = st.text_area("Enter Legal Text:")
tone = st.radio("Select Tone:", ["Formal", "Empathetic", "Neutral", "Strength-Based"])
if st.button("Process Text"):
input_data = {"original_text": original_text, "tone": tone, "complexity": 1}
result = graph.invoke(input_data)
st.subheader("πŸ”Ή Rewritten Text")
st.markdown(f"**{result['rewritten_text']}**")
st.subheader("πŸ“Œ Summary")
st.write(result['summary'])
st.subheader("πŸ“œ Key Clauses")
st.write(result['key_clauses'])
st.subheader("⚠️ Risk Analysis")
st.write(result['risk_analysis'])
st.subheader("βœ… Compliance Report")
st.write(result['compliance_report'])
st.subheader("πŸ’‘ Contract Suggestions")
st.write(result['contract_suggestions'])
st.subheader("βš–οΈ Legal Arguments")
st.write(result['legal_arguments'])
st.subheader("πŸ“ Formatted Legal Document")
st.write(result['formatted_text'])
st.subheader("πŸ” Comparison Result")
st.write(result['comparison_result'])
if __name__ == "__main__":
main()