Spaces:
Runtime error
Runtime error
| import os | |
| import gradio as gr | |
| import json | |
| from typing import Annotated | |
| from typing_extensions import TypedDict | |
| from langchain_huggingface import HuggingFaceEndpoint | |
| from langchain_community.tools.tavily_search import TavilySearchResults | |
| from langgraph.graph import StateGraph, START, END | |
| from langgraph.graph.message import add_messages | |
| from langchain_core.messages import ToolMessage | |
| from dotenv import load_dotenv | |
| import logging | |
| # Initialize logging | |
| logging.basicConfig(level=logging.INFO) | |
| # Load environment variables | |
| load_dotenv() | |
| HF_TOKEN = os.getenv("HF_TOKEN") | |
| # Initialize the HuggingFace model | |
| llm = HuggingFaceEndpoint( | |
| repo_id="mistralai/Mistral-7B-Instruct-v0.3", | |
| huggingfacehub_api_token=HF_TOKEN.strip(), | |
| temperature=0.7, | |
| max_new_tokens=200 | |
| ) | |
| # Initialize Tavily Search tool | |
| tool = TavilySearchResults(max_results=2) | |
| tools = [tool] | |
| # Define the state structure | |
| class State(TypedDict): | |
| messages: Annotated[list, add_messages] | |
| # Create a state graph builder | |
| graph_builder = StateGraph(State) | |
| # Define the chatbot function | |
| def chatbot(state: State): | |
| try: | |
| # Get the last message and ensure it's a string | |
| input_message = state["messages"][-1] if state["messages"] else "" | |
| # Ensure that input_message is a string (check the type) | |
| if isinstance(input_message, str): | |
| query = input_message # If it's already a string, use it directly | |
| elif hasattr(input_message, 'content') and isinstance(input_message.content, str): | |
| query = input_message.content # Extract the content if it's a HumanMessage object | |
| else: | |
| raise ValueError("Input message is not in the correct format") | |
| logging.info(f"Input Message: {query}") | |
| # Invoke the LLM for a response | |
| response = llm.invoke([query]) | |
| logging.info(f"LLM Response: {response}") | |
| # Now, invoke Tavily Search and get the results | |
| search_results = tool.invoke({"query": query}) | |
| # Extract URLs from search results | |
| urls = [result.get("url", "No URL found") for result in search_results] | |
| # Prepare the result to include URL information | |
| result_with_url = { | |
| "role": "assistant", # Set the role to 'assistant' | |
| "content": response, # Set the response as content | |
| "urls": urls # Include the URLs of the search results | |
| } | |
| return {"messages": state["messages"] + [result_with_url]} | |
| except Exception as e: | |
| logging.error(f"Error: {str(e)}") | |
| return {"messages": state["messages"] + [f"Error: {str(e)}"]} | |
| # Add tool node to the graph | |
| class BasicToolNode: | |
| """A node that runs the tools requested in the last AIMessage.""" | |
| def __init__(self, tools: list) -> None: | |
| self.tools_by_name = {tool.name: tool for tool in tools} | |
| def __call__(self, inputs: dict): | |
| if messages := inputs.get("messages", []): | |
| message = messages[-1] | |
| else: | |
| raise ValueError("No message found in input") | |
| outputs = [] | |
| for tool_call in message.tool_calls: | |
| tool_result = self.tools_by_name[tool_call["name"]].invoke( | |
| tool_call["args"] | |
| ) | |
| outputs.append( | |
| ToolMessage( | |
| content=json.dumps(tool_result), | |
| name=tool_call["name"], | |
| tool_call_id=tool_call["id"], | |
| ) | |
| ) | |
| return {"messages": outputs} | |
| # Add tool node to the graph | |
| tool_node = BasicToolNode(tools=tools) | |
| graph_builder.add_node("tools", tool_node) | |
| # Define the conditional routing function | |
| def route_tools(state: State): | |
| """ | |
| Route to the ToolNode if the last message has tool calls. | |
| Otherwise, route to the end. | |
| """ | |
| if isinstance(state, list): | |
| ai_message = state[-1] | |
| elif messages := state.get("messages", []): | |
| ai_message = messages[-1] | |
| else: | |
| raise ValueError(f"No messages found in input state to tool_edge: {state}") | |
| if hasattr(ai_message, "tool_calls") and len(ai_message.tool_calls) > 0: | |
| return "tools" | |
| return END | |
| # Add nodes and conditional edges to the state graph | |
| graph_builder.add_node("chatbot", chatbot) | |
| graph_builder.add_conditional_edges( | |
| "chatbot", | |
| route_tools, | |
| {"tools": "tools", END: END} | |
| ) | |
| graph_builder.add_edge("tools", "chatbot") | |
| graph_builder.add_edge(START, "chatbot") | |
| graph = graph_builder.compile() | |
| # Gradio interface | |
| def chat_interface(input_text, state): | |
| # Prepare state if not provided | |
| if state is None: | |
| state = {"messages": []} | |
| # Append user input to state | |
| state["messages"].append(input_text) | |
| # Process state through the graph | |
| updated_state = graph.invoke(state) | |
| return updated_state["messages"][-1], updated_state | |
| # Create Gradio app | |
| with gr.Blocks() as demo: | |
| gr.Markdown("### Chatbot with Tavily Search Integration") | |
| chat_state = gr.State({"messages": []}) | |
| with gr.Row(): | |
| with gr.Column(): | |
| user_input = gr.Textbox(label="Your Message", placeholder="Type your message here...", lines=2) | |
| submit_button = gr.Button("Submit") | |
| with gr.Column(): | |
| chatbot_output = gr.Textbox(label="Chatbot Response", interactive=False, lines=4) | |
| submit_button.click(chat_interface, inputs=[user_input, chat_state], outputs=[chatbot_output, chat_state]) | |
| # Launch the Gradio app | |
| if __name__ == "__main__": | |
| demo.launch() | |