| | |
| |
|
| | from llama_index.core import PromptTemplate |
| | from llama_index.core.workflow import Context |
| | from llama_index.core.agent.workflow import ReActAgent, AgentStream, ToolCallResult |
| | from llama_index.llms.huggingface_api import HuggingFaceInferenceAPI |
| | from llama_index.tools.wikipedia import WikipediaToolSpec |
| | from llama_index.tools.duckduckgo import DuckDuckGoSearchToolSpec |
| | from llama_index.tools.code_interpreter import CodeInterpreterToolSpec |
| |
|
| | from .prompt import custom_react_system_header_str |
| | from .custom_tools import query_image_tool, automatic_speech_recognition_tool |
| |
|
| | class LLamaIndexAgent: |
| | def __init__(self, |
| | model_name="Qwen/Qwen2.5-Coder-32B-Instruct", |
| | provider="hf-inference", |
| | show_tools_desc=True, |
| | show_prompt=True): |
| |
|
| | |
| | llm = HuggingFaceInferenceAPI(model_name=model_name, |
| | provider=provider) |
| | print(f"LLamaIndexAgent initialized with model \"{model_name}\"") |
| |
|
| | |
| | tool_spec_list = [] |
| | tool_spec_list += WikipediaToolSpec().to_tool_list() |
| | tool_spec_list += DuckDuckGoSearchToolSpec().to_tool_list() |
| | tool_spec_list += CodeInterpreterToolSpec().to_tool_list() |
| | tool_spec_list += [query_image_tool, automatic_speech_recognition_tool] |
| |
|
| | |
| | self.agent = ReActAgent(llm=llm, tools=tool_spec_list) |
| |
|
| | |
| | custom_react_system_header = PromptTemplate(custom_react_system_header_str) |
| | self.agent.update_prompts({"react_header": custom_react_system_header}) |
| |
|
| | |
| | self.ctx = Context(self.agent) |
| |
|
| | if show_tools_desc: |
| | for i, tool in enumerate(tool_spec_list): |
| | print("\n" + "="*30 + f" Tool {i+1} " + "="*30) |
| | print(tool.metadata.description) |
| |
|
| | if show_prompt: |
| | prompt_dict = self.agent.get_prompts() |
| | for k, v in prompt_dict.items(): |
| | print("\n" + "="*30 + f" Prompt: {k} " + "="*30) |
| | print(v.template) |
| |
|
| | async def __call__(self, question: str) -> str: |
| | print("\n\n"+"*"*50) |
| | print(f"Agent received question: {question}") |
| | print("*"*50) |
| |
|
| | handler = self.agent.run(question, ctx=self.ctx) |
| | async for ev in handler.stream_events(): |
| | |
| | |
| | if isinstance(ev, AgentStream): |
| | print(f"{ev.delta}", end="", flush=True) |
| |
|
| | response = await handler |
| |
|
| | |
| | response = str(response) |
| | try: |
| | response = response.split("FINAL ANSWER:")[-1].strip() |
| | except: |
| | print('Could not split response on "FINAL ANSWER:"') |
| | print("\n\n"+"-"*50) |
| | print(f"Agent returning with answer: {response}") |
| |
|
| | |
| | self.ctx.clear() |
| |
|
| | return response |