Spaces:
Running
Running
| import streamlit as st | |
| from langchain.chat_models import init_chat_model | |
| from langchain_core.messages import HumanMessage, AIMessage | |
| from langgraph.checkpoint.memory import MemorySaver | |
| from langgraph.graph import START, MessagesState, StateGraph | |
| from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder | |
| from typing import Sequence | |
| from langchain_core.messages import SystemMessage, trim_messages | |
| from langchain_core.messages import BaseMessage | |
| from langgraph.graph.message import add_messages | |
| from typing_extensions import Annotated, TypedDict | |
| from dotenv import load_dotenv | |
| import os | |
| import asyncio | |
| import uuid | |
| import pandas as pd | |
| from src.functions_db import connect_to_db, ChatBot | |
| load_dotenv() | |
| groq_api_key = os.getenv("GROQ_API_KEY") | |
| model = init_chat_model("llama3-8b-8192", model_provider="groq") | |
| session, _, _, _, chatbot_history = connect_to_db(address="sqlite:///src/databases/main.db") # connect to the database | |
| class State(TypedDict): | |
| messages: Annotated[Sequence[BaseMessage], add_messages] | |
| language: str | |
| def generate_thread_id(): | |
| return str(uuid.uuid4()) | |
| # here you set the model reponse behavior | |
| prompt_template = ChatPromptTemplate.from_messages( | |
| [ | |
| ( | |
| "system", | |
| "You are a helpful assistant. Answer all questions to the best of your ability in {language}.", | |
| ), | |
| MessagesPlaceholder(variable_name="messages"), | |
| ] | |
| ) | |
| trimmer = trim_messages( | |
| max_tokens=65, | |
| strategy="last", | |
| token_counter=model, | |
| include_system=True, | |
| allow_partial=False, | |
| start_on="human", | |
| ) | |
| # Async function for node: | |
| # Whe give the input message to the promt template, then we give the prompt to the model | |
| async def call_model(state: State): | |
| trimmed_messages = trimmer.invoke(state["messages"]) | |
| prompt = prompt_template.invoke( | |
| {"messages": trimmed_messages, "language": state["language"]} | |
| ) | |
| response = await model.ainvoke(prompt) | |
| return {"messages": [response]} | |
| # Define graph: | |
| workflow = StateGraph(state_schema=State) # because whe swapped MessagesState to hand-build State with additional input language | |
| workflow.add_edge(START, "model") | |
| workflow.add_node("model", call_model) | |
| app = workflow.compile(checkpointer=MemorySaver()) | |
| async def chatbot(query, thread_id, history, language="english"): #english is default language | |
| #Each new conversation or session could have a unique thread_id | |
| config = {"configurable": {"thread_id": thread_id}} | |
| input_messages = history + [HumanMessage(content=query)] | |
| output = await app.ainvoke({"messages": input_messages, "language": language}, config) | |
| # output["messages"][-1].pretty_print() #for logs | |
| return output["messages"][-1] | |
| ################################################################################ | |
| tab1, tab2 = st.tabs(["ChatBot", "DB_Extraction"]) | |
| st.sidebar.title("App parameters") | |
| language = st.sidebar.selectbox("Select Language", ["english", "french", "spanish"]) | |
| tab1.write("This is the Chatbot LangChain app.") | |
| if "history" not in st.session_state: | |
| st.session_state["history"] = [] | |
| if "session_id" not in st.session_state: | |
| st.session_state["session_id"] = generate_thread_id() | |
| session_id = st.session_state["session_id"] | |
| with tab1.container(): | |
| user_input = tab1.chat_input("Say something") | |
| if user_input: | |
| # session_id = generate_thread_id() | |
| st.session_state["history"].append(HumanMessage(content=user_input)) | |
| output = asyncio.run(chatbot(user_input, session_id, st.session_state["history"], language)) | |
| st.session_state["history"].append(output) | |
| chatbot_history.add_gpt_history( | |
| session, | |
| session_id=session_id, | |
| user_input=user_input, | |
| model_output=output.content | |
| ) | |
| # Display the conversation history | |
| for message in reversed(st.session_state["history"]): | |
| if isinstance(message, HumanMessage): | |
| with tab1.chat_message("user"): | |
| tab1.write(message.content) | |
| elif isinstance(message, AIMessage): | |
| with tab1.chat_message("assistant"): | |
| tab1.write(message.content) | |
| ################################################################################ | |
| # TAB 2 DB | |
| ################################################################################ | |
| chatbot_histories = session.query(ChatBot).all() | |
| chatbot_histories_df = pd.DataFrame([{ | |
| 'id': history.id, | |
| 'session_id': history.session_id, | |
| 'user_input': history.user_input, | |
| 'model_output': history.model_output | |
| } for history in chatbot_histories]) | |
| tab2.write("Chatbot history:") | |
| tab2.data_editor(chatbot_histories_df) |