langchain-agent / app.py
lukmanaj's picture
Update app.py
7dfd205 verified
import os
import uuid
from datetime import datetime
from typing import Annotated
import gradio as gr
from typing_extensions import TypedDict
from langchain_core.messages import HumanMessage, AIMessage
from langchain_core.tools import tool
from langchain_community.utilities import WikipediaAPIWrapper
from langchain_community.tools import WikipediaQueryRun
from langchain_cohere import ChatCohere
from langgraph.graph import StateGraph, START, END, MessagesState
from langgraph.graph.message import add_messages
from langgraph.prebuilt import ToolNode
from langgraph.checkpoint.memory import MemorySaver
# =========================
# 1) Secrets / Environment
# =========================
COHERE_API_KEY = os.getenv("COHERE_API_KEY")
if not COHERE_API_KEY:
raise ValueError("Please set COHERE_API_KEY in your Hugging Face Spaces secrets")
os.environ["COHERE_API_KEY"] = COHERE_API_KEY
# =========================
# 2) LLM (Cohere)
# =========================
llm = ChatCohere(
model="command-a-03-2025",
temperature=0.3,
)
# =========================
# 3) LangGraph State
# =========================
class State(TypedDict):
messages: Annotated[list, add_messages]
# =========================
# 4) Tools
# =========================
# Tool 1: Wikipedia
wiki_api_wrapper = WikipediaAPIWrapper(top_k_results=1)
wikipedia_tool = WikipediaQueryRun(api_wrapper=wiki_api_wrapper)
# Tool 2: Historical Events (LLM-powered tool)
@tool
def historical_events(date_input: str) -> str:
"""Provide a list of important historical events for a given date."""
try:
res = llm.invoke(
f"You are a helpful historian. List important historical events that occurred on {date_input}. "
f"Return a concise bullet list (5-10 items)."
)
return res.content
except Exception as e:
return f"Error: {str(e)}"
# Tool 3: Palindrome Checker
@tool
def palindrome_checker(text: str) -> str:
"""Check if a word or phrase is a palindrome."""
cleaned = "".join(c.lower() for c in text if c.isalnum())
if cleaned == cleaned[::-1]:
return f"'{text}' is a palindrome."
return f"'{text}' is not a palindrome."
tools = [wikipedia_tool, historical_events, palindrome_checker]
tool_node = ToolNode(tools=tools)
# Bind tools to the LLM
model_with_tools = llm.bind_tools(tools)
# =========================
# 5) Graph logic
# =========================
def should_continue(state: MessagesState):
last_message = state["messages"][-1]
# If the model emitted tool calls, route to ToolNode; otherwise stop.
if getattr(last_message, "tool_calls", None):
if last_message.tool_calls:
return "tools"
return END
def call_model(state: MessagesState):
messages = state["messages"]
response = model_with_tools.invoke(messages)
return {"messages": [response]}
builder = StateGraph(State)
builder.add_node("chatbot", call_model)
builder.add_node("tools", tool_node)
builder.add_edge(START, "chatbot")
builder.add_conditional_edges("chatbot", should_continue, {"tools": "tools", END: END})
builder.add_edge("tools", "chatbot")
memory = MemorySaver()
app = builder.compile(checkpointer=memory)
# =========================
# 6) Gradio Chat Formatting
# =========================
# Per-session "pretty display" history (separate from LangGraph checkpoint state)
conversations = {}
def format_message_for_display(msg, msg_type="ai"):
"""Format a message for markdown display."""
timestamp = datetime.now().strftime("%H:%M")
if msg_type == "human":
return f"**πŸ‘€ You** *({timestamp})*\n\n{msg}\n\n---\n"
if msg_type == "tool":
tool_name = getattr(msg, "name", "Tool")
return f"**πŸ”§ {tool_name}** *({timestamp})*\n```text\n{msg.content}\n```\n\n---\n"
# AI message
return f"**πŸ€– Assistant** *({timestamp})*\n\n{msg.content}\n\n---\n"
def chatbot_conversation(message, _history_markdown, session_id):
"""Main chatbot function that maintains conversation history."""
# Generate session ID if not provided
if not session_id:
session_id = str(uuid.uuid4())
# LangGraph checkpoint thread config
config = {"configurable": {"thread_id": session_id}}
# Initialize display history if new session
if session_id not in conversations:
conversations[session_id] = []
# Add user message to display history
conversations[session_id].append(("human", message))
# Invoke LangGraph with this single user message (checkpoint keeps state)
inputs = {"messages": [HumanMessage(content=message)]}
try:
result = app.invoke(inputs, config)
final_messages = result["messages"]
# Append tool + AI outputs to display history
for msg in final_messages:
if isinstance(msg, HumanMessage):
continue
# Tool messages in LangChain usually come back with a name
if getattr(msg, "name", None):
conversations[session_id].append(("tool", msg))
else:
# AIMessage or similar
if getattr(msg, "content", None):
conversations[session_id].append(("ai", msg))
except Exception as e:
error_msg = f"❌ Error: {str(e)}"
conversations[session_id].append(("ai", AIMessage(content=error_msg)))
# Render whole conversation as one markdown string
formatted_history = ""
for msg_type, msg_content in conversations[session_id]:
if msg_type == "human":
formatted_history += format_message_for_display(msg_content, "human")
elif msg_type == "tool":
formatted_history += format_message_for_display(msg_content, "tool")
else:
formatted_history += format_message_for_display(msg_content, "ai")
return formatted_history, session_id
def clear_conversation():
"""Clear the current conversation (UI + new session id)."""
return "", str(uuid.uuid4())
# =========================
# 7) Gradio App (Spaces-ready)
# =========================
with gr.Blocks(theme=gr.themes.Soft(), title="πŸš€ Cohere + LangGraph Chatbot") as demo:
gr.Markdown(
"""
# πŸš€ Cohere (Command A) + LangGraph Chatbot
**LangGraph-powered conversational AI using Cohere's Command models**
πŸ” **Available Tools:**
- πŸ“š **Wikipedia Search** - Get information from Wikipedia
- πŸ”„ **Palindrome Checker** - Check if text is a palindrome
- πŸ“… **Historical Events** - Find events that happened on specific dates
πŸ’‘ **Try asking:** *"Tell me about Alan Turing, then check if 'radar' is a palindrome"*
"""
)
with gr.Row():
with gr.Column(scale=4):
chat_history = gr.Markdown(
value="**πŸ€– Assistant**: Hello! I'm your AI assistant powered by Cohere + LangGraph. "
"I can search Wikipedia, check palindromes, and find historical events. What would you like to know?\n\n---\n",
label="πŸ’¬ Conversation",
)
with gr.Row():
message_input = gr.Textbox(
placeholder="Type your message here...",
label="Your message",
scale=4,
lines=2,
)
send_btn = gr.Button("Send πŸš€", scale=1, variant="primary")
with gr.Row():
clear_btn = gr.Button("πŸ—‘οΈ Clear Chat", variant="secondary")
with gr.Column(scale=1):
gr.Markdown(
"""
### πŸ’‘ Example Queries:
- "What is machine learning?"
- "Is 'level' a palindrome?"
- "What happened on June 6, 1944?"
- "Tell me about Python programming"
- "Check if 'A man a plan a canal Panama' is a palindrome"
"""
)
session_id = gr.State(value=str(uuid.uuid4()))
def send_message(message, history, session_id_value):
if message and message.strip():
new_history, new_session_id = chatbot_conversation(message, history, session_id_value)
return new_history, new_session_id, ""
return history, session_id_value, message
send_btn.click(
send_message,
inputs=[message_input, chat_history, session_id],
outputs=[chat_history, session_id, message_input],
)
message_input.submit(
send_message,
inputs=[message_input, chat_history, session_id],
outputs=[chat_history, session_id, message_input],
)
clear_btn.click(
clear_conversation,
outputs=[chat_history, session_id],
)
if __name__ == "__main__":
# Spaces uses PORT=7860 by default, and needs server_name="0.0.0.0"
demo.launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT", "7860")))