Spaces:
Sleeping
Sleeping
| import os | |
| import uuid | |
| import gradio as gr | |
| from langgraph.graph import MessagesState, StateGraph, START | |
| from langgraph.checkpoint.memory import MemorySaver | |
| from langchain_groq import ChatGroq | |
| from langchain_core.messages import HumanMessage, AIMessage | |
| # Get API key from environment variables | |
| GROQ_API_KEY = os.environ['GROQ_API_KEY'] | |
| # Initialize the Groq LLM | |
| llm = ChatGroq(groq_api_key=GROQ_API_KEY, model_name="llama3-8b-8192") | |
| # Helper functions | |
| def get_original_text(messages): | |
| for msg in messages: | |
| if isinstance(msg, HumanMessage) and "Rewrite this text in" in msg.content: | |
| return msg.content.split("tone: ")[1] | |
| return "" | |
| def get_current_text_and_tone(messages): | |
| for msg in reversed(messages): | |
| if isinstance(msg, AIMessage): | |
| return msg.content, "Unknown" # Simplified since we now strip metadata | |
| return None, None | |
| def extract_display_text(message): | |
| return message.strip() | |
| # Assistant node with clean output | |
| def assistant(state: MessagesState): | |
| messages = state["messages"] | |
| latest_msg = messages[-1].content | |
| if "Rewrite this text in" in latest_msg: | |
| tone = latest_msg.split(" in ")[1].split(" tone:")[0] | |
| text = latest_msg.split("tone: ")[1] | |
| prompt = f"Rewrite the following legal text related to special education in California in a {tone} tone, ensuring it remains compliant with state and federal regulations. Only return the rewritten text, without any explanation or prefix.\n\nText: {text}" | |
| response = llm.invoke(prompt) | |
| rewritten_text = response.content.strip() | |
| return {"messages": [AIMessage(content=rewritten_text)]} | |
| elif latest_msg == "regenerate": | |
| original_text = get_original_text(messages) | |
| if not original_text: | |
| return {"messages": [AIMessage(content="Error: Original text not found.")]} | |
| _, current_tone = get_current_text_and_tone(messages) | |
| if not current_tone: | |
| return {"messages": [AIMessage(content="Error: Current tone not found.")]} | |
| prompt = f"Rewrite the following legal text related to special education in California in a {current_tone} tone. Only return the rewritten text, without any explanation or prefix.\n\nText: {original_text}" | |
| response = llm.invoke(prompt) | |
| new_text = response.content.strip() | |
| return {"messages": [AIMessage(content=new_text)]} | |
| elif latest_msg.startswith("feedback:"): | |
| feedback = latest_msg.split("feedback: ")[1].strip() | |
| current_text, current_tone = get_current_text_and_tone(messages) | |
| if not current_text or not current_tone: | |
| return {"messages": [AIMessage(content="Error: Current text or tone not found.")]} | |
| prompt = f"Refine the following legal text related to special education in California based on this feedback: '{feedback}'. Maintain the {current_tone} tone and ensure compliance with regulations. Only return the refined text, without any explanation or prefix.\n\nText: {current_text}" | |
| response = llm.invoke(prompt) | |
| refined_text = response.content.strip() | |
| return {"messages": [AIMessage(content=refined_text)]} | |
| elif latest_msg.lower() == "approve": | |
| current_text, _ = get_current_text_and_tone(messages) | |
| if not current_text: | |
| return {"messages": [AIMessage(content="Error: No text to approve.")]} | |
| return {"messages": [AIMessage(content=f"Text approved: {current_text}")]} | |
| else: | |
| return {"messages": [AIMessage(content="Invalid command.")]} | |
| # Human feedback node | |
| def human_feedback(state: MessagesState): | |
| pass | |
| # Build LangGraph | |
| builder = StateGraph(MessagesState) | |
| builder.add_node("assistant", assistant) | |
| builder.add_node("human_feedback", human_feedback) | |
| builder.add_edge(START, "human_feedback") | |
| builder.add_edge("human_feedback", "assistant") | |
| builder.add_edge("assistant", "human_feedback") | |
| memory = MemorySaver() | |
| graph = builder.compile(interrupt_before=["human_feedback"], checkpointer=memory) | |
| # Gradio logic | |
| def initial_submit_fn(legal_text, initial_tone, state): | |
| if state["configurable"]["thread_id"] == "1": | |
| state["configurable"]["thread_id"] = str(uuid.uuid4()) | |
| if not legal_text.strip(): | |
| return "Please enter legal text.", state | |
| command = f"Rewrite this text in {initial_tone} tone: {legal_text}" | |
| graph.update_state(state, {"messages": [HumanMessage(content=command)]}, as_node="human_feedback") | |
| for event in graph.stream(None, state, stream_mode="values"): | |
| if "messages" in event: | |
| rewritten_text = event["messages"][-1].content | |
| display_text = extract_display_text(rewritten_text) | |
| return display_text, state | |
| def regenerate_fn(state): | |
| graph.update_state(state, {"messages": [HumanMessage(content="regenerate")]}, as_node="human_feedback") | |
| for event in graph.stream(None, state, stream_mode="values"): | |
| if "messages" in event: | |
| rewritten_text = event["messages"][-1].content | |
| display_text = extract_display_text(rewritten_text) | |
| return display_text, state | |
| def submit_feedback_fn(feedback_text, state): | |
| if not feedback_text.strip(): | |
| return "Please enter feedback.", state, "" | |
| command = f"feedback: {feedback_text}" | |
| graph.update_state(state, {"messages": [HumanMessage(content=command)]}, as_node="human_feedback") | |
| for event in graph.stream(None, state, stream_mode="values"): | |
| if "messages" in event: | |
| rewritten_text = event["messages"][-1].content | |
| display_text = extract_display_text(rewritten_text) | |
| return display_text, state, "" | |
| def approve_fn(state): | |
| graph.update_state(state, {"messages": [HumanMessage(content="approve")]}, as_node="human_feedback") | |
| for event in graph.stream(None, state, stream_mode="values"): | |
| if "messages" in event: | |
| confirmation = event["messages"][-1].content | |
| return confirmation, state | |
| # Gradio UI | |
| with gr.Blocks(title="Legal Text Rewriter for Special Education in California") as demo: | |
| state = gr.State(value={"configurable": {"thread_id": "1"}}) | |
| with gr.Column(): | |
| legal_text = gr.Textbox(label="Legal Text", placeholder="Enter your legal text here") | |
| initial_tone = gr.Dropdown( | |
| label="Initial Tone", | |
| choices=["Formal", "Empathetic", "Neutral", "Strength-Based"], | |
| value="Formal" | |
| ) | |
| submit_btn = gr.Button("Submit") | |
| rewritten_text_display = gr.Textbox(label="Rewritten Text", interactive=False) | |
| gr.Markdown("") | |
| regenerate_btn = gr.Button("Regenerate") | |
| feedback_textbox = gr.Textbox(label="Feedback", placeholder="Enter feedback to refine the text") | |
| submit_feedback_btn = gr.Button("Submit Feedback") | |
| approve_btn = gr.Button("Approve") | |
| # Button bindings | |
| submit_btn.click(initial_submit_fn, [legal_text, initial_tone, state], [rewritten_text_display, state]) | |
| regenerate_btn.click(regenerate_fn, [state], [rewritten_text_display, state]) | |
| submit_feedback_btn.click(submit_feedback_fn, [feedback_textbox, state], [rewritten_text_display, state, feedback_textbox]) | |
| approve_btn.click(approve_fn, [state], [rewritten_text_display, state]) | |
| demo.launch() | |