import os import uuid import gradio as gr from langgraph.graph import MessagesState, StateGraph, START from langgraph.checkpoint.memory import MemorySaver from langchain_groq import ChatGroq from langchain_core.messages import HumanMessage, AIMessage # Get API key from environment variables GROQ_API_KEY = os.environ['GROQ_API_KEY'] # Initialize the Groq LLM llm = ChatGroq(groq_api_key=GROQ_API_KEY, model_name="llama3-8b-8192") # Helper functions def get_original_text(messages): for msg in messages: if isinstance(msg, HumanMessage) and "Rewrite this text in" in msg.content: return msg.content.split("tone: ")[1] return "" def get_current_text_and_tone(messages): for msg in reversed(messages): if isinstance(msg, AIMessage): return msg.content, "Unknown" # Simplified since we now strip metadata return None, None def extract_display_text(message): return message.strip() # Assistant node with clean output def assistant(state: MessagesState): messages = state["messages"] latest_msg = messages[-1].content if "Rewrite this text in" in latest_msg: tone = latest_msg.split(" in ")[1].split(" tone:")[0] text = latest_msg.split("tone: ")[1] prompt = f"Rewrite the following legal text related to special education in California in a {tone} tone, ensuring it remains compliant with state and federal regulations. Only return the rewritten text, without any explanation or prefix.\n\nText: {text}" response = llm.invoke(prompt) rewritten_text = response.content.strip() return {"messages": [AIMessage(content=rewritten_text)]} elif latest_msg == "regenerate": original_text = get_original_text(messages) if not original_text: return {"messages": [AIMessage(content="Error: Original text not found.")]} _, current_tone = get_current_text_and_tone(messages) if not current_tone: return {"messages": [AIMessage(content="Error: Current tone not found.")]} prompt = f"Rewrite the following legal text related to special education in California in a {current_tone} tone. Only return the rewritten text, without any explanation or prefix.\n\nText: {original_text}" response = llm.invoke(prompt) new_text = response.content.strip() return {"messages": [AIMessage(content=new_text)]} elif latest_msg.startswith("feedback:"): feedback = latest_msg.split("feedback: ")[1].strip() current_text, current_tone = get_current_text_and_tone(messages) if not current_text or not current_tone: return {"messages": [AIMessage(content="Error: Current text or tone not found.")]} prompt = f"Refine the following legal text related to special education in California based on this feedback: '{feedback}'. Maintain the {current_tone} tone and ensure compliance with regulations. Only return the refined text, without any explanation or prefix.\n\nText: {current_text}" response = llm.invoke(prompt) refined_text = response.content.strip() return {"messages": [AIMessage(content=refined_text)]} elif latest_msg.lower() == "approve": current_text, _ = get_current_text_and_tone(messages) if not current_text: return {"messages": [AIMessage(content="Error: No text to approve.")]} return {"messages": [AIMessage(content=f"Text approved: {current_text}")]} else: return {"messages": [AIMessage(content="Invalid command.")]} # Human feedback node def human_feedback(state: MessagesState): pass # Build LangGraph builder = StateGraph(MessagesState) builder.add_node("assistant", assistant) builder.add_node("human_feedback", human_feedback) builder.add_edge(START, "human_feedback") builder.add_edge("human_feedback", "assistant") builder.add_edge("assistant", "human_feedback") memory = MemorySaver() graph = builder.compile(interrupt_before=["human_feedback"], checkpointer=memory) # Gradio logic def initial_submit_fn(legal_text, initial_tone, state): if state["configurable"]["thread_id"] == "1": state["configurable"]["thread_id"] = str(uuid.uuid4()) if not legal_text.strip(): return "Please enter legal text.", state command = f"Rewrite this text in {initial_tone} tone: {legal_text}" graph.update_state(state, {"messages": [HumanMessage(content=command)]}, as_node="human_feedback") for event in graph.stream(None, state, stream_mode="values"): if "messages" in event: rewritten_text = event["messages"][-1].content display_text = extract_display_text(rewritten_text) return display_text, state def regenerate_fn(state): graph.update_state(state, {"messages": [HumanMessage(content="regenerate")]}, as_node="human_feedback") for event in graph.stream(None, state, stream_mode="values"): if "messages" in event: rewritten_text = event["messages"][-1].content display_text = extract_display_text(rewritten_text) return display_text, state def submit_feedback_fn(feedback_text, state): if not feedback_text.strip(): return "Please enter feedback.", state, "" command = f"feedback: {feedback_text}" graph.update_state(state, {"messages": [HumanMessage(content=command)]}, as_node="human_feedback") for event in graph.stream(None, state, stream_mode="values"): if "messages" in event: rewritten_text = event["messages"][-1].content display_text = extract_display_text(rewritten_text) return display_text, state, "" def approve_fn(state): graph.update_state(state, {"messages": [HumanMessage(content="approve")]}, as_node="human_feedback") for event in graph.stream(None, state, stream_mode="values"): if "messages" in event: confirmation = event["messages"][-1].content return confirmation, state # Gradio UI with gr.Blocks(title="Legal Text Rewriter for Special Education in California") as demo: state = gr.State(value={"configurable": {"thread_id": "1"}}) with gr.Column(): legal_text = gr.Textbox(label="Legal Text", placeholder="Enter your legal text here") initial_tone = gr.Dropdown( label="Initial Tone", choices=["Formal", "Empathetic", "Neutral", "Strength-Based"], value="Formal" ) submit_btn = gr.Button("Submit") rewritten_text_display = gr.Textbox(label="Rewritten Text", interactive=False) gr.Markdown("") regenerate_btn = gr.Button("Regenerate") feedback_textbox = gr.Textbox(label="Feedback", placeholder="Enter feedback to refine the text") submit_feedback_btn = gr.Button("Submit Feedback") approve_btn = gr.Button("Approve") # Button bindings submit_btn.click(initial_submit_fn, [legal_text, initial_tone, state], [rewritten_text_display, state]) regenerate_btn.click(regenerate_fn, [state], [rewritten_text_display, state]) submit_feedback_btn.click(submit_feedback_fn, [feedback_textbox, state], [rewritten_text_display, state, feedback_textbox]) approve_btn.click(approve_fn, [state], [rewritten_text_display, state]) demo.launch()