| |
| import os |
| import gradio as gr |
| from typing import TypedDict, Annotated |
| from langgraph.graph import StateGraph, START, END |
| from langgraph.prebuilt import ToolNode, tools_condition |
| from langgraph.graph.message import add_messages |
| from langchain_core.messages import BaseMessage, HumanMessage, AIMessage |
| from langchain_groq import ChatGroq |
| from langchain_community.tools import DuckDuckGoSearchRun |
| from langchain_core.tools import tool |
| import requests |
|
|
| os.environ["GROQ_API_KEY"] = os.getenv("GROQ_API_KEY") |
|
|
| |
| llm = ChatGroq( |
| model="llama-3.1-8b-instant", |
| max_tokens=300, |
| temperature=0 |
| ) |
|
|
| |
| @tool |
| def search(query: str) -> str: |
| """Search the web for current information.""" |
| return DuckDuckGoSearchRun().run(query) |
|
|
| @tool |
| def calculator(first_num: float, second_num: float, operation: str) -> dict: |
| """Perform basic arithmetic. Operations: add, sub, mul, div""" |
| ops = {"add": first_num + second_num, "sub": first_num - second_num, |
| "mul": first_num * second_num} |
| if operation == "div": |
| return {"result": "Division by zero" if second_num == 0 else first_num / second_num} |
| return {"result": ops.get(operation, f"Unknown operation: {operation}")} |
|
|
| @tool |
| def get_stock_price(symbol: str) -> dict: |
| """Fetch latest stock price for a symbol like AAPL or TSLA.""" |
| url = f"https://www.alphavantage.co/query?function=GLOBAL_QUOTE&symbol={symbol}&apikey={os.getenv('STOCKS_API_KEY')}" |
| return requests.get(url).json() |
|
|
| tools = [search, calculator, get_stock_price] |
| llm_with_tools = llm.bind_tools(tools) |
|
|
| |
| class ChatState(TypedDict): |
| messages: Annotated[list[BaseMessage], add_messages] |
|
|
| |
| def chat_node(state: ChatState): |
| """LLM node that may answer or request a tool call.""" |
| messages = state['messages'] |
| response = llm_with_tools.invoke(messages) |
| return {"messages": [response]} |
|
|
| graph = StateGraph(ChatState) |
| graph.add_node("chat_node", chat_node) |
| graph.add_node("tools", ToolNode(tools)) |
| graph.add_edge(START, "chat_node") |
| graph.add_conditional_edges("chat_node", tools_condition) |
| graph.add_edge("tools", "chat_node") |
| agent = graph.compile() |
|
|
| |
| def respond(message, history): |
| |
| messages = [] |
| for user_msg, bot_msg in history: |
| messages.append(HumanMessage(content=user_msg)) |
| if bot_msg: |
| messages.append(AIMessage(content=bot_msg)) |
| messages.append(HumanMessage(content=message)) |
|
|
| result = agent.invoke({"messages": messages}) |
| return result["messages"][-1].content |
|
|
| demo = gr.ChatInterface( |
| fn=respond, |
| title="π AI Research Agent", |
| description="Ask me anything β I can search the web and do calculations!", |
| examples=[ |
| "What is LangGraph?", |
| "What's happening in AI news today?", |
| "Calculate 128 multiplied by 37", |
| ], |
| cache_examples=False |
| ) |
|
|
| if __name__ == "__main__": |
| demo.launch() |