Spaces:
Sleeping
Sleeping
Bryce Guinta commited on
Commit ·
8a2fa64
0
Parent(s):
Initial commit
Browse files- .gitignore +5 -0
- README.md +1 -0
- app.py +83 -0
- graph.py +98 -0
- logging-config.json +38 -0
- pyproject.toml +16 -0
- setup.cfg +2 -0
.gitignore
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.env
|
| 2 |
+
uv.lock
|
| 3 |
+
.idea
|
| 4 |
+
.python-version
|
| 5 |
+
__pycache__
|
README.md
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
A template for chatbot streaming from langgraph with gradio
|
app.py
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
|
| 3 |
+
from uuid import uuid4
|
| 4 |
+
import logging
|
| 5 |
+
import logging.config
|
| 6 |
+
import json
|
| 7 |
+
|
| 8 |
+
import gradio as gr
|
| 9 |
+
from dotenv import load_dotenv
|
| 10 |
+
from langgraph.types import RunnableConfig
|
| 11 |
+
|
| 12 |
+
from graph import GraphProcessingState, graph
|
| 13 |
+
|
| 14 |
+
logger = logging.getLogger(__name__)
|
| 15 |
+
|
| 16 |
+
def setup_logging():
|
| 17 |
+
with open("logging-config.json") as fh:
|
| 18 |
+
config = json.load(fh)
|
| 19 |
+
logging.config.dictConfig(config)
|
| 20 |
+
|
| 21 |
+
async def chat_fn(message, history, input_graph_state, uuid):
|
| 22 |
+
try:
|
| 23 |
+
input_graph_state.user_input = message
|
| 24 |
+
input_graph_state.history = history
|
| 25 |
+
config = RunnableConfig()
|
| 26 |
+
config["configurable"] = {}
|
| 27 |
+
config["configurable"]["thread_id"] = uuid
|
| 28 |
+
|
| 29 |
+
output = ""
|
| 30 |
+
async for msg, metadata in graph.astream(
|
| 31 |
+
{"user_input": input_graph_state.user_input, "history": input_graph_state.history},
|
| 32 |
+
config=config,
|
| 33 |
+
stream_mode="messages",
|
| 34 |
+
):
|
| 35 |
+
# assistant_node is the name we defined in the langraph graph
|
| 36 |
+
if metadata['langgraph_node'] == "assistant_node" and msg.content:
|
| 37 |
+
output += msg.content
|
| 38 |
+
yield output
|
| 39 |
+
except Exception:
|
| 40 |
+
logger.exception("Exception occurred")
|
| 41 |
+
user_error_message = "There was an error processing your request. Please try again."
|
| 42 |
+
yield user_error_message # , input_graph_state
|
| 43 |
+
|
| 44 |
+
def clear():
|
| 45 |
+
return GraphProcessingState(), uuid4()
|
| 46 |
+
|
| 47 |
+
if __name__ == "__main__":
|
| 48 |
+
load_dotenv()
|
| 49 |
+
setup_logging()
|
| 50 |
+
logger.info("Starting the interface")
|
| 51 |
+
with gr.Blocks(title="Langgraph Template", fill_height=True, css="footer {visibility: hidden}") as app:
|
| 52 |
+
gradio_graph_state = gr.State(
|
| 53 |
+
value=GraphProcessingState
|
| 54 |
+
)
|
| 55 |
+
uuid_state = gr.State(
|
| 56 |
+
uuid4
|
| 57 |
+
)
|
| 58 |
+
chatbot = gr.Chatbot(
|
| 59 |
+
# avatar_images=(None, "assets/ai-avatar.png"),
|
| 60 |
+
type="messages",
|
| 61 |
+
# placeholder=WELCOME_MESSAGE,/
|
| 62 |
+
scale=1,
|
| 63 |
+
)
|
| 64 |
+
chatbot.clear(fn=clear, outputs=[gradio_graph_state, uuid_state])
|
| 65 |
+
chat_interface = gr.ChatInterface(
|
| 66 |
+
chatbot=chatbot,
|
| 67 |
+
fn=chat_fn,
|
| 68 |
+
additional_inputs=[
|
| 69 |
+
gradio_graph_state,
|
| 70 |
+
uuid_state,
|
| 71 |
+
],
|
| 72 |
+
additional_outputs=[
|
| 73 |
+
# gradio_graph_state
|
| 74 |
+
],
|
| 75 |
+
type="messages",
|
| 76 |
+
multimodal=False,
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
app.launch(
|
| 80 |
+
server_name="127.0.0.1",
|
| 81 |
+
server_port=7860,
|
| 82 |
+
# favicon_path="assets/favicon.ico"
|
| 83 |
+
)
|
graph.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Annotated
|
| 2 |
+
|
| 3 |
+
import aiohttp
|
| 4 |
+
from langchain_core.messages import AnyMessage
|
| 5 |
+
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
| 6 |
+
from langchain_core.tools import tool
|
| 7 |
+
from langchain_openai import ChatOpenAI
|
| 8 |
+
from langgraph.graph.state import CompiledStateGraph
|
| 9 |
+
from langgraph.prebuilt import ToolNode
|
| 10 |
+
from langgraph.types import RunnableConfig
|
| 11 |
+
from langgraph.graph import Graph, StateGraph, END, MessagesState, add_messages
|
| 12 |
+
from pydantic import BaseModel, Field, ValidationError
|
| 13 |
+
from trafilatura import extract
|
| 14 |
+
|
| 15 |
+
@tool
|
| 16 |
+
async def download_website_text(url: str, config: RunnableConfig) -> str:
|
| 17 |
+
"""Downloads the text from a website
|
| 18 |
+
|
| 19 |
+
args:
|
| 20 |
+
url: The URL of the website
|
| 21 |
+
"""
|
| 22 |
+
async with aiohttp.ClientSession() as session:
|
| 23 |
+
async with session.get(url) as response:
|
| 24 |
+
downloaded = await response.text()
|
| 25 |
+
result = extract(downloaded, include_formatting=True, include_links=True, output_format='json', with_metadata=True)
|
| 26 |
+
if result:
|
| 27 |
+
return result
|
| 28 |
+
return "No text found on the website"
|
| 29 |
+
|
| 30 |
+
tools = [download_website_text]
|
| 31 |
+
|
| 32 |
+
ASSISTANT_SYSTEM_PROMPT = "You are a helpful assistant."
|
| 33 |
+
assistant_model = ChatOpenAI(model="gpt-4o-mini", tags=["assistant"]).bind_tools(tools)
|
| 34 |
+
|
| 35 |
+
class GraphProcessingState(BaseModel):
|
| 36 |
+
user_input: str = Field(default_factory=str, description="The original user input")
|
| 37 |
+
history: list[dict] = Field(default_factory=list, description="Chat history") # type: ignore
|
| 38 |
+
messages: Annotated[list[AnyMessage], add_messages] = Field(default_factory=list)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
async def assistant_node(state: GraphProcessingState, config=None):
|
| 42 |
+
prompt = ChatPromptTemplate.from_messages(
|
| 43 |
+
[
|
| 44 |
+
("system", ASSISTANT_SYSTEM_PROMPT),
|
| 45 |
+
MessagesPlaceholder(variable_name="chat_history"),
|
| 46 |
+
("user", "{user_input}"),
|
| 47 |
+
*state.messages,
|
| 48 |
+
]
|
| 49 |
+
)
|
| 50 |
+
chain = prompt | assistant_model
|
| 51 |
+
response = await chain.ainvoke({"user_input": state.user_input, "chat_history": state.history}, config)
|
| 52 |
+
return {"messages": response}
|
| 53 |
+
|
| 54 |
+
def assistant_cond_edge(state: GraphProcessingState, config=None):
|
| 55 |
+
if not state.messages[-1].content:
|
| 56 |
+
return "tools"
|
| 57 |
+
return END
|
| 58 |
+
|
| 59 |
+
def define_workflow() -> CompiledStateGraph:
|
| 60 |
+
"""Defines the workflow graph"""
|
| 61 |
+
# Initialize the graph
|
| 62 |
+
workflow = StateGraph(GraphProcessingState)
|
| 63 |
+
|
| 64 |
+
# Add nodes
|
| 65 |
+
workflow.add_node("assistant_node", assistant_node)
|
| 66 |
+
workflow.add_node("tools", ToolNode(tools))
|
| 67 |
+
|
| 68 |
+
workflow.add_edge("tools", "assistant_node")
|
| 69 |
+
|
| 70 |
+
# Conditional routing
|
| 71 |
+
workflow.add_conditional_edges(
|
| 72 |
+
"assistant_node",
|
| 73 |
+
# If the latest message (result) from assistant is a tool call -> tools_condition routes to tools
|
| 74 |
+
# If the latest message (result) from assistant is a not a tool call -> tools_condition routes to END
|
| 75 |
+
assistant_cond_edge,
|
| 76 |
+
)
|
| 77 |
+
# Set end nodes
|
| 78 |
+
workflow.set_entry_point("assistant_node")
|
| 79 |
+
# workflow.set_finish_point("assistant_node")
|
| 80 |
+
|
| 81 |
+
return workflow.compile()
|
| 82 |
+
|
| 83 |
+
graph = define_workflow()
|
| 84 |
+
#
|
| 85 |
+
# async def process_user_input_graph(input_state: GraphProcessingState, thread_id=None) -> GraphProcessingState:
|
| 86 |
+
# config: RunnableConfig = RunnableConfig()
|
| 87 |
+
# if "configurable" not in config:
|
| 88 |
+
# config["configurable"] = {}
|
| 89 |
+
# if thread_id:
|
| 90 |
+
# config["configurable"]["thread_id"] = thread_id
|
| 91 |
+
# final_state_dict = await graph.ainvoke(
|
| 92 |
+
# input_state,
|
| 93 |
+
# config=config,
|
| 94 |
+
# )
|
| 95 |
+
# final_state = GraphProcessingState(**final_state_dict)
|
| 96 |
+
# final_state.user_input = ""
|
| 97 |
+
# final_state.history = []
|
| 98 |
+
# return final_state
|
logging-config.json
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"version": 1,
|
| 3 |
+
"disable_existing_loggers": true,
|
| 4 |
+
"formatters": {
|
| 5 |
+
"simple": {
|
| 6 |
+
"format": "%(levelname)s: %(message)s"
|
| 7 |
+
},
|
| 8 |
+
"detailed": {
|
| 9 |
+
"format": "[%(levelname)s|%(module)s|L%(lineno)d] %(asctime)s: %(message)s",
|
| 10 |
+
"datefmt": "%Y-%m-%dT%H:%M:%S%z"
|
| 11 |
+
}
|
| 12 |
+
},
|
| 13 |
+
"handlers": {
|
| 14 |
+
"stdout": {
|
| 15 |
+
"class": "logging.StreamHandler",
|
| 16 |
+
"level": "INFO",
|
| 17 |
+
"formatter": "detailed",
|
| 18 |
+
"stream": "ext://sys.stdout"
|
| 19 |
+
},
|
| 20 |
+
"file": {
|
| 21 |
+
"class": "logging.handlers.RotatingFileHandler",
|
| 22 |
+
"level": "DEBUG",
|
| 23 |
+
"formatter": "detailed",
|
| 24 |
+
"filename": "logs/my_app.log",
|
| 25 |
+
"maxBytes": 10000,
|
| 26 |
+
"backupCount": 3
|
| 27 |
+
}
|
| 28 |
+
},
|
| 29 |
+
"loggers": {
|
| 30 |
+
"root": {
|
| 31 |
+
"level": "DEBUG",
|
| 32 |
+
"handlers": [
|
| 33 |
+
"stdout",
|
| 34 |
+
"file"
|
| 35 |
+
]
|
| 36 |
+
}
|
| 37 |
+
}
|
| 38 |
+
}
|
pyproject.toml
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[project]
|
| 2 |
+
name = "gradio-langgraph-template"
|
| 3 |
+
version = "0.1.0"
|
| 4 |
+
description = "Add your description here"
|
| 5 |
+
readme = "README.md"
|
| 6 |
+
requires-python = ">=3.12"
|
| 7 |
+
dependencies = [
|
| 8 |
+
"aiohttp>=3.11.12",
|
| 9 |
+
"gradio==5.16.2",
|
| 10 |
+
"langchain-core==0.3.37",
|
| 11 |
+
"langchain-openai==0.3.6",
|
| 12 |
+
"langgraph==0.2.74",
|
| 13 |
+
"pydantic==2.10.6",
|
| 14 |
+
"python-dotenv==1.0.1",
|
| 15 |
+
"trafilatura==2.0.0",
|
| 16 |
+
]
|
setup.cfg
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[flake8]
|
| 2 |
+
extend-ignore = E302,E501,E305,E402
|