prasannahf commited on
Commit
11a14d3
Β·
verified Β·
1 Parent(s): d3e4a2b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -29
app.py CHANGED
@@ -1,5 +1,6 @@
1
  import os
2
  import streamlit as st
 
3
  from langchain.schema import HumanMessage
4
  from langchain_groq import ChatGroq
5
  from langgraph.graph import StateGraph, START, END
@@ -7,18 +8,24 @@ from pydantic import BaseModel
7
  from langsmith import traceable
8
  import traceback
9
 
10
- # βœ… Load API keys
 
11
  GROQ_API_KEY = os.getenv("GROQ_API_KEY")
12
  LANGSMITH_API_KEY = os.getenv("LANGSMITH_API_KEY")
13
 
14
- # βœ… Initialize LLM (Using Groq Llama3-8B)
 
 
 
 
 
15
  llm = ChatGroq(groq_api_key=GROQ_API_KEY, model_name="llama3-8b-8192")
16
 
17
  # βœ… Define the LegalState Model
18
  class LegalState(BaseModel):
19
  original_text: str
20
  tone: str
21
- complexity: int = 1 # Default complexity level
22
  rewritten_text: str = None
23
  summary: str = None
24
  key_clauses: str = None
@@ -30,8 +37,7 @@ class LegalState(BaseModel):
30
  comparison_result: str = None
31
  final_report: str = None
32
 
33
- # βœ… Function to invoke LLM
34
- @traceable(name="Generate Response")
35
  def generate_response(prompt):
36
  try:
37
  response = llm.invoke([HumanMessage(content=prompt)])
@@ -39,10 +45,11 @@ def generate_response(prompt):
39
  except Exception as e:
40
  return f"❌ Error: {str(e)}"
41
 
42
- # βœ… Define Workflow Functions
43
  @traceable(name="Rewrite Legal Text")
44
  def rewrite_text(state: LegalState):
45
- prompt = f"""Rewrite this legal text in '{state.tone}' tone with complexity level {state.complexity}:\n\n{state.original_text}"""
 
46
  return {"rewritten_text": generate_response(prompt)}
47
 
48
  @traceable(name="Summarize Legal Text")
@@ -55,14 +62,23 @@ def extract_clauses(state: LegalState):
55
 
56
  @traceable(name="Final Synthesis")
57
  def synthesizer(state: LegalState):
58
- return {"final_report": f"""
59
- πŸ“œ **Final Legal Report** πŸ“œ\n\n
60
- **Rewritten Legal Text:** {state.rewritten_text}\n\n
61
- **Summary:** {state.summary}\n\n
62
- **Key Clauses:** {state.key_clauses}\n"""}
 
 
 
 
 
 
 
 
63
 
64
- # βœ… Orchestrator
65
  builder = StateGraph(LegalState)
 
66
  builder.add_node("rewrite_text", rewrite_text)
67
  builder.add_node("summarize_text", summarize_text)
68
  builder.add_node("extract_clauses", extract_clauses)
@@ -82,22 +98,18 @@ st.title("πŸ“œ AI-Powered Legal Text Processor")
82
  original_text = st.text_area("Enter Legal Text")
83
  tone = st.radio("Select Tone", ["Formal", "Empathetic", "Neutral", "Strength-Based"])
84
 
85
- if st.button("Generate"):
86
  input_data = {"original_text": original_text, "tone": tone, "complexity": 1}
87
  result = graph.invoke(input_data)
88
-
89
- st.subheader("πŸ”Ή Rewritten Text")
90
- rewritten_text = st.text_area("", value=result["rewritten_text"], height=150)
91
- if st.button("Regenerate"):
92
- input_data["original_text"] = rewritten_text
 
93
  result = graph.invoke(input_data)
94
- st.experimental_rerun()
95
-
96
- st.subheader("πŸ”Ή Summary")
97
- st.write(result["summary"])
98
-
99
- st.subheader("πŸ”Ή Key Clauses")
100
- st.write(result["key_clauses"])
101
-
102
- st.subheader("πŸ“œ Final Report")
103
- st.write(result["final_report"])
 
1
  import os
2
  import streamlit as st
3
+ from dotenv import load_dotenv
4
  from langchain.schema import HumanMessage
5
  from langchain_groq import ChatGroq
6
  from langgraph.graph import StateGraph, START, END
 
8
  from langsmith import traceable
9
  import traceback
10
 
11
+ # βœ… Load API keys securely
12
+ load_dotenv()
13
  GROQ_API_KEY = os.getenv("GROQ_API_KEY")
14
  LANGSMITH_API_KEY = os.getenv("LANGSMITH_API_KEY")
15
 
16
+ # βœ… Set API keys as environment variables
17
+ os.environ["GROQ_API_KEY"] = GROQ_API_KEY
18
+ os.environ["LANGCHAIN_TRACING_V2"] = "true"
19
+ os.environ["LANGCHAIN_API_KEY"] = LANGSMITH_API_KEY
20
+
21
+ # βœ… Initialize Open-Source LLM (Using Groq Llama3-8B)
22
  llm = ChatGroq(groq_api_key=GROQ_API_KEY, model_name="llama3-8b-8192")
23
 
24
  # βœ… Define the LegalState Model
25
  class LegalState(BaseModel):
26
  original_text: str
27
  tone: str
28
+ complexity: int = 1
29
  rewritten_text: str = None
30
  summary: str = None
31
  key_clauses: str = None
 
37
  comparison_result: str = None
38
  final_report: str = None
39
 
40
+ # βœ… Function to invoke LLM with error handling
 
41
  def generate_response(prompt):
42
  try:
43
  response = llm.invoke([HumanMessage(content=prompt)])
 
45
  except Exception as e:
46
  return f"❌ Error: {str(e)}"
47
 
48
+ # βœ… Define Worker Functions with LangSmith Debugging
49
  @traceable(name="Rewrite Legal Text")
50
  def rewrite_text(state: LegalState):
51
+ prompt = f"""Rewrite this legal text in '{state.tone}' tone with complexity level {state.complexity}:
52
+ {state.original_text}"""
53
  return {"rewritten_text": generate_response(prompt)}
54
 
55
  @traceable(name="Summarize Legal Text")
 
62
 
63
  @traceable(name="Final Synthesis")
64
  def synthesizer(state: LegalState):
65
+ final_output = f"""
66
+ πŸ“œ **AI-Powered Legal Document Processing Report** πŸ“œ
67
+
68
+ **πŸ”Ή Rewritten Legal Text:**
69
+ **{state.rewritten_text}**
70
+
71
+ **πŸ”Ή Summary:**
72
+ {state.summary}
73
+
74
+ **πŸ”Ή Key Clauses Identified:**
75
+ {state.key_clauses}
76
+ """
77
+ return {"final_report": final_output}
78
 
79
+ # βœ… Build LangGraph Workflow
80
  builder = StateGraph(LegalState)
81
+
82
  builder.add_node("rewrite_text", rewrite_text)
83
  builder.add_node("summarize_text", summarize_text)
84
  builder.add_node("extract_clauses", extract_clauses)
 
98
  original_text = st.text_area("Enter Legal Text")
99
  tone = st.radio("Select Tone", ["Formal", "Empathetic", "Neutral", "Strength-Based"])
100
 
101
+ if st.button("Rewrite Text"):
102
  input_data = {"original_text": original_text, "tone": tone, "complexity": 1}
103
  result = graph.invoke(input_data)
104
+ st.session_state["rewritten_text"] = result.get("rewritten_text", "")
105
+ st.session_state["show_regen"] = True
106
+
107
+ if "show_regen" in st.session_state and st.session_state["show_regen"]:
108
+ if st.button("Regenerate Text"):
109
+ input_data = {"original_text": original_text, "tone": tone, "complexity": 1}
110
  result = graph.invoke(input_data)
111
+ st.session_state["rewritten_text"] = result.get("rewritten_text", "")
112
+
113
+ if "rewritten_text" in st.session_state:
114
+ st.subheader("Rewritten Text")
115
+ st.write(f"**{st.session_state['rewritten_text']}**")