Spaces:
Sleeping
Sleeping
| import os | |
| import uuid | |
| from datetime import datetime | |
| from typing import Annotated | |
| import gradio as gr | |
| from typing_extensions import TypedDict | |
| from langchain_core.messages import HumanMessage, AIMessage | |
| from langchain_core.tools import tool | |
| from langchain_community.utilities import WikipediaAPIWrapper | |
| from langchain_community.tools import WikipediaQueryRun | |
| from langchain_cohere import ChatCohere | |
| from langgraph.graph import StateGraph, START, END, MessagesState | |
| from langgraph.graph.message import add_messages | |
| from langgraph.prebuilt import ToolNode | |
| from langgraph.checkpoint.memory import MemorySaver | |
| # ========================= | |
| # 1) Secrets / Environment | |
| # ========================= | |
| COHERE_API_KEY = os.getenv("COHERE_API_KEY") | |
| if not COHERE_API_KEY: | |
| raise ValueError("Please set COHERE_API_KEY in your Hugging Face Spaces secrets") | |
| os.environ["COHERE_API_KEY"] = COHERE_API_KEY | |
| # ========================= | |
| # 2) LLM (Cohere) | |
| # ========================= | |
| llm = ChatCohere( | |
| model="command-a-03-2025", | |
| temperature=0.3, | |
| ) | |
| # ========================= | |
| # 3) LangGraph State | |
| # ========================= | |
| class State(TypedDict): | |
| messages: Annotated[list, add_messages] | |
| # ========================= | |
| # 4) Tools | |
| # ========================= | |
| # Tool 1: Wikipedia | |
| wiki_api_wrapper = WikipediaAPIWrapper(top_k_results=1) | |
| wikipedia_tool = WikipediaQueryRun(api_wrapper=wiki_api_wrapper) | |
| # Tool 2: Historical Events (LLM-powered tool) | |
| def historical_events(date_input: str) -> str: | |
| """Provide a list of important historical events for a given date.""" | |
| try: | |
| res = llm.invoke( | |
| f"You are a helpful historian. List important historical events that occurred on {date_input}. " | |
| f"Return a concise bullet list (5-10 items)." | |
| ) | |
| return res.content | |
| except Exception as e: | |
| return f"Error: {str(e)}" | |
| # Tool 3: Palindrome Checker | |
| def palindrome_checker(text: str) -> str: | |
| """Check if a word or phrase is a palindrome.""" | |
| cleaned = "".join(c.lower() for c in text if c.isalnum()) | |
| if cleaned == cleaned[::-1]: | |
| return f"'{text}' is a palindrome." | |
| return f"'{text}' is not a palindrome." | |
| tools = [wikipedia_tool, historical_events, palindrome_checker] | |
| tool_node = ToolNode(tools=tools) | |
| # Bind tools to the LLM | |
| model_with_tools = llm.bind_tools(tools) | |
| # ========================= | |
| # 5) Graph logic | |
| # ========================= | |
| def should_continue(state: MessagesState): | |
| last_message = state["messages"][-1] | |
| # If the model emitted tool calls, route to ToolNode; otherwise stop. | |
| if getattr(last_message, "tool_calls", None): | |
| if last_message.tool_calls: | |
| return "tools" | |
| return END | |
| def call_model(state: MessagesState): | |
| messages = state["messages"] | |
| response = model_with_tools.invoke(messages) | |
| return {"messages": [response]} | |
| builder = StateGraph(State) | |
| builder.add_node("chatbot", call_model) | |
| builder.add_node("tools", tool_node) | |
| builder.add_edge(START, "chatbot") | |
| builder.add_conditional_edges("chatbot", should_continue, {"tools": "tools", END: END}) | |
| builder.add_edge("tools", "chatbot") | |
| memory = MemorySaver() | |
| app = builder.compile(checkpointer=memory) | |
| # ========================= | |
| # 6) Gradio Chat Formatting | |
| # ========================= | |
| # Per-session "pretty display" history (separate from LangGraph checkpoint state) | |
| conversations = {} | |
| def format_message_for_display(msg, msg_type="ai"): | |
| """Format a message for markdown display.""" | |
| timestamp = datetime.now().strftime("%H:%M") | |
| if msg_type == "human": | |
| return f"**π€ You** *({timestamp})*\n\n{msg}\n\n---\n" | |
| if msg_type == "tool": | |
| tool_name = getattr(msg, "name", "Tool") | |
| return f"**π§ {tool_name}** *({timestamp})*\n```text\n{msg.content}\n```\n\n---\n" | |
| # AI message | |
| return f"**π€ Assistant** *({timestamp})*\n\n{msg.content}\n\n---\n" | |
| def chatbot_conversation(message, _history_markdown, session_id): | |
| """Main chatbot function that maintains conversation history.""" | |
| # Generate session ID if not provided | |
| if not session_id: | |
| session_id = str(uuid.uuid4()) | |
| # LangGraph checkpoint thread config | |
| config = {"configurable": {"thread_id": session_id}} | |
| # Initialize display history if new session | |
| if session_id not in conversations: | |
| conversations[session_id] = [] | |
| # Add user message to display history | |
| conversations[session_id].append(("human", message)) | |
| # Invoke LangGraph with this single user message (checkpoint keeps state) | |
| inputs = {"messages": [HumanMessage(content=message)]} | |
| try: | |
| result = app.invoke(inputs, config) | |
| final_messages = result["messages"] | |
| # Append tool + AI outputs to display history | |
| for msg in final_messages: | |
| if isinstance(msg, HumanMessage): | |
| continue | |
| # Tool messages in LangChain usually come back with a name | |
| if getattr(msg, "name", None): | |
| conversations[session_id].append(("tool", msg)) | |
| else: | |
| # AIMessage or similar | |
| if getattr(msg, "content", None): | |
| conversations[session_id].append(("ai", msg)) | |
| except Exception as e: | |
| error_msg = f"β Error: {str(e)}" | |
| conversations[session_id].append(("ai", AIMessage(content=error_msg))) | |
| # Render whole conversation as one markdown string | |
| formatted_history = "" | |
| for msg_type, msg_content in conversations[session_id]: | |
| if msg_type == "human": | |
| formatted_history += format_message_for_display(msg_content, "human") | |
| elif msg_type == "tool": | |
| formatted_history += format_message_for_display(msg_content, "tool") | |
| else: | |
| formatted_history += format_message_for_display(msg_content, "ai") | |
| return formatted_history, session_id | |
| def clear_conversation(): | |
| """Clear the current conversation (UI + new session id).""" | |
| return "", str(uuid.uuid4()) | |
| # ========================= | |
| # 7) Gradio App (Spaces-ready) | |
| # ========================= | |
| with gr.Blocks(theme=gr.themes.Soft(), title="π Cohere + LangGraph Chatbot") as demo: | |
| gr.Markdown( | |
| """ | |
| # π Cohere (Command A) + LangGraph Chatbot | |
| **LangGraph-powered conversational AI using Cohere's Command models** | |
| π **Available Tools:** | |
| - π **Wikipedia Search** - Get information from Wikipedia | |
| - π **Palindrome Checker** - Check if text is a palindrome | |
| - π **Historical Events** - Find events that happened on specific dates | |
| π‘ **Try asking:** *"Tell me about Alan Turing, then check if 'radar' is a palindrome"* | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=4): | |
| chat_history = gr.Markdown( | |
| value="**π€ Assistant**: Hello! I'm your AI assistant powered by Cohere + LangGraph. " | |
| "I can search Wikipedia, check palindromes, and find historical events. What would you like to know?\n\n---\n", | |
| label="π¬ Conversation", | |
| ) | |
| with gr.Row(): | |
| message_input = gr.Textbox( | |
| placeholder="Type your message here...", | |
| label="Your message", | |
| scale=4, | |
| lines=2, | |
| ) | |
| send_btn = gr.Button("Send π", scale=1, variant="primary") | |
| with gr.Row(): | |
| clear_btn = gr.Button("ποΈ Clear Chat", variant="secondary") | |
| with gr.Column(scale=1): | |
| gr.Markdown( | |
| """ | |
| ### π‘ Example Queries: | |
| - "What is machine learning?" | |
| - "Is 'level' a palindrome?" | |
| - "What happened on June 6, 1944?" | |
| - "Tell me about Python programming" | |
| - "Check if 'A man a plan a canal Panama' is a palindrome" | |
| """ | |
| ) | |
| session_id = gr.State(value=str(uuid.uuid4())) | |
| def send_message(message, history, session_id_value): | |
| if message and message.strip(): | |
| new_history, new_session_id = chatbot_conversation(message, history, session_id_value) | |
| return new_history, new_session_id, "" | |
| return history, session_id_value, message | |
| send_btn.click( | |
| send_message, | |
| inputs=[message_input, chat_history, session_id], | |
| outputs=[chat_history, session_id, message_input], | |
| ) | |
| message_input.submit( | |
| send_message, | |
| inputs=[message_input, chat_history, session_id], | |
| outputs=[chat_history, session_id, message_input], | |
| ) | |
| clear_btn.click( | |
| clear_conversation, | |
| outputs=[chat_history, session_id], | |
| ) | |
| if __name__ == "__main__": | |
| # Spaces uses PORT=7860 by default, and needs server_name="0.0.0.0" | |
| demo.launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT", "7860"))) | |