| | from time import time |
| | from pprint import pprint |
| | import huggingface_hub |
| | import streamlit as st |
| | from typing import Literal, Dict |
| | from typing_extensions import TypedDict |
| | import langchain |
| | from langgraph.graph import END, StateGraph |
| | from langchain_community.chat_models import ChatOllama |
| | from logger import logger |
| |
|
| | from config import config |
| | from agents import get_agents, tools_dict |
| |
|
| |
|
| | class GraphState(TypedDict): |
| | """Represents the state of the graph.""" |
| | question: str |
| | rephrased_question: str |
| | function_agent_output: str |
| | generation: str |
| |
|
| |
|
| | @st.cache_resource(show_spinner="Loading model..") |
| | def init_agents() -> dict[str, langchain.agents.AgentExecutor]: |
| | huggingface_hub.login(token=config.hf_token, new_session=False) |
| | llm = ChatOllama(model = config.ollama_model, temperature = 0.8) |
| | return get_agents(llm) |
| |
|
| |
|
| | |
| |
|
| | def question_node(state: GraphState) -> Dict[str, str]: |
| | """ |
| | Generate a question for the function agent. |
| | """ |
| | logger.info("Generating question for function agent") |
| | |
| | question = state["question"] |
| | logger.info(f"Original question: {question}") |
| | rephrased_question = agents["rephrase_agent"].invoke({"question": question}) |
| | logger.info(f"Rephrased question: {rephrased_question}") |
| | return {"rephrased_question": rephrased_question} |
| |
|
| | def function_agent_node(state: GraphState) -> Literal["finished"]: |
| | """ |
| | Call the function agent |
| | """ |
| | logger.info("Calling function agent") |
| | question = state["rephrased_question"] |
| | response = agents["function_agent"].invoke({"input": question, "tools": tools_dict}).get("output") |
| | |
| | logger.info(f"Function agent output: {response}") |
| | return {"function_agent_output": response} |
| |
|
| | def output_node(state: GraphState) -> Dict[str, str]: |
| | """ |
| | Generate the final output |
| | """ |
| | logger.info("Generating output") |
| | |
| | generation = agents["output_agent"].invoke({"context": state["function_agent_output"], |
| | "question": state["rephrased_question"]}) |
| | return {"generation": generation} |
| |
|
| | |
| |
|
| | def route_question(state: GraphState) -> Literal["vectorstore", "websearch"]: |
| | """ |
| | Route quesition to web search or RAG |
| | """ |
| | logger.info("Routing question") |
| | |
| | question = state["question"] |
| | logger.info(f"Question: {question}") |
| | source = agents["router_agent"].invoke({"question": question}) |
| | logger.info(source) |
| | logger.info(source["datasource"]) |
| | if source["datasource"] == "vectorstore": |
| | return "vectorstore" |
| | elif source["datasource"] == "websearch": |
| | return "websearch" |
| |
|
| |
|
| | |
| |
|
| | workflow = StateGraph(GraphState) |
| | workflow.add_node("question_rephrase", question_node) |
| | workflow.add_node("function_agent", function_agent_node) |
| | workflow.add_node("output_node", output_node) |
| |
|
| | workflow.set_entry_point("question_rephrase") |
| | workflow.add_edge("question_rephrase", "function_agent") |
| | workflow.add_edge("function_agent", "output_node") |
| | workflow.set_finish_point("output_node") |
| |
|
| | flow = workflow.compile() |
| |
|
| | progress_map = { |
| | "question_rephrase": ":mag: Collecting data", |
| | "function_agent": ":bulb: Preparing response", |
| | "output_node": ":bulb: Done!", |
| | } |
| |
|
| | def main(): |
| | st.title("LLM-ADE 9B Demo") |
| |
|
| | input_text = st.text_area("Enter your text here:", value="", height=200) |
| | |
| | def get_response(input_text: str, depth: int = 1) -> str: |
| | try: |
| | for output in flow.stream({"question": input_text}): |
| | for key, value in output.items(): |
| | config.status.update(label=progress_map[key]) |
| | pprint(f"Finished running: {key}") |
| | return value["generation"] |
| | except Exception as e: |
| | logger.error(e) |
| | logger.info("Retrying..") |
| | if depth < 5: |
| | return get_response(input_text, depth + 1) |
| |
|
| | if st.button("Generate"): |
| | if input_text: |
| | with st.status("Generating response...") as status: |
| | config.status = status |
| | config.status.update(label=":question: Breaking down question") |
| | response = get_response(input_text) |
| | st.write(response) |
| | config.status.update(label="Finished!", state="complete", expanded=True) |
| | else: |
| | st.warning("Please enter some text to generate a response.") |
| |
|
| |
|
| | def main_headless(prompt: str): |
| | start = time() |
| | for output in flow.stream({"question": prompt}): |
| | for key, value in output.items(): |
| | pprint(f"Finished running: {key}") |
| | print("\033[94m" + value["generation"] + "\033[0m") |
| | print(f"Time taken: {time() - start:.2f}s\n" + "-" * 20) |
| |
|
| |
|
| | agents = init_agents() |
| |
|
| | if __name__ == "__main__": |
| | if config.headless: |
| | import fire |
| | fire.Fire(main_headless) |
| | else: |
| | main() |
| |
|