Bryce Guinta commited on
Commit
8a2fa64
·
0 Parent(s):

Initial commit

Browse files
Files changed (7) hide show
  1. .gitignore +5 -0
  2. README.md +1 -0
  3. app.py +83 -0
  4. graph.py +98 -0
  5. logging-config.json +38 -0
  6. pyproject.toml +16 -0
  7. setup.cfg +2 -0
.gitignore ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ .env
2
+ uv.lock
3
+ .idea
4
+ .python-version
5
+ __pycache__
README.md ADDED
@@ -0,0 +1 @@
 
 
1
+ A template for chatbot streaming from langgraph with gradio
app.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ from uuid import uuid4
4
+ import logging
5
+ import logging.config
6
+ import json
7
+
8
+ import gradio as gr
9
+ from dotenv import load_dotenv
10
+ from langgraph.types import RunnableConfig
11
+
12
+ from graph import GraphProcessingState, graph
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+ def setup_logging():
17
+ with open("logging-config.json") as fh:
18
+ config = json.load(fh)
19
+ logging.config.dictConfig(config)
20
+
21
+ async def chat_fn(message, history, input_graph_state, uuid):
22
+ try:
23
+ input_graph_state.user_input = message
24
+ input_graph_state.history = history
25
+ config = RunnableConfig()
26
+ config["configurable"] = {}
27
+ config["configurable"]["thread_id"] = uuid
28
+
29
+ output = ""
30
+ async for msg, metadata in graph.astream(
31
+ {"user_input": input_graph_state.user_input, "history": input_graph_state.history},
32
+ config=config,
33
+ stream_mode="messages",
34
+ ):
35
+ # assistant_node is the name we defined in the langraph graph
36
+ if metadata['langgraph_node'] == "assistant_node" and msg.content:
37
+ output += msg.content
38
+ yield output
39
+ except Exception:
40
+ logger.exception("Exception occurred")
41
+ user_error_message = "There was an error processing your request. Please try again."
42
+ yield user_error_message # , input_graph_state
43
+
44
+ def clear():
45
+ return GraphProcessingState(), uuid4()
46
+
47
+ if __name__ == "__main__":
48
+ load_dotenv()
49
+ setup_logging()
50
+ logger.info("Starting the interface")
51
+ with gr.Blocks(title="Langgraph Template", fill_height=True, css="footer {visibility: hidden}") as app:
52
+ gradio_graph_state = gr.State(
53
+ value=GraphProcessingState
54
+ )
55
+ uuid_state = gr.State(
56
+ uuid4
57
+ )
58
+ chatbot = gr.Chatbot(
59
+ # avatar_images=(None, "assets/ai-avatar.png"),
60
+ type="messages",
61
+ # placeholder=WELCOME_MESSAGE,/
62
+ scale=1,
63
+ )
64
+ chatbot.clear(fn=clear, outputs=[gradio_graph_state, uuid_state])
65
+ chat_interface = gr.ChatInterface(
66
+ chatbot=chatbot,
67
+ fn=chat_fn,
68
+ additional_inputs=[
69
+ gradio_graph_state,
70
+ uuid_state,
71
+ ],
72
+ additional_outputs=[
73
+ # gradio_graph_state
74
+ ],
75
+ type="messages",
76
+ multimodal=False,
77
+ )
78
+
79
+ app.launch(
80
+ server_name="127.0.0.1",
81
+ server_port=7860,
82
+ # favicon_path="assets/favicon.ico"
83
+ )
graph.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Annotated
2
+
3
+ import aiohttp
4
+ from langchain_core.messages import AnyMessage
5
+ from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
6
+ from langchain_core.tools import tool
7
+ from langchain_openai import ChatOpenAI
8
+ from langgraph.graph.state import CompiledStateGraph
9
+ from langgraph.prebuilt import ToolNode
10
+ from langgraph.types import RunnableConfig
11
+ from langgraph.graph import Graph, StateGraph, END, MessagesState, add_messages
12
+ from pydantic import BaseModel, Field, ValidationError
13
+ from trafilatura import extract
14
+
15
+ @tool
16
+ async def download_website_text(url: str, config: RunnableConfig) -> str:
17
+ """Downloads the text from a website
18
+
19
+ args:
20
+ url: The URL of the website
21
+ """
22
+ async with aiohttp.ClientSession() as session:
23
+ async with session.get(url) as response:
24
+ downloaded = await response.text()
25
+ result = extract(downloaded, include_formatting=True, include_links=True, output_format='json', with_metadata=True)
26
+ if result:
27
+ return result
28
+ return "No text found on the website"
29
+
30
+ tools = [download_website_text]
31
+
32
+ ASSISTANT_SYSTEM_PROMPT = "You are a helpful assistant."
33
+ assistant_model = ChatOpenAI(model="gpt-4o-mini", tags=["assistant"]).bind_tools(tools)
34
+
35
+ class GraphProcessingState(BaseModel):
36
+ user_input: str = Field(default_factory=str, description="The original user input")
37
+ history: list[dict] = Field(default_factory=list, description="Chat history") # type: ignore
38
+ messages: Annotated[list[AnyMessage], add_messages] = Field(default_factory=list)
39
+
40
+
41
+ async def assistant_node(state: GraphProcessingState, config=None):
42
+ prompt = ChatPromptTemplate.from_messages(
43
+ [
44
+ ("system", ASSISTANT_SYSTEM_PROMPT),
45
+ MessagesPlaceholder(variable_name="chat_history"),
46
+ ("user", "{user_input}"),
47
+ *state.messages,
48
+ ]
49
+ )
50
+ chain = prompt | assistant_model
51
+ response = await chain.ainvoke({"user_input": state.user_input, "chat_history": state.history}, config)
52
+ return {"messages": response}
53
+
54
+ def assistant_cond_edge(state: GraphProcessingState, config=None):
55
+ if not state.messages[-1].content:
56
+ return "tools"
57
+ return END
58
+
59
+ def define_workflow() -> CompiledStateGraph:
60
+ """Defines the workflow graph"""
61
+ # Initialize the graph
62
+ workflow = StateGraph(GraphProcessingState)
63
+
64
+ # Add nodes
65
+ workflow.add_node("assistant_node", assistant_node)
66
+ workflow.add_node("tools", ToolNode(tools))
67
+
68
+ workflow.add_edge("tools", "assistant_node")
69
+
70
+ # Conditional routing
71
+ workflow.add_conditional_edges(
72
+ "assistant_node",
73
+ # If the latest message (result) from assistant is a tool call -> tools_condition routes to tools
74
+ # If the latest message (result) from assistant is a not a tool call -> tools_condition routes to END
75
+ assistant_cond_edge,
76
+ )
77
+ # Set end nodes
78
+ workflow.set_entry_point("assistant_node")
79
+ # workflow.set_finish_point("assistant_node")
80
+
81
+ return workflow.compile()
82
+
83
+ graph = define_workflow()
84
+ #
85
+ # async def process_user_input_graph(input_state: GraphProcessingState, thread_id=None) -> GraphProcessingState:
86
+ # config: RunnableConfig = RunnableConfig()
87
+ # if "configurable" not in config:
88
+ # config["configurable"] = {}
89
+ # if thread_id:
90
+ # config["configurable"]["thread_id"] = thread_id
91
+ # final_state_dict = await graph.ainvoke(
92
+ # input_state,
93
+ # config=config,
94
+ # )
95
+ # final_state = GraphProcessingState(**final_state_dict)
96
+ # final_state.user_input = ""
97
+ # final_state.history = []
98
+ # return final_state
logging-config.json ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "version": 1,
3
+ "disable_existing_loggers": true,
4
+ "formatters": {
5
+ "simple": {
6
+ "format": "%(levelname)s: %(message)s"
7
+ },
8
+ "detailed": {
9
+ "format": "[%(levelname)s|%(module)s|L%(lineno)d] %(asctime)s: %(message)s",
10
+ "datefmt": "%Y-%m-%dT%H:%M:%S%z"
11
+ }
12
+ },
13
+ "handlers": {
14
+ "stdout": {
15
+ "class": "logging.StreamHandler",
16
+ "level": "INFO",
17
+ "formatter": "detailed",
18
+ "stream": "ext://sys.stdout"
19
+ },
20
+ "file": {
21
+ "class": "logging.handlers.RotatingFileHandler",
22
+ "level": "DEBUG",
23
+ "formatter": "detailed",
24
+ "filename": "logs/my_app.log",
25
+ "maxBytes": 10000,
26
+ "backupCount": 3
27
+ }
28
+ },
29
+ "loggers": {
30
+ "root": {
31
+ "level": "DEBUG",
32
+ "handlers": [
33
+ "stdout",
34
+ "file"
35
+ ]
36
+ }
37
+ }
38
+ }
pyproject.toml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "gradio-langgraph-template"
3
+ version = "0.1.0"
4
+ description = "Add your description here"
5
+ readme = "README.md"
6
+ requires-python = ">=3.12"
7
+ dependencies = [
8
+ "aiohttp>=3.11.12",
9
+ "gradio==5.16.2",
10
+ "langchain-core==0.3.37",
11
+ "langchain-openai==0.3.6",
12
+ "langgraph==0.2.74",
13
+ "pydantic==2.10.6",
14
+ "python-dotenv==1.0.1",
15
+ "trafilatura==2.0.0",
16
+ ]
setup.cfg ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ [flake8]
2
+ extend-ignore = E302,E501,E305,E402