Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import os | |
| import re | |
| import matplotlib.pyplot as plt | |
| from io import BytesIO | |
| from langgraph.graph import StateGraph, MessagesState, START, END | |
| from langgraph.types import Command | |
| from langchain_core.messages import HumanMessage | |
| from langgraph.prebuilt import create_react_agent | |
| from langchain_anthropic import ChatAnthropic | |
| # Load API Key | |
| os.environ["ANTHROPIC_API_KEY"] = os.getenv("ANTHROPIC_API_KEY") | |
| # LangGraph setup | |
| llm = ChatAnthropic(model="claude-3-5-sonnet-latest") | |
| def make_system_prompt(suffix: str) -> str: | |
| return ( | |
| "You are a helpful AI assistant, collaborating with other assistants." | |
| " Use the provided tools to progress towards answering the question." | |
| " If you are unable to fully answer, that's OK, another assistant with different tools " | |
| " will help where you left off. Execute what you can to make progress." | |
| " If you or any of the other assistants have the final answer or deliverable," | |
| " prefix your response with FINAL ANSWER so the team knows to stop." | |
| f"\n{suffix}" | |
| ) | |
| def research_node(state: MessagesState) -> Command[str]: | |
| agent = create_react_agent( | |
| llm, | |
| tools=[], | |
| state_modifier=make_system_prompt("You can only do research.") | |
| ) | |
| result = agent.invoke(state) | |
| goto = END if "FINAL ANSWER" in result["messages"][-1].content else "chart_generator" | |
| result["messages"][-1] = HumanMessage(content=result["messages"][-1].content, name="researcher") | |
| return Command(update={"messages": result["messages"]}, goto=goto) | |
| def chart_node(state: MessagesState) -> Command[str]: | |
| agent = create_react_agent( | |
| llm, | |
| tools=[], | |
| state_modifier=make_system_prompt("You can only generate charts.") | |
| ) | |
| result = agent.invoke(state) | |
| goto = END if "FINAL ANSWER" in result["messages"][-1].content else "researcher" | |
| result["messages"][-1] = HumanMessage(content=result["messages"][-1].content, name="chart_generator") | |
| return Command(update={"messages": result["messages"]}, goto=goto) | |
| # Create the LangGraph workflow | |
| workflow = StateGraph(MessagesState) | |
| workflow.add_node("researcher", research_node) | |
| workflow.add_node("chart_generator", chart_node) | |
| workflow.add_edge(START, "researcher") | |
| workflow.add_edge("researcher", "chart_generator") | |
| workflow.add_edge("chart_generator", END) | |
| graph = workflow.compile() | |
| def extract_chart_data(text): | |
| print("π§ͺ Raw LLM Output to parse:\n", text) | |
| matches = re.findall(r'(\b19\d{2}|\b20\d{2})[^\d]{1,10}(\$?\d+(\.\d+)?)', text) | |
| if not matches: | |
| print("β No year-value pairs found.") | |
| return None, None | |
| years = [] | |
| values = [] | |
| for match in matches: | |
| year = match[0] | |
| value_str = match[1].replace('$', '') | |
| try: | |
| value = float(value_str) | |
| years.append(year) | |
| values.append(value) | |
| except ValueError: | |
| continue | |
| print("β Extracted:", years, values) | |
| return years, values if years and values else (None, None) | |
| def generate_plot(years, values): | |
| fig, ax = plt.subplots() | |
| ax.bar(years, values) | |
| ax.set_title("Generated Chart") | |
| ax.set_xlabel("Year") | |
| ax.set_ylabel("Value") | |
| buf = BytesIO() | |
| plt.savefig(buf, format="png") | |
| buf.seek(0) | |
| return buf | |
| def run_langgraph(user_input): | |
| print("π© Input to LangGraph:", user_input) | |
| events = graph.stream( | |
| {"messages": [("user", user_input)]}, | |
| {"recursion_limit": 150} | |
| ) | |
| final_message = None | |
| for event in events: | |
| print("πΉ Event:", event) | |
| if "messages" in event and event["messages"]: | |
| for m in event["messages"]: | |
| print("πΈ Message:", m.content) | |
| final_message = event["messages"][-1].content | |
| return final_message or "No output generated" | |
| def process_input(user_input): | |
| # π Toggle this to test graph generation without LLM | |
| STATIC_TEST = False | |
| if STATIC_TEST: | |
| dummy_output = """ | |
| FINAL ANSWER: | |
| Here is the GDP of the USA: | |
| 2019: 21.4 | |
| 2020: 20.9 | |
| 2021: 22.1 | |
| 2022: 23.0 | |
| 2023: 24.3 | |
| """ | |
| years, values = extract_chart_data(dummy_output) | |
| if years and values: | |
| chart = generate_plot(years, values) | |
| return dummy_output, chart | |
| else: | |
| return dummy_output, None | |
| # Run actual LangGraph-based flow | |
| result_text = run_langgraph(user_input) | |
| years, values = extract_chart_data(result_text) | |
| if years and values: | |
| chart = generate_plot(years, values) | |
| return result_text, chart | |
| else: | |
| return result_text, None | |
| # Gradio interface | |
| interface = gr.Interface( | |
| fn=process_input, | |
| inputs="text", | |
| outputs=[ | |
| gr.Textbox(label="Generated Response"), | |
| gr.Image(type="pil", label="Generated Chart") | |
| ], | |
| title="LangGraph Research Automation", | |
| description="Enter your research task (e.g., 'Get GDP data for the USA over the past 5 years and create a chart.')" | |
| ) | |
| if __name__ == "__main__": | |
| interface.launch(share=True, ssr_mode=False) | |