DSA_Search / app.py
Jorge Londoño
Changed system prompt
9c10b6e
# Documentation:
# https://www.gradio.app/docs/gradio/chatinterface
# https://www.gradio.app/docs/gradio/chatbot
import logging
logging.basicConfig(level=logging.WARNING) # Level for the root logger
import os
import json
import uuid
from dotenv import load_dotenv
from pydantic import BaseModel
from functools import lru_cache
load_dotenv(verbose=True)
from langchain.prompts import ChatPromptTemplate, PromptTemplate
from langchain_core.runnables import RunnablePassthrough
from langchain.schema import AIMessage, HumanMessage, SystemMessage
from langchain_huggingface import HuggingFaceEmbeddings
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain_core.messages.tool import ToolMessage
from langgraph.types import StateSnapshot
import gradio as gr
from workflow import app, memory
logger = logging.getLogger(__name__) # Child logger for this module
logger.setLevel(logging.DEBUG)
config = {"configurable": {"thread_id": str(uuid.uuid4()) }}
system_messsage ="""You are a helpful assistant.
You only answer questions about data structures and algorithms, discrete math, and computer science in general.
You can use search tools for finding information in the web to answer the user's question.
Output your answers using markdown format and include links to the pages used in constructing the answer."""
def pretty_print(event:dict) -> None:
msgs = event['messages']
for x in msgs:
match x:
case SystemMessage():
print('SystemMessage:', x.content[:80], sep='\n\t')
case HumanMessage():
print('HumanMessage: ', x.content[:80], sep='\n\t')
case AIMessage():
if x.additional_kwargs and 'tool_calls' in x.additional_kwargs:
tool_calls = x.additional_kwargs['tool_calls']
print('AIMessage: ', 'tool_call')
for call in tool_calls:
print('\t','Name = ', call['function']['name'],' Args =', call['function']['arguments'][:80])
else:
print('AIMessage: ', x.content[:80], sep='\n\t')
case ToolMessage():
# print('ToolMessage', x.content[:80]) # Is a JSON string
print('ToolMessage: ')
try:
l = json.loads(x.content)
for d in l:
print('\t', 'url', d['url'])
print('\t', 'content', d['content'][:80])
except Exception as e:
logger.error(str(e))
logger.error(x)
case _:
print('UNKNOWN MESSAGE TYPE', type(x), x)
print('-'*20, '\n')
class Message(BaseModel):
role : str = None
metadata: dict = {}
content : str = None
def stream_response(message:str, history:list[dict]):
if message is not None:
input_message = HumanMessage(content=message)
for event in app.stream({"messages": [input_message]}, config, stream_mode="values"):
pretty_print(event)
yield event["messages"][-1].content
def clear_history() -> dict:
global config
session_id = str(uuid.uuid4())
if 'configurable' not in config:
logger.debug('New config')
config['configurable'] = {"thread_id": session_id}
logger.debug(f'New config: {config}')
system_msg = SystemMessage(system_messsage)
# state:StateSnapshot = app.get_state(config) # StateSnapshot
app.update_state(config, {'messages': system_msg})
with gr.Blocks(theme=gr.themes.Soft()) as demo:
chatbot=gr.Chatbot(
type='messages',
height="80vh")
gr.ChatInterface(stream_response,
type='messages',
chatbot=chatbot,
textbox=gr.Textbox(placeholder="Enter your query...",
container=False,
autoscroll=True,
scale=7),
)
#chatbot.clear(clear_history) # Fails in huggingface
if __name__ == "__main__":
logger.debug('Started main')
system_msg = SystemMessage(system_messsage)
app.update_state(config, {'messages': system_msg})
demo.launch(share=False, debug=False)
# TODO
# When gradio reloads the application, the SystemMessage is lost.