import os import uuid from datetime import datetime from typing import Annotated import gradio as gr from typing_extensions import TypedDict from langchain_core.messages import HumanMessage, AIMessage from langchain_core.tools import tool from langchain_community.utilities import WikipediaAPIWrapper from langchain_community.tools import WikipediaQueryRun from langchain_cohere import ChatCohere from langgraph.graph import StateGraph, START, END, MessagesState from langgraph.graph.message import add_messages from langgraph.prebuilt import ToolNode from langgraph.checkpoint.memory import MemorySaver # ========================= # 1) Secrets / Environment # ========================= COHERE_API_KEY = os.getenv("COHERE_API_KEY") if not COHERE_API_KEY: raise ValueError("Please set COHERE_API_KEY in your Hugging Face Spaces secrets") os.environ["COHERE_API_KEY"] = COHERE_API_KEY # ========================= # 2) LLM (Cohere) # ========================= llm = ChatCohere( model="command-a-03-2025", temperature=0.3, ) # ========================= # 3) LangGraph State # ========================= class State(TypedDict): messages: Annotated[list, add_messages] # ========================= # 4) Tools # ========================= # Tool 1: Wikipedia wiki_api_wrapper = WikipediaAPIWrapper(top_k_results=1) wikipedia_tool = WikipediaQueryRun(api_wrapper=wiki_api_wrapper) # Tool 2: Historical Events (LLM-powered tool) @tool def historical_events(date_input: str) -> str: """Provide a list of important historical events for a given date.""" try: res = llm.invoke( f"You are a helpful historian. List important historical events that occurred on {date_input}. " f"Return a concise bullet list (5-10 items)." ) return res.content except Exception as e: return f"Error: {str(e)}" # Tool 3: Palindrome Checker @tool def palindrome_checker(text: str) -> str: """Check if a word or phrase is a palindrome.""" cleaned = "".join(c.lower() for c in text if c.isalnum()) if cleaned == cleaned[::-1]: return f"'{text}' is a palindrome." return f"'{text}' is not a palindrome." tools = [wikipedia_tool, historical_events, palindrome_checker] tool_node = ToolNode(tools=tools) # Bind tools to the LLM model_with_tools = llm.bind_tools(tools) # ========================= # 5) Graph logic # ========================= def should_continue(state: MessagesState): last_message = state["messages"][-1] # If the model emitted tool calls, route to ToolNode; otherwise stop. if getattr(last_message, "tool_calls", None): if last_message.tool_calls: return "tools" return END def call_model(state: MessagesState): messages = state["messages"] response = model_with_tools.invoke(messages) return {"messages": [response]} builder = StateGraph(State) builder.add_node("chatbot", call_model) builder.add_node("tools", tool_node) builder.add_edge(START, "chatbot") builder.add_conditional_edges("chatbot", should_continue, {"tools": "tools", END: END}) builder.add_edge("tools", "chatbot") memory = MemorySaver() app = builder.compile(checkpointer=memory) # ========================= # 6) Gradio Chat Formatting # ========================= # Per-session "pretty display" history (separate from LangGraph checkpoint state) conversations = {} def format_message_for_display(msg, msg_type="ai"): """Format a message for markdown display.""" timestamp = datetime.now().strftime("%H:%M") if msg_type == "human": return f"**👤 You** *({timestamp})*\n\n{msg}\n\n---\n" if msg_type == "tool": tool_name = getattr(msg, "name", "Tool") return f"**🔧 {tool_name}** *({timestamp})*\n```text\n{msg.content}\n```\n\n---\n" # AI message return f"**🤖 Assistant** *({timestamp})*\n\n{msg.content}\n\n---\n" def chatbot_conversation(message, _history_markdown, session_id): """Main chatbot function that maintains conversation history.""" # Generate session ID if not provided if not session_id: session_id = str(uuid.uuid4()) # LangGraph checkpoint thread config config = {"configurable": {"thread_id": session_id}} # Initialize display history if new session if session_id not in conversations: conversations[session_id] = [] # Add user message to display history conversations[session_id].append(("human", message)) # Invoke LangGraph with this single user message (checkpoint keeps state) inputs = {"messages": [HumanMessage(content=message)]} try: result = app.invoke(inputs, config) final_messages = result["messages"] # Append tool + AI outputs to display history for msg in final_messages: if isinstance(msg, HumanMessage): continue # Tool messages in LangChain usually come back with a name if getattr(msg, "name", None): conversations[session_id].append(("tool", msg)) else: # AIMessage or similar if getattr(msg, "content", None): conversations[session_id].append(("ai", msg)) except Exception as e: error_msg = f"❌ Error: {str(e)}" conversations[session_id].append(("ai", AIMessage(content=error_msg))) # Render whole conversation as one markdown string formatted_history = "" for msg_type, msg_content in conversations[session_id]: if msg_type == "human": formatted_history += format_message_for_display(msg_content, "human") elif msg_type == "tool": formatted_history += format_message_for_display(msg_content, "tool") else: formatted_history += format_message_for_display(msg_content, "ai") return formatted_history, session_id def clear_conversation(): """Clear the current conversation (UI + new session id).""" return "", str(uuid.uuid4()) # ========================= # 7) Gradio App (Spaces-ready) # ========================= with gr.Blocks(theme=gr.themes.Soft(), title="🚀 Cohere + LangGraph Chatbot") as demo: gr.Markdown( """ # 🚀 Cohere (Command A) + LangGraph Chatbot **LangGraph-powered conversational AI using Cohere's Command models** 🔍 **Available Tools:** - 📚 **Wikipedia Search** - Get information from Wikipedia - 🔄 **Palindrome Checker** - Check if text is a palindrome - 📅 **Historical Events** - Find events that happened on specific dates 💡 **Try asking:** *"Tell me about Alan Turing, then check if 'radar' is a palindrome"* """ ) with gr.Row(): with gr.Column(scale=4): chat_history = gr.Markdown( value="**🤖 Assistant**: Hello! I'm your AI assistant powered by Cohere + LangGraph. " "I can search Wikipedia, check palindromes, and find historical events. What would you like to know?\n\n---\n", label="💬 Conversation", ) with gr.Row(): message_input = gr.Textbox( placeholder="Type your message here...", label="Your message", scale=4, lines=2, ) send_btn = gr.Button("Send 🚀", scale=1, variant="primary") with gr.Row(): clear_btn = gr.Button("🗑️ Clear Chat", variant="secondary") with gr.Column(scale=1): gr.Markdown( """ ### 💡 Example Queries: - "What is machine learning?" - "Is 'level' a palindrome?" - "What happened on June 6, 1944?" - "Tell me about Python programming" - "Check if 'A man a plan a canal Panama' is a palindrome" """ ) session_id = gr.State(value=str(uuid.uuid4())) def send_message(message, history, session_id_value): if message and message.strip(): new_history, new_session_id = chatbot_conversation(message, history, session_id_value) return new_history, new_session_id, "" return history, session_id_value, message send_btn.click( send_message, inputs=[message_input, chat_history, session_id], outputs=[chat_history, session_id, message_input], ) message_input.submit( send_message, inputs=[message_input, chat_history, session_id], outputs=[chat_history, session_id, message_input], ) clear_btn.click( clear_conversation, outputs=[chat_history, session_id], ) if __name__ == "__main__": # Spaces uses PORT=7860 by default, and needs server_name="0.0.0.0" demo.launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT", "7860")))