Spaces:
Build error
Build error
| # Documentation: | |
| # https://www.gradio.app/docs/gradio/chatinterface | |
| # https://www.gradio.app/docs/gradio/chatbot | |
| import logging | |
| logging.basicConfig(level=logging.WARNING) # Level for the root logger | |
| import os | |
| import json | |
| import uuid | |
| from dotenv import load_dotenv | |
| from pydantic import BaseModel | |
| from functools import lru_cache | |
| load_dotenv(verbose=True) | |
| from langchain.prompts import ChatPromptTemplate, PromptTemplate | |
| from langchain_core.runnables import RunnablePassthrough | |
| from langchain.schema import AIMessage, HumanMessage, SystemMessage | |
| from langchain_huggingface import HuggingFaceEmbeddings | |
| from langchain.chains.combine_documents import create_stuff_documents_chain | |
| from langchain_core.messages.tool import ToolMessage | |
| from langgraph.types import StateSnapshot | |
| import gradio as gr | |
| from workflow import app, memory | |
| logger = logging.getLogger(__name__) # Child logger for this module | |
| logger.setLevel(logging.DEBUG) | |
| config = {"configurable": {"thread_id": str(uuid.uuid4()) }} | |
| system_messsage ="""You are a helpful assistant. | |
| You only answer questions about data structures and algorithms, discrete math, and computer science in general. | |
| You can use search tools for finding information in the web to answer the user's question. | |
| Output your answers using markdown format and include links to the pages used in constructing the answer.""" | |
| def pretty_print(event:dict) -> None: | |
| msgs = event['messages'] | |
| for x in msgs: | |
| match x: | |
| case SystemMessage(): | |
| print('SystemMessage:', x.content[:80], sep='\n\t') | |
| case HumanMessage(): | |
| print('HumanMessage: ', x.content[:80], sep='\n\t') | |
| case AIMessage(): | |
| if x.additional_kwargs and 'tool_calls' in x.additional_kwargs: | |
| tool_calls = x.additional_kwargs['tool_calls'] | |
| print('AIMessage: ', 'tool_call') | |
| for call in tool_calls: | |
| print('\t','Name = ', call['function']['name'],' Args =', call['function']['arguments'][:80]) | |
| else: | |
| print('AIMessage: ', x.content[:80], sep='\n\t') | |
| case ToolMessage(): | |
| # print('ToolMessage', x.content[:80]) # Is a JSON string | |
| print('ToolMessage: ') | |
| try: | |
| l = json.loads(x.content) | |
| for d in l: | |
| print('\t', 'url', d['url']) | |
| print('\t', 'content', d['content'][:80]) | |
| except Exception as e: | |
| logger.error(str(e)) | |
| logger.error(x) | |
| case _: | |
| print('UNKNOWN MESSAGE TYPE', type(x), x) | |
| print('-'*20, '\n') | |
| class Message(BaseModel): | |
| role : str = None | |
| metadata: dict = {} | |
| content : str = None | |
| def stream_response(message:str, history:list[dict]): | |
| if message is not None: | |
| input_message = HumanMessage(content=message) | |
| for event in app.stream({"messages": [input_message]}, config, stream_mode="values"): | |
| pretty_print(event) | |
| yield event["messages"][-1].content | |
| def clear_history() -> dict: | |
| global config | |
| session_id = str(uuid.uuid4()) | |
| if 'configurable' not in config: | |
| logger.debug('New config') | |
| config['configurable'] = {"thread_id": session_id} | |
| logger.debug(f'New config: {config}') | |
| system_msg = SystemMessage(system_messsage) | |
| # state:StateSnapshot = app.get_state(config) # StateSnapshot | |
| app.update_state(config, {'messages': system_msg}) | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| chatbot=gr.Chatbot( | |
| type='messages', | |
| height="80vh") | |
| gr.ChatInterface(stream_response, | |
| type='messages', | |
| chatbot=chatbot, | |
| textbox=gr.Textbox(placeholder="Enter your query...", | |
| container=False, | |
| autoscroll=True, | |
| scale=7), | |
| ) | |
| #chatbot.clear(clear_history) # Fails in huggingface | |
| if __name__ == "__main__": | |
| logger.debug('Started main') | |
| system_msg = SystemMessage(system_messsage) | |
| app.update_state(config, {'messages': system_msg}) | |
| demo.launch(share=False, debug=False) | |
| # TODO | |
| # When gradio reloads the application, the SystemMessage is lost. | |