Legal_YTS / app.py
prasannahf's picture
Update app.py
733239a verified
import streamlit as st
import os
from langchain_groq import ChatGroq
from langgraph.graph import MessagesState, StateGraph, START
from langgraph.checkpoint.memory import MemorySaver
from langchain_core.messages import HumanMessage, AIMessage
# Load API key from Hugging Face Secrets
GROQ_API_KEY = os.getenv("GROQ_API_KEY")
if not GROQ_API_KEY:
st.error("GROQ_API_KEY not found. Please add it as a Secret in Hugging Face Spaces.")
st.stop()
# Initialize the Groq LLM
llm = ChatGroq(groq_api_key=GROQ_API_KEY, model_name="llama3-8b-8192")
# Helper functions
def extract_display_text(message):
if "Rewritten text in" in message:
return message.split(": ", 1)[1]
elif "Text approved:" in message:
return message
return message
# Assistant function
def assistant(state: MessagesState):
messages = state["messages"]
latest_msg = messages[-1].content
if "Rewrite this text in" in latest_msg:
tone = latest_msg.split(" in ")[1].split(" tone:")[0]
text = latest_msg.split("tone: ")[1]
response = llm.invoke(f"Rewrite the following legal text in a {tone} tone: {text}")
return {"messages": [AIMessage(content=f"Rewritten text in {tone} tone: {response.content}")]}
elif latest_msg == "regenerate":
for msg in messages:
if isinstance(msg, HumanMessage) and "Rewrite this text in" in msg.content:
original_text = msg.content.split("tone: ")[1]
tone = msg.content.split(" in ")[1].split(" tone:")[0]
response = llm.invoke(f"Rewrite the following legal text in a {tone} tone: {original_text}")
return {"messages": [AIMessage(content=f"Rewritten text in {tone} tone: {response.content}")]}
elif latest_msg.startswith("feedback:"):
feedback = latest_msg.split("feedback: ")[1].strip()
for msg in reversed(messages):
if isinstance(msg, AIMessage) and "Rewritten text in" in msg.content:
text = msg.content.split(": ", 1)[1]
tone = msg.content.split(" in ")[1].split(" tone")[0]
response = llm.invoke(f"Refine the following text based on this feedback: '{feedback}'. Maintain the {tone} tone. Text: {text}")
return {"messages": [AIMessage(content=f"Rewritten text in {tone} tone: {response.content}")]}
elif latest_msg.lower() == "approve":
for msg in reversed(messages):
if isinstance(msg, AIMessage) and "Rewritten text in" in msg.content:
return {"messages": [AIMessage(content=f"Text approved: {msg.content.split(': ', 1)[1]}")]}
return {"messages": [AIMessage(content="Invalid command.")]}
# Human feedback node (stop condition)
def human_feedback(state: MessagesState):
return state
# Build the LangGraph pipeline
builder = StateGraph(MessagesState)
builder.add_node("assistant", assistant)
builder.add_node("human_feedback", human_feedback)
# Define edges with a stopping condition
builder.add_edge(START, "assistant")
builder.add_edge("assistant", "human_feedback") # Stops after assistant
builder.add_edge("human_feedback", "assistant") # Only runs if more input is provided
memory = MemorySaver()
graph = builder.compile(interrupt_before=["human_feedback"], checkpointer=memory)
# Streamlit UI
st.title("πŸ“œ Legal Text Rewriter")
st.markdown("Rewrite legal text into different tones using AI.")
# Store session state
if "state" not in st.session_state:
st.session_state.state = {"configurable": {"thread_id": "1"}}
# User input section
legal_text = st.text_area("Enter Legal Text:", "")
initial_tone = st.selectbox("Select Tone:", ["Formal", "Empathetic", "Neutral", "Strength-Based"])
submit_btn = st.button("Submit")
if submit_btn and legal_text.strip():
command = f"Rewrite this text in {initial_tone} tone: {legal_text}"
graph.update_state(st.session_state.state, {"messages": [HumanMessage(content=command)]})
for event in graph.stream(None, st.session_state.state, stream_mode="values"):
if "messages" in event:
st.session_state.rewritten_text = extract_display_text(event["messages"][-1].content)
# Display rewritten text
if "rewritten_text" in st.session_state:
st.subheader("Rewritten Text")
st.write(st.session_state.rewritten_text)
# Buttons
col1, col2, col3 = st.columns([1, 2, 1])
with col1:
if st.button("πŸ”„ Regenerate"):
graph.update_state(st.session_state.state, {"messages": [HumanMessage(content="regenerate")]})
for event in graph.stream(None, st.session_state.state, stream_mode="values"):
if "messages" in event:
st.session_state.rewritten_text = extract_display_text(event["messages"][-1].content)
with col2:
feedback_text = st.text_input("Enter Feedback:")
if st.button("πŸ’¬ Submit Feedback"):
graph.update_state(st.session_state.state, {"messages": [HumanMessage(content=f'feedback: {feedback_text}')]})
for event in graph.stream(None, st.session_state.state, stream_mode="values"):
if "messages" in event:
st.session_state.rewritten_text = extract_display_text(event["messages"][-1].content)
with col3:
if st.button("βœ… Approve"):
graph.update_state(st.session_state.state, {"messages": [HumanMessage(content="approve")]})
for event in graph.stream(None, st.session_state.state, stream_mode="values"):
if "messages" in event:
st.session_state.rewritten_text = extract_display_text(event["messages"][-1].content)