prasannahf commited on
Commit
3ee0652
Β·
verified Β·
1 Parent(s): 7c18bcd

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +146 -0
app.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import streamlit as st
3
+ from dotenv import load_dotenv
4
+ from langchain.schema import HumanMessage
5
+ from langchain_groq import ChatGroq
6
+ from langgraph.graph import StateGraph, START, END
7
+ from pydantic import BaseModel
8
+ from langsmith import traceable
9
+ import traceback
10
+
11
+ # βœ… Load API keys securely
12
+ load_dotenv()
13
+ GROQ_API_KEY = os.getenv("GROQ_API_KEY")
14
+ LANGSMITH_API_KEY = os.getenv("LANGSMITH_API_KEY")
15
+
16
+ os.environ["GROQ_API_KEY"] = GROQ_API_KEY
17
+ os.environ["LANGCHAIN_TRACING_V2"] = "true"
18
+ os.environ["LANGCHAIN_API_KEY"] = LANGSMITH_API_KEY
19
+
20
+ # βœ… Initialize LLM (Using Groq Llama3-8B)
21
+ llm = ChatGroq(groq_api_key=GROQ_API_KEY, model_name="llama3-8b-8192")
22
+
23
+ # βœ… Define the LegalState Model
24
+ class LegalState(BaseModel):
25
+ original_text: str
26
+ tone: str
27
+ complexity: int = 1
28
+ rewritten_text: str = None
29
+ summary: str = None
30
+ key_clauses: str = None
31
+ risk_analysis: str = None
32
+ compliance_report: str = None
33
+ contract_suggestions: str = None
34
+ legal_arguments: str = None
35
+ formatted_text: str = None
36
+ comparison_result: str = None
37
+ final_report: str = None
38
+
39
+ # βœ… Function to invoke LLM with error handling
40
+ def generate_response(prompt):
41
+ try:
42
+ response = llm.invoke([HumanMessage(content=prompt)])
43
+ return response.content
44
+ except Exception as e:
45
+ return f"❌ Error: {str(e)}"
46
+
47
+ # βœ… Define Worker Functions
48
+ @traceable(name="Rewrite Legal Text")
49
+ def rewrite_text(state: LegalState):
50
+ prompt = f"""Rewrite this legal text in '{state.tone}' tone with complexity level {state.complexity}:
51
+
52
+ {state.original_text}"""
53
+ return {"rewritten_text": generate_response(prompt)}
54
+
55
+ @traceable(name="Summarize Legal Text")
56
+ def summarize_text(state: LegalState):
57
+ return {"summary": generate_response(f"Summarize this legal text:\n\n{state.rewritten_text}")}
58
+
59
+ @traceable(name="Extract Key Clauses")
60
+ def extract_clauses(state: LegalState):
61
+ return {"key_clauses": generate_response(f"Extract key legal clauses:\n\n{state.rewritten_text}")}
62
+
63
+ @traceable(name="Detect Risks in Document")
64
+ def detect_risks(state: LegalState):
65
+ return {"risk_analysis": generate_response(f"Analyze for risks:\n\n{state.rewritten_text}")}
66
+
67
+ @traceable(name="Check Compliance")
68
+ def check_compliance(state: LegalState):
69
+ return {"compliance_report": generate_response(f"Check legal compliance:\n\n{state.rewritten_text}")}
70
+
71
+ @traceable(name="Suggest Contract Improvements")
72
+ def suggest_improvements(state: LegalState):
73
+ return {"contract_suggestions": generate_response(f"Suggest improvements:\n\n{state.rewritten_text}")}
74
+
75
+ @traceable(name="Generate Legal Arguments")
76
+ def generate_arguments(state: LegalState):
77
+ return {"legal_arguments": generate_response(f"Generate legal arguments:\n\n{state.rewritten_text}")}
78
+
79
+ @traceable(name="Format Legal Document")
80
+ def format_document(state: LegalState):
81
+ return {"formatted_text": generate_response(f"Format this legal document:\n\n{state.rewritten_text}")}
82
+
83
+ @traceable(name="Compare Original vs Rewritten")
84
+ def compare_texts(state: LegalState):
85
+ prompt = f"Compare original vs rewritten:\n\nOriginal: {state.original_text}\nRewritten: {state.rewritten_text}"
86
+ return {"comparison_result": generate_response(prompt)}
87
+
88
+ # βœ… Build LangGraph Workflow
89
+ builder = StateGraph(LegalState)
90
+ builder.add_node("rewrite_text", rewrite_text)
91
+ builder.add_node("summarize_text", summarize_text)
92
+ builder.add_node("extract_clauses", extract_clauses)
93
+ builder.add_node("detect_risks", detect_risks)
94
+ builder.add_node("check_compliance", check_compliance)
95
+ builder.add_node("suggest_improvements", suggest_improvements)
96
+ builder.add_node("generate_arguments", generate_arguments)
97
+ builder.add_node("format_document", format_document)
98
+ builder.add_node("compare_texts", compare_texts)
99
+
100
+ builder.add_edge(START, "rewrite_text")
101
+ for node in ["summarize_text", "extract_clauses", "detect_risks", "check_compliance", "suggest_improvements", "generate_arguments", "format_document", "compare_texts"]:
102
+ builder.add_edge("rewrite_text", node)
103
+
104
+ graph = builder.compile()
105
+
106
+ # βœ… Streamlit UI
107
+ def main():
108
+ st.title("πŸ“œ AI-Powered Legal Text Processor")
109
+ st.write("A smart legal document assistant powered by LLMs.")
110
+
111
+ original_text = st.text_area("Enter Legal Text:")
112
+ tone = st.radio("Select Tone:", ["Formal", "Empathetic", "Neutral", "Strength-Based"])
113
+
114
+ if st.button("Process Text"):
115
+ input_data = {"original_text": original_text, "tone": tone, "complexity": 1}
116
+ result = graph.invoke(input_data)
117
+
118
+ st.subheader("πŸ”Ή Rewritten Text")
119
+ st.markdown(f"**{result['rewritten_text']}**")
120
+
121
+ st.subheader("πŸ“Œ Summary")
122
+ st.write(result['summary'])
123
+
124
+ st.subheader("πŸ“œ Key Clauses")
125
+ st.write(result['key_clauses'])
126
+
127
+ st.subheader("⚠️ Risk Analysis")
128
+ st.write(result['risk_analysis'])
129
+
130
+ st.subheader("βœ… Compliance Report")
131
+ st.write(result['compliance_report'])
132
+
133
+ st.subheader("πŸ’‘ Contract Suggestions")
134
+ st.write(result['contract_suggestions'])
135
+
136
+ st.subheader("βš–οΈ Legal Arguments")
137
+ st.write(result['legal_arguments'])
138
+
139
+ st.subheader("πŸ“ Formatted Legal Document")
140
+ st.write(result['formatted_text'])
141
+
142
+ st.subheader("πŸ” Comparison Result")
143
+ st.write(result['comparison_result'])
144
+
145
+ if __name__ == "__main__":
146
+ main()