Alex Arvanitidis commited on
Commit
52fdbfb
·
1 Parent(s): d37812b

feat: add llm streamlit app

Browse files
Files changed (4) hide show
  1. app.py +48 -2
  2. graph.py +40 -0
  3. requirements.txt +4 -1
  4. st_callable_util.py +95 -0
app.py CHANGED
@@ -1,4 +1,50 @@
 
 
1
  import streamlit as st
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- x = st.slider('Select a value')
4
- st.write(x, 'squared is', x * x)
 
 
 
 
 
 
 
1
+ import os
2
+
3
  import streamlit as st
4
+ from langchain_core.messages import AIMessage, HumanMessage
5
+
6
+ from graph import invoke_our_graph
7
+ from st_callable_util import get_streamlit_cb # Utility function to get a Streamlit callback handler with context
8
+
9
+ st.title("StreamLit 🤝 LangGraph")
10
+ st.markdown("#### Simple Chat Streaming")
11
+
12
+ # st write magic
13
+ """
14
+ In this example, we're going to be creating our own [`BaseCallbackHandler`](https://api.python.langchain.com/en/latest/callbacks/langchain_core.callbacks.base.BaseCallbackHandler.html) called StreamHandler
15
+ to stream our [_LangGraph_](https://langchain-ai.github.io/langgraph/) invocations and leveraging callbacks in our
16
+ graph's [`RunnableConfig`](https://api.python.langchain.com/en/latest/runnables/langchain_core.runnables.config.RunnableConfig.html).
17
+
18
+ The BaseCallBackHandler is a [Mixin](https://www.wikiwand.com/en/articles/Mixin) overloader function which we will use
19
+ to implement only `on_llm_new_token`, a method that run on every new generation of a token from the ChatLLM model.
20
+
21
+ ---
22
+ """
23
+
24
+ # Check if the API key is available as an environment variable
25
+ if "messages" not in st.session_state:
26
+ # default initial message to render in message state
27
+ st.session_state["messages"] = [AIMessage(content="How can I help you?")]
28
+
29
+ # Loop through all messages in the session state and render them as a chat on every st.refresh mech
30
+ for msg in st.session_state.messages:
31
+ # https://docs.streamlit.io/develop/api-reference/chat/st.chat_message
32
+ # we store them as AIMessage and HumanMessage as its easier to send to LangGraph
33
+ if type(msg) == AIMessage:
34
+ st.chat_message("assistant").write(msg.content)
35
+ if type(msg) == HumanMessage:
36
+ st.chat_message("user").write(msg.content)
37
+
38
+ # takes new input in chat box from user and invokes the graph
39
+ if prompt := st.chat_input():
40
+ st.session_state.messages.append(HumanMessage(content=prompt))
41
+ st.chat_message("user").write(prompt)
42
 
43
+ # Process the AI's response and handles graph events using the callback mechanism
44
+ with st.chat_message("assistant"):
45
+ # create a new container for streaming messages only, and give it context
46
+ st_callback = get_streamlit_cb(st.container())
47
+ response = invoke_our_graph(st.session_state.messages, [st_callback])
48
+ # Add that last message to the st_message_state
49
+ # Streamlit's refresh the message will automatically be visually rendered bc of the msg render for loop above
50
+ st.session_state.messages.append(AIMessage(content=response["messages"][-1].content))
graph.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Annotated, TypedDict
3
+
4
+ from langgraph.graph import START, END, StateGraph
5
+ from langgraph.graph.message import AnyMessage, add_messages
6
+ from langchain_openai import ChatOpenAI
7
+
8
+ # This is the default state same as "MessageState" TypedDict but allows us accessibility to custom keys
9
+ class GraphsState(TypedDict):
10
+ messages: Annotated[list[AnyMessage], add_messages]
11
+ # Custom keys for additional data can be added here such as - conversation_id: str
12
+
13
+ graph = StateGraph(GraphsState)
14
+
15
+ # Core invocation of the model
16
+ def _call_model(state: GraphsState):
17
+ messages = state["messages"]
18
+ llm = ChatOpenAI(
19
+ model=os.environ["LLM_MODEL_ID"],
20
+ max_retries=2,
21
+ api_key="None",
22
+ base_url=os.environ["LLM_API_BASE"],
23
+ )
24
+ response = llm.invoke(messages)
25
+ return {"messages": [response]}# add the response to the messages using LangGraph reducer paradigm
26
+
27
+ # Define the structure (nodes and directional edges between nodes) of the graph
28
+ graph.add_edge(START, "modelNode")
29
+ graph.add_node("modelNode", _call_model)
30
+ graph.add_edge("modelNode", END)
31
+
32
+ # Compile the state graph into a runnable object
33
+ graph_runnable = graph.compile()
34
+
35
+ def invoke_our_graph(st_messages, callables):
36
+ # Ensure the callables parameter is a list as you can have multiple callbacks
37
+ if not isinstance(callables, list):
38
+ raise TypeError("callables must be a list")
39
+ # Invoke the graph with the current messages and callback configuration
40
+ return graph_runnable.invoke({"messages": st_messages}, config={"callbacks": callables})
requirements.txt CHANGED
@@ -1 +1,4 @@
1
- langchain_core
 
 
 
 
1
+ langgraph
2
+ streamlit
3
+ langchain-openai
4
+ python-dotenv
st_callable_util.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable, TypeVar
2
+ import inspect
3
+
4
+ from streamlit.runtime.scriptrunner import add_script_run_ctx, get_script_run_ctx
5
+ from streamlit.delta_generator import DeltaGenerator
6
+
7
+ from langchain_core.callbacks.base import BaseCallbackHandler
8
+
9
+
10
+ # Define a function to create a callback handler for Streamlit that updates the UI dynamically
11
+ def get_streamlit_cb(parent_container: DeltaGenerator) -> BaseCallbackHandler:
12
+ """
13
+ Creates a Streamlit callback handler that updates the provided Streamlit container with new tokens.
14
+
15
+ Args:
16
+ parent_container (DeltaGenerator): The Streamlit container where the text will be rendered.
17
+ Returns:
18
+ BaseCallbackHandler: An instance of a callback handler configured for Streamlit.
19
+ """
20
+
21
+ # Define a custom callback handler class for managing and displaying stream events from LangGraph in Streamlit
22
+ class StreamHandler(BaseCallbackHandler):
23
+ """
24
+ Custom callback handler for Streamlit that updates a Streamlit container with new tokens.
25
+ """
26
+
27
+ def __init__(self, container: DeltaGenerator, initial_text: str = ""):
28
+ """
29
+ Initializes the StreamHandler with a Streamlit container and optional initial text.
30
+
31
+ Args:
32
+ container (DeltaGenerator): The Streamlit container where text will be rendered.
33
+ initial_text (str): Optional initial text to start with in the container.
34
+ """
35
+ self.container = container # The Streamlit container to update
36
+ self.token_placeholder = self.container.empty() # Placeholder for dynamic token updates
37
+ self.text = initial_text # Initialize the text content, starting with any initial text
38
+
39
+ def on_llm_new_token(self, token: str, **kwargs) -> None:
40
+ """
41
+ Callback method triggered when a new token is received (e.g., from a language model).
42
+
43
+ Args:
44
+ token (str): The new token received.
45
+ **kwargs: Additional keyword arguments.
46
+ """
47
+ self.text += token # Append the new token to the existing text
48
+ self.token_placeholder.write(self.text) # Update the Streamlit container with the full text
49
+
50
+ # Define a type variable for generic type hinting in the decorator, to maintain
51
+ # the return type of the input function and the wrapped function
52
+ fn_return_type = TypeVar('fn_return_type')
53
+
54
+ # Decorator function to add the Streamlit execution context to a function
55
+ def add_streamlit_context(fn: Callable[..., fn_return_type]) -> Callable[..., fn_return_type]:
56
+ """
57
+ Decorator to ensure that the decorated function runs within the Streamlit execution context.
58
+
59
+ Args:
60
+ fn (Callable[..., fn_return_type]): The function to be decorated.
61
+
62
+ Returns:
63
+ Callable[..., fn_return_type]: The decorated function that includes the Streamlit context setup.
64
+ """
65
+ # Retrieve the current Streamlit script execution context.
66
+ # This context holds session information necessary for Streamlit operations.
67
+ ctx = get_script_run_ctx()
68
+
69
+ def wrapper(*args, **kwargs) -> fn_return_type:
70
+ """
71
+ Wrapper function that adds the Streamlit context and then calls the original function.
72
+
73
+ Args:
74
+ *args: Positional arguments to pass to the original function.
75
+ **kwargs: Keyword arguments to pass to the original function.
76
+
77
+ Returns:
78
+ fn_return_type: The result from the original function.
79
+ """
80
+ add_script_run_ctx(ctx=ctx) # Set the correct Streamlit context for execution
81
+ return fn(*args, **kwargs) # Call the original function with its arguments
82
+
83
+ return wrapper
84
+
85
+ # Create an instance of the custom StreamHandler with the provided Streamlit container
86
+ st_cb = StreamHandler(parent_container)
87
+
88
+ # Iterate over all methods of the StreamHandler instance
89
+ for method_name, method_func in inspect.getmembers(st_cb, predicate=inspect.ismethod):
90
+ if method_name.startswith('on_'): # Identify callback methods that respond to LLM events
91
+ setattr(st_cb, method_name,
92
+ add_streamlit_context(method_func)) # Wrap and replace the method with the context-aware version
93
+
94
+ # Return the fully configured StreamlitCallbackHandler instance, now context-aware and integrated with any ChatLLM
95
+ return st_cb