# Documentation: # https://www.gradio.app/docs/gradio/chatinterface # https://www.gradio.app/docs/gradio/chatbot import logging logging.basicConfig(level=logging.WARNING) # Level for the root logger import os import json import uuid from dotenv import load_dotenv from pydantic import BaseModel from functools import lru_cache load_dotenv(verbose=True) from langchain.prompts import ChatPromptTemplate, PromptTemplate from langchain_core.runnables import RunnablePassthrough from langchain.schema import AIMessage, HumanMessage, SystemMessage from langchain_huggingface import HuggingFaceEmbeddings from langchain.chains.combine_documents import create_stuff_documents_chain from langchain_core.messages.tool import ToolMessage from langgraph.types import StateSnapshot import gradio as gr from workflow import app, memory logger = logging.getLogger(__name__) # Child logger for this module logger.setLevel(logging.DEBUG) config = {"configurable": {"thread_id": str(uuid.uuid4()) }} system_messsage ="""You are a helpful assistant. You only answer questions about data structures and algorithms, discrete math, and computer science in general. You can use search tools for finding information in the web to answer the user's question. Output your answers using markdown format and include links to the pages used in constructing the answer.""" def pretty_print(event:dict) -> None: msgs = event['messages'] for x in msgs: match x: case SystemMessage(): print('SystemMessage:', x.content[:80], sep='\n\t') case HumanMessage(): print('HumanMessage: ', x.content[:80], sep='\n\t') case AIMessage(): if x.additional_kwargs and 'tool_calls' in x.additional_kwargs: tool_calls = x.additional_kwargs['tool_calls'] print('AIMessage: ', 'tool_call') for call in tool_calls: print('\t','Name = ', call['function']['name'],' Args =', call['function']['arguments'][:80]) else: print('AIMessage: ', x.content[:80], sep='\n\t') case ToolMessage(): # print('ToolMessage', x.content[:80]) # Is a JSON string print('ToolMessage: ') try: l = json.loads(x.content) for d in l: print('\t', 'url', d['url']) print('\t', 'content', d['content'][:80]) except Exception as e: logger.error(str(e)) logger.error(x) case _: print('UNKNOWN MESSAGE TYPE', type(x), x) print('-'*20, '\n') class Message(BaseModel): role : str = None metadata: dict = {} content : str = None def stream_response(message:str, history:list[dict]): if message is not None: input_message = HumanMessage(content=message) for event in app.stream({"messages": [input_message]}, config, stream_mode="values"): pretty_print(event) yield event["messages"][-1].content def clear_history() -> dict: global config session_id = str(uuid.uuid4()) if 'configurable' not in config: logger.debug('New config') config['configurable'] = {"thread_id": session_id} logger.debug(f'New config: {config}') system_msg = SystemMessage(system_messsage) # state:StateSnapshot = app.get_state(config) # StateSnapshot app.update_state(config, {'messages': system_msg}) with gr.Blocks(theme=gr.themes.Soft()) as demo: chatbot=gr.Chatbot( type='messages', height="80vh") gr.ChatInterface(stream_response, type='messages', chatbot=chatbot, textbox=gr.Textbox(placeholder="Enter your query...", container=False, autoscroll=True, scale=7), ) #chatbot.clear(clear_history) # Fails in huggingface if __name__ == "__main__": logger.debug('Started main') system_msg = SystemMessage(system_messsage) app.update_state(config, {'messages': system_msg}) demo.launch(share=False, debug=False) # TODO # When gradio reloads the application, the SystemMessage is lost.