File size: 4,405 Bytes
0520cd8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84710e6
0520cd8
 
 
 
 
 
 
 
 
9c10b6e
0520cd8
 
9c10b6e
0520cd8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c002da4
0520cd8
84710e6
 
 
0520cd8
 
 
 
6737d5a
0520cd8
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133

# Documentation:
# https://www.gradio.app/docs/gradio/chatinterface
# https://www.gradio.app/docs/gradio/chatbot



import logging
logging.basicConfig(level=logging.WARNING)      # Level for the root logger

import os
import json
import uuid
from dotenv import load_dotenv
from pydantic import BaseModel
from functools import lru_cache

load_dotenv(verbose=True)

from langchain.prompts import ChatPromptTemplate, PromptTemplate
from langchain_core.runnables import RunnablePassthrough
from langchain.schema import AIMessage, HumanMessage, SystemMessage
from langchain_huggingface import HuggingFaceEmbeddings
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain_core.messages.tool import ToolMessage
from langgraph.types import StateSnapshot

import gradio as gr

from workflow import app, memory


logger = logging.getLogger(__name__)            # Child logger for this module
logger.setLevel(logging.DEBUG)

config = {"configurable": {"thread_id": str(uuid.uuid4()) }}

system_messsage ="""You are a helpful assistant.
You only answer questions about data structures and algorithms, discrete math, and computer science in general.
You can use search tools for finding information in the web to answer the user's question.
Output your answers using markdown format and include links to the pages used in constructing the answer."""


def pretty_print(event:dict) -> None:
    msgs = event['messages']
    for x in msgs:
        match x:
            case SystemMessage():
                print('SystemMessage:', x.content[:80], sep='\n\t')
            case HumanMessage():
                print('HumanMessage: ', x.content[:80], sep='\n\t')
            case AIMessage():
                if x.additional_kwargs and 'tool_calls' in x.additional_kwargs:
                    tool_calls = x.additional_kwargs['tool_calls']
                    print('AIMessage:    ', 'tool_call')
                    for call in tool_calls:
                        print('\t','Name = ', call['function']['name'],' Args =', call['function']['arguments'][:80])
                else:
                    print('AIMessage:    ', x.content[:80], sep='\n\t')
            case ToolMessage():
                # print('ToolMessage', x.content[:80])        # Is a JSON string
                print('ToolMessage:  ')
                try:
                    l = json.loads(x.content)
                    for d in l:
                        print('\t', 'url', d['url'])
                        print('\t', 'content', d['content'][:80])
                except Exception as e:
                    logger.error(str(e))
                    logger.error(x)
            case _:
                print('UNKNOWN MESSAGE TYPE', type(x), x)
    print('-'*20, '\n')


class Message(BaseModel):
    role : str = None
    metadata: dict = {}
    content : str = None


def stream_response(message:str, history:list[dict]):

    if message is not None:
        input_message = HumanMessage(content=message)
        for event in app.stream({"messages": [input_message]}, config, stream_mode="values"):
            pretty_print(event)
            yield event["messages"][-1].content


def clear_history() -> dict:
    global config
    session_id = str(uuid.uuid4())
    if 'configurable' not in config:
        logger.debug('New config')
    config['configurable'] = {"thread_id": session_id}
    logger.debug(f'New config: {config}')

    system_msg = SystemMessage(system_messsage)
    # state:StateSnapshot = app.get_state(config) # StateSnapshot
    app.update_state(config, {'messages': system_msg})



with gr.Blocks(theme=gr.themes.Soft()) as demo:
    chatbot=gr.Chatbot(
        type='messages',
        height="80vh")
    gr.ChatInterface(stream_response,
                    type='messages',
                    chatbot=chatbot,
                    textbox=gr.Textbox(placeholder="Enter your query...",
                                    container=False,
                                    autoscroll=True,
                                    scale=7),
    )
    #chatbot.clear(clear_history) # Fails in huggingface



if __name__ == "__main__":
    logger.debug('Started main')

    system_msg = SystemMessage(system_messsage)
    app.update_state(config, {'messages': system_msg})
    demo.launch(share=False, debug=False)


# TODO
# When gradio reloads the application, the SystemMessage is lost.