Spaces:
Sleeping
Sleeping
Alex Arvanitidis commited on
Commit ·
52fdbfb
1
Parent(s): d37812b
feat: add llm streamlit app
Browse files- app.py +48 -2
- graph.py +40 -0
- requirements.txt +4 -1
- st_callable_util.py +95 -0
app.py
CHANGED
|
@@ -1,4 +1,50 @@
|
|
|
|
|
|
|
|
| 1 |
import streamlit as st
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
-
|
| 4 |
-
st.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
import streamlit as st
|
| 4 |
+
from langchain_core.messages import AIMessage, HumanMessage
|
| 5 |
+
|
| 6 |
+
from graph import invoke_our_graph
|
| 7 |
+
from st_callable_util import get_streamlit_cb # Utility function to get a Streamlit callback handler with context
|
| 8 |
+
|
| 9 |
+
st.title("StreamLit 🤝 LangGraph")
|
| 10 |
+
st.markdown("#### Simple Chat Streaming")
|
| 11 |
+
|
| 12 |
+
# st write magic
|
| 13 |
+
"""
|
| 14 |
+
In this example, we're going to be creating our own [`BaseCallbackHandler`](https://api.python.langchain.com/en/latest/callbacks/langchain_core.callbacks.base.BaseCallbackHandler.html) called StreamHandler
|
| 15 |
+
to stream our [_LangGraph_](https://langchain-ai.github.io/langgraph/) invocations and leveraging callbacks in our
|
| 16 |
+
graph's [`RunnableConfig`](https://api.python.langchain.com/en/latest/runnables/langchain_core.runnables.config.RunnableConfig.html).
|
| 17 |
+
|
| 18 |
+
The BaseCallBackHandler is a [Mixin](https://www.wikiwand.com/en/articles/Mixin) overloader function which we will use
|
| 19 |
+
to implement only `on_llm_new_token`, a method that run on every new generation of a token from the ChatLLM model.
|
| 20 |
+
|
| 21 |
+
---
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
# Check if the API key is available as an environment variable
|
| 25 |
+
if "messages" not in st.session_state:
|
| 26 |
+
# default initial message to render in message state
|
| 27 |
+
st.session_state["messages"] = [AIMessage(content="How can I help you?")]
|
| 28 |
+
|
| 29 |
+
# Loop through all messages in the session state and render them as a chat on every st.refresh mech
|
| 30 |
+
for msg in st.session_state.messages:
|
| 31 |
+
# https://docs.streamlit.io/develop/api-reference/chat/st.chat_message
|
| 32 |
+
# we store them as AIMessage and HumanMessage as its easier to send to LangGraph
|
| 33 |
+
if type(msg) == AIMessage:
|
| 34 |
+
st.chat_message("assistant").write(msg.content)
|
| 35 |
+
if type(msg) == HumanMessage:
|
| 36 |
+
st.chat_message("user").write(msg.content)
|
| 37 |
+
|
| 38 |
+
# takes new input in chat box from user and invokes the graph
|
| 39 |
+
if prompt := st.chat_input():
|
| 40 |
+
st.session_state.messages.append(HumanMessage(content=prompt))
|
| 41 |
+
st.chat_message("user").write(prompt)
|
| 42 |
|
| 43 |
+
# Process the AI's response and handles graph events using the callback mechanism
|
| 44 |
+
with st.chat_message("assistant"):
|
| 45 |
+
# create a new container for streaming messages only, and give it context
|
| 46 |
+
st_callback = get_streamlit_cb(st.container())
|
| 47 |
+
response = invoke_our_graph(st.session_state.messages, [st_callback])
|
| 48 |
+
# Add that last message to the st_message_state
|
| 49 |
+
# Streamlit's refresh the message will automatically be visually rendered bc of the msg render for loop above
|
| 50 |
+
st.session_state.messages.append(AIMessage(content=response["messages"][-1].content))
|
graph.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from typing import Annotated, TypedDict
|
| 3 |
+
|
| 4 |
+
from langgraph.graph import START, END, StateGraph
|
| 5 |
+
from langgraph.graph.message import AnyMessage, add_messages
|
| 6 |
+
from langchain_openai import ChatOpenAI
|
| 7 |
+
|
| 8 |
+
# This is the default state same as "MessageState" TypedDict but allows us accessibility to custom keys
|
| 9 |
+
class GraphsState(TypedDict):
|
| 10 |
+
messages: Annotated[list[AnyMessage], add_messages]
|
| 11 |
+
# Custom keys for additional data can be added here such as - conversation_id: str
|
| 12 |
+
|
| 13 |
+
graph = StateGraph(GraphsState)
|
| 14 |
+
|
| 15 |
+
# Core invocation of the model
|
| 16 |
+
def _call_model(state: GraphsState):
|
| 17 |
+
messages = state["messages"]
|
| 18 |
+
llm = ChatOpenAI(
|
| 19 |
+
model=os.environ["LLM_MODEL_ID"],
|
| 20 |
+
max_retries=2,
|
| 21 |
+
api_key="None",
|
| 22 |
+
base_url=os.environ["LLM_API_BASE"],
|
| 23 |
+
)
|
| 24 |
+
response = llm.invoke(messages)
|
| 25 |
+
return {"messages": [response]}# add the response to the messages using LangGraph reducer paradigm
|
| 26 |
+
|
| 27 |
+
# Define the structure (nodes and directional edges between nodes) of the graph
|
| 28 |
+
graph.add_edge(START, "modelNode")
|
| 29 |
+
graph.add_node("modelNode", _call_model)
|
| 30 |
+
graph.add_edge("modelNode", END)
|
| 31 |
+
|
| 32 |
+
# Compile the state graph into a runnable object
|
| 33 |
+
graph_runnable = graph.compile()
|
| 34 |
+
|
| 35 |
+
def invoke_our_graph(st_messages, callables):
|
| 36 |
+
# Ensure the callables parameter is a list as you can have multiple callbacks
|
| 37 |
+
if not isinstance(callables, list):
|
| 38 |
+
raise TypeError("callables must be a list")
|
| 39 |
+
# Invoke the graph with the current messages and callback configuration
|
| 40 |
+
return graph_runnable.invoke({"messages": st_messages}, config={"callbacks": callables})
|
requirements.txt
CHANGED
|
@@ -1 +1,4 @@
|
|
| 1 |
-
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
langgraph
|
| 2 |
+
streamlit
|
| 3 |
+
langchain-openai
|
| 4 |
+
python-dotenv
|
st_callable_util.py
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Callable, TypeVar
|
| 2 |
+
import inspect
|
| 3 |
+
|
| 4 |
+
from streamlit.runtime.scriptrunner import add_script_run_ctx, get_script_run_ctx
|
| 5 |
+
from streamlit.delta_generator import DeltaGenerator
|
| 6 |
+
|
| 7 |
+
from langchain_core.callbacks.base import BaseCallbackHandler
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
# Define a function to create a callback handler for Streamlit that updates the UI dynamically
|
| 11 |
+
def get_streamlit_cb(parent_container: DeltaGenerator) -> BaseCallbackHandler:
|
| 12 |
+
"""
|
| 13 |
+
Creates a Streamlit callback handler that updates the provided Streamlit container with new tokens.
|
| 14 |
+
|
| 15 |
+
Args:
|
| 16 |
+
parent_container (DeltaGenerator): The Streamlit container where the text will be rendered.
|
| 17 |
+
Returns:
|
| 18 |
+
BaseCallbackHandler: An instance of a callback handler configured for Streamlit.
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
# Define a custom callback handler class for managing and displaying stream events from LangGraph in Streamlit
|
| 22 |
+
class StreamHandler(BaseCallbackHandler):
|
| 23 |
+
"""
|
| 24 |
+
Custom callback handler for Streamlit that updates a Streamlit container with new tokens.
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
def __init__(self, container: DeltaGenerator, initial_text: str = ""):
|
| 28 |
+
"""
|
| 29 |
+
Initializes the StreamHandler with a Streamlit container and optional initial text.
|
| 30 |
+
|
| 31 |
+
Args:
|
| 32 |
+
container (DeltaGenerator): The Streamlit container where text will be rendered.
|
| 33 |
+
initial_text (str): Optional initial text to start with in the container.
|
| 34 |
+
"""
|
| 35 |
+
self.container = container # The Streamlit container to update
|
| 36 |
+
self.token_placeholder = self.container.empty() # Placeholder for dynamic token updates
|
| 37 |
+
self.text = initial_text # Initialize the text content, starting with any initial text
|
| 38 |
+
|
| 39 |
+
def on_llm_new_token(self, token: str, **kwargs) -> None:
|
| 40 |
+
"""
|
| 41 |
+
Callback method triggered when a new token is received (e.g., from a language model).
|
| 42 |
+
|
| 43 |
+
Args:
|
| 44 |
+
token (str): The new token received.
|
| 45 |
+
**kwargs: Additional keyword arguments.
|
| 46 |
+
"""
|
| 47 |
+
self.text += token # Append the new token to the existing text
|
| 48 |
+
self.token_placeholder.write(self.text) # Update the Streamlit container with the full text
|
| 49 |
+
|
| 50 |
+
# Define a type variable for generic type hinting in the decorator, to maintain
|
| 51 |
+
# the return type of the input function and the wrapped function
|
| 52 |
+
fn_return_type = TypeVar('fn_return_type')
|
| 53 |
+
|
| 54 |
+
# Decorator function to add the Streamlit execution context to a function
|
| 55 |
+
def add_streamlit_context(fn: Callable[..., fn_return_type]) -> Callable[..., fn_return_type]:
|
| 56 |
+
"""
|
| 57 |
+
Decorator to ensure that the decorated function runs within the Streamlit execution context.
|
| 58 |
+
|
| 59 |
+
Args:
|
| 60 |
+
fn (Callable[..., fn_return_type]): The function to be decorated.
|
| 61 |
+
|
| 62 |
+
Returns:
|
| 63 |
+
Callable[..., fn_return_type]: The decorated function that includes the Streamlit context setup.
|
| 64 |
+
"""
|
| 65 |
+
# Retrieve the current Streamlit script execution context.
|
| 66 |
+
# This context holds session information necessary for Streamlit operations.
|
| 67 |
+
ctx = get_script_run_ctx()
|
| 68 |
+
|
| 69 |
+
def wrapper(*args, **kwargs) -> fn_return_type:
|
| 70 |
+
"""
|
| 71 |
+
Wrapper function that adds the Streamlit context and then calls the original function.
|
| 72 |
+
|
| 73 |
+
Args:
|
| 74 |
+
*args: Positional arguments to pass to the original function.
|
| 75 |
+
**kwargs: Keyword arguments to pass to the original function.
|
| 76 |
+
|
| 77 |
+
Returns:
|
| 78 |
+
fn_return_type: The result from the original function.
|
| 79 |
+
"""
|
| 80 |
+
add_script_run_ctx(ctx=ctx) # Set the correct Streamlit context for execution
|
| 81 |
+
return fn(*args, **kwargs) # Call the original function with its arguments
|
| 82 |
+
|
| 83 |
+
return wrapper
|
| 84 |
+
|
| 85 |
+
# Create an instance of the custom StreamHandler with the provided Streamlit container
|
| 86 |
+
st_cb = StreamHandler(parent_container)
|
| 87 |
+
|
| 88 |
+
# Iterate over all methods of the StreamHandler instance
|
| 89 |
+
for method_name, method_func in inspect.getmembers(st_cb, predicate=inspect.ismethod):
|
| 90 |
+
if method_name.startswith('on_'): # Identify callback methods that respond to LLM events
|
| 91 |
+
setattr(st_cb, method_name,
|
| 92 |
+
add_streamlit_context(method_func)) # Wrap and replace the method with the context-aware version
|
| 93 |
+
|
| 94 |
+
# Return the fully configured StreamlitCallbackHandler instance, now context-aware and integrated with any ChatLLM
|
| 95 |
+
return st_cb
|