Spaces:
Build error
Build error
| import streamlit as st | |
| import os | |
| from langchain_groq import ChatGroq | |
| from langgraph.graph import MessagesState, StateGraph, START | |
| from langgraph.checkpoint.memory import MemorySaver | |
| from langchain_core.messages import HumanMessage, AIMessage | |
| # Load API key from Hugging Face Secrets | |
| GROQ_API_KEY = os.getenv("GROQ_API_KEY") | |
| if not GROQ_API_KEY: | |
| st.error("GROQ_API_KEY not found. Please add it as a Secret in Hugging Face Spaces.") | |
| st.stop() | |
| # Initialize the Groq LLM | |
| llm = ChatGroq(groq_api_key=GROQ_API_KEY, model_name="llama3-8b-8192") | |
| # Helper functions | |
| def extract_display_text(message): | |
| if "Rewritten text in" in message: | |
| return message.split(": ", 1)[1] | |
| elif "Text approved:" in message: | |
| return message | |
| return message | |
| # Assistant function | |
| def assistant(state: MessagesState): | |
| messages = state["messages"] | |
| latest_msg = messages[-1].content | |
| if "Rewrite this text in" in latest_msg: | |
| tone = latest_msg.split(" in ")[1].split(" tone:")[0] | |
| text = latest_msg.split("tone: ")[1] | |
| response = llm.invoke(f"Rewrite the following legal text in a {tone} tone: {text}") | |
| return {"messages": [AIMessage(content=f"Rewritten text in {tone} tone: {response.content}")]} | |
| elif latest_msg == "regenerate": | |
| for msg in messages: | |
| if isinstance(msg, HumanMessage) and "Rewrite this text in" in msg.content: | |
| original_text = msg.content.split("tone: ")[1] | |
| tone = msg.content.split(" in ")[1].split(" tone:")[0] | |
| response = llm.invoke(f"Rewrite the following legal text in a {tone} tone: {original_text}") | |
| return {"messages": [AIMessage(content=f"Rewritten text in {tone} tone: {response.content}")]} | |
| elif latest_msg.startswith("feedback:"): | |
| feedback = latest_msg.split("feedback: ")[1].strip() | |
| for msg in reversed(messages): | |
| if isinstance(msg, AIMessage) and "Rewritten text in" in msg.content: | |
| text = msg.content.split(": ", 1)[1] | |
| tone = msg.content.split(" in ")[1].split(" tone")[0] | |
| response = llm.invoke(f"Refine the following text based on this feedback: '{feedback}'. Maintain the {tone} tone. Text: {text}") | |
| return {"messages": [AIMessage(content=f"Rewritten text in {tone} tone: {response.content}")]} | |
| elif latest_msg.lower() == "approve": | |
| for msg in reversed(messages): | |
| if isinstance(msg, AIMessage) and "Rewritten text in" in msg.content: | |
| return {"messages": [AIMessage(content=f"Text approved: {msg.content.split(': ', 1)[1]}")]} | |
| return {"messages": [AIMessage(content="Invalid command.")]} | |
| # Human feedback node (stop condition) | |
| def human_feedback(state: MessagesState): | |
| return state | |
| # Build the LangGraph pipeline | |
| builder = StateGraph(MessagesState) | |
| builder.add_node("assistant", assistant) | |
| builder.add_node("human_feedback", human_feedback) | |
| # Define edges with a stopping condition | |
| builder.add_edge(START, "assistant") | |
| builder.add_edge("assistant", "human_feedback") # Stops after assistant | |
| builder.add_edge("human_feedback", "assistant") # Only runs if more input is provided | |
| memory = MemorySaver() | |
| graph = builder.compile(interrupt_before=["human_feedback"], checkpointer=memory) | |
| # Streamlit UI | |
| st.title("π Legal Text Rewriter") | |
| st.markdown("Rewrite legal text into different tones using AI.") | |
| # Store session state | |
| if "state" not in st.session_state: | |
| st.session_state.state = {"configurable": {"thread_id": "1"}} | |
| # User input section | |
| legal_text = st.text_area("Enter Legal Text:", "") | |
| initial_tone = st.selectbox("Select Tone:", ["Formal", "Empathetic", "Neutral", "Strength-Based"]) | |
| submit_btn = st.button("Submit") | |
| if submit_btn and legal_text.strip(): | |
| command = f"Rewrite this text in {initial_tone} tone: {legal_text}" | |
| graph.update_state(st.session_state.state, {"messages": [HumanMessage(content=command)]}) | |
| for event in graph.stream(None, st.session_state.state, stream_mode="values"): | |
| if "messages" in event: | |
| st.session_state.rewritten_text = extract_display_text(event["messages"][-1].content) | |
| # Display rewritten text | |
| if "rewritten_text" in st.session_state: | |
| st.subheader("Rewritten Text") | |
| st.write(st.session_state.rewritten_text) | |
| # Buttons | |
| col1, col2, col3 = st.columns([1, 2, 1]) | |
| with col1: | |
| if st.button("π Regenerate"): | |
| graph.update_state(st.session_state.state, {"messages": [HumanMessage(content="regenerate")]}) | |
| for event in graph.stream(None, st.session_state.state, stream_mode="values"): | |
| if "messages" in event: | |
| st.session_state.rewritten_text = extract_display_text(event["messages"][-1].content) | |
| with col2: | |
| feedback_text = st.text_input("Enter Feedback:") | |
| if st.button("π¬ Submit Feedback"): | |
| graph.update_state(st.session_state.state, {"messages": [HumanMessage(content=f'feedback: {feedback_text}')]}) | |
| for event in graph.stream(None, st.session_state.state, stream_mode="values"): | |
| if "messages" in event: | |
| st.session_state.rewritten_text = extract_display_text(event["messages"][-1].content) | |
| with col3: | |
| if st.button("β Approve"): | |
| graph.update_state(st.session_state.state, {"messages": [HumanMessage(content="approve")]}) | |
| for event in graph.stream(None, st.session_state.state, stream_mode="values"): | |
| if "messages" in event: | |
| st.session_state.rewritten_text = extract_display_text(event["messages"][-1].content) | |