Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import os | |
| from haystack.core.pipeline import Pipeline | |
| from haystack.components.builders.answer_builder import AnswerBuilder | |
| from haystack.components.converters.html import HTMLToDocument | |
| from haystack.components.converters.output_adapter import OutputAdapter | |
| from haystack.components.fetchers.link_content import LinkContentFetcher | |
| from haystack.components.websearch.serper_dev import SerperDevWebSearch | |
| from haystack.components.generators.chat import OpenAIChatGenerator | |
| from haystack.core.super_component import SuperComponent | |
| from haystack.tools import ComponentTool, Tool | |
| from haystack.components.agents import Agent | |
| from haystack.utils import Secret | |
| from haystack.dataclasses import ChatMessage | |
| base_url= os.getenv("BASE_URL") | |
| api_key= os.getenv("API_KEY") | |
| class ChatbotAgent: | |
| def __init__(self): | |
| self.agent: Agent = self.init_pipline() | |
| def init_pipline(self): | |
| # ------------------------------ | |
| # 🧠 Define the Search Pipeline | |
| # ------------------------------ | |
| search_pipeline = Pipeline() | |
| search_pipeline.add_component( | |
| "search", | |
| SerperDevWebSearch( | |
| api_key=Secret.from_token("2d4ddff6e41b645a20432f16b7d38f371414afbd"), | |
| top_k=10, | |
| ), | |
| ) | |
| search_pipeline.add_component( | |
| "fetcher", | |
| LinkContentFetcher(timeout=3, raise_on_failure=False, retry_attempts=2), | |
| ) | |
| search_pipeline.add_component("converter", HTMLToDocument()) | |
| search_pipeline.add_component( | |
| "output_adapter", | |
| OutputAdapter( | |
| template=""" | |
| {%- for doc in docs -%} | |
| {%- if doc.content -%} | |
| <search-result url="{{ doc.meta.url }}"> | |
| {{ doc.content|truncate(25000) }} | |
| </search-result> | |
| {%- endif -%} | |
| {%- endfor -%} | |
| """, | |
| output_type=str, | |
| ), | |
| ) | |
| search_pipeline.connect("search.links", "fetcher.urls") | |
| search_pipeline.connect("fetcher.streams", "converter.sources") | |
| search_pipeline.connect("converter.documents", "output_adapter.docs") | |
| # ------------------------------ | |
| # 🤖 Set up the Chat Generator | |
| # ------------------------------ | |
| openai_generator = OpenAIChatGenerator( | |
| api_base_url=base_url, | |
| api_key=Secret.from_token(api_key), | |
| model="meta/llama-3.3-70b-instruct", | |
| ) | |
| # ------------------------------ | |
| # 🔧 Wrap search pipeline in a SuperComponent Tool | |
| # ------------------------------ | |
| search_component = SuperComponent( | |
| pipeline=search_pipeline, | |
| input_mapping={"query": ["search.query"]}, | |
| output_mapping={"output_adapter.output": "search_result"}, | |
| ) | |
| search_tool = ComponentTool( | |
| name="search", | |
| description="Use this tool to search for information on the internet.", | |
| component=search_component, | |
| outputs_to_string={"source": "search_result"}, | |
| ) | |
| calculator = Tool( | |
| name="calculator", | |
| description="Use this tool to calculate math problems.", | |
| parameters={"expression": {"type": "string"}}, | |
| function=lambda expression: eval(expression), | |
| ) | |
| # ------------------------------ | |
| # 🧠 Define the Agent | |
| # ------------------------------ | |
| agent = Agent( | |
| chat_generator=openai_generator, | |
| tools=[search_tool], | |
| system_prompt=""" | |
| You are a deep research assistant. | |
| You create comprehensive research reports to answer the user's questions. | |
| You use the 'search'-tool to answer any questions. | |
| You perform multiple searches until you have the information you need to answer the question. | |
| Make sure you research different aspects of the question. | |
| Use markdown to format your response. | |
| When you use information from the websearch results, cite your sources using markdown links. | |
| It is important that you cite accurately. | |
| """, | |
| exit_conditions=["text"], | |
| max_agent_steps=20, | |
| ) | |
| agent.warm_up() | |
| return agent | |
| def run(self, query) -> str: | |
| # If query is a list of ChatMessage objects, we can pass it directly | |
| if isinstance(query, list) and all( | |
| isinstance(msg, ChatMessage) for msg in query | |
| ): | |
| self.agent.warm_up() | |
| result = self.agent.run(query) | |
| # If query is a single string, convert it to a ChatMessage | |
| elif isinstance(query, str): | |
| user_message = ChatMessage.from_user(query) | |
| self.agent.warm_up() | |
| result = self.agent.run([user_message]) | |
| else: | |
| raise ValueError( | |
| "Query must be either a string or a list of ChatMessage objects" | |
| ) | |
| output = result["messages"][-1].text | |
| return output | |
| def respond( | |
| message, | |
| history, | |
| system_message, | |
| max_tokens, | |
| temperature, | |
| top_p, | |
| ): | |
| try: | |
| # Initialize the agent | |
| agent = ChatbotAgent() | |
| # Convert the chat history from Gradio format to Haystack ChatMessage format | |
| chat_messages = [] | |
| for message_obj in history: | |
| if message_obj["role"] == "user": | |
| chat_messages.append(ChatMessage.from_user(message_obj["content"])) | |
| elif message_obj["role"] == "assistant": | |
| chat_messages.append(ChatMessage.from_assistant(message_obj["content"])) | |
| # Add the current user message | |
| chat_messages.append(ChatMessage.from_user(message)) | |
| # Run the agent | |
| response_content = agent.run(chat_messages) | |
| # Yield the response | |
| yield response_content | |
| except Exception as e: | |
| # Log or print the error for debugging | |
| print(f"Error in respond(): {e}") | |
| # Gracefully respond with an error message | |
| yield f"⚠️ An error occurred: `{e}`" | |
| """ | |
| For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface | |
| """ | |
| demo = gr.ChatInterface( | |
| respond, | |
| type="messages", | |
| description="A responsive chatbot interface with streaming responses." | |
| # additional_inputs=[ | |
| # gr.Textbox(value="You are a friendly Chatbot.", label="System message") | |
| # gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"), | |
| # gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"), | |
| # gr.Slider( | |
| # minimum=0.1, | |
| # maximum=1.0, | |
| # value=0.95, | |
| # step=0.05, | |
| # label="Top-p (nucleus sampling)", | |
| # ), | |
| # ], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |