File size: 2,917 Bytes
197a291
 
 
 
 
 
 
 
 
 
 
 
 
5766729
 
197a291
 
 
 
 
 
 
 
0be2042
197a291
 
 
 
87370c1
197a291
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import os

import gradio as gr
from langchain_community.llms import OpenAI
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_community.chat_message_histories import ChatMessageHistory
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain_openai.chat_models import ChatOpenAI


GENERATE_ARGS = {
    'temperature': max(float(os.getenv("TEMPERATURE", 0.2)), 1e-2),
    'max_tokens': int(os.getenv("MAX_NEW_TOKENS", 1024)),
}


class Chat:

    def __init__(self, system_prompt: str):

        base = ChatOpenAI
        model = os.getenv("CHAT_MODEL")

        self.assistant_model = base(
            model=model,
            streaming=True,
            **GENERATE_ARGS
        )

        self.store = {}

        self.prompt = ChatPromptTemplate.from_messages([
            ("system", system_prompt),
            MessagesPlaceholder(variable_name="history"),
            ("human", "{input}")
        ])
        self.runnable = self.prompt | self.assistant_model

        self.chat_model = RunnableWithMessageHistory(
            self.runnable,
            self.get_session_history,
            input_messages_key="input",
            history_messages_key="history",
        )

    def format_prompt(self, system_prompt: str, user_prompt: str):
        messages = [
            SystemMessage(
                content=system_prompt
            ),
            HumanMessage(
                content=user_prompt
            ),
        ]

        return messages

    def get_session_history(self, session_id: (str | int)) -> BaseChatMessageHistory:
        if session_id not in self.store:
            self.store[session_id] = ChatMessageHistory()
        return self.store[session_id]

    def stream(self, user_prompt: str, session_id: (str | int) = 0):
        try:

            stream_answer = self.chat_model.stream(
                {"input": user_prompt},
                config={"configurable": {"session_id": session_id}},
            )
            output = ""
            for response in stream_answer:
                if type(self.assistant_model) == OpenAI:
                    if response.choices[0].delta.content:
                        output += response.choices[0].delta.content
                        yield output
                else:
                    output += response.content
                    yield output

        except Exception as e:
            if "Too Many Requests" in str(e):
                raise gr.Error(f"Too many requests: {str(e)}")
            elif "Authorization header is invalid" in str(e):
                raise gr.Error("Authentication error: API token was either not provided or incorrect")
            else:
                raise gr.Error(f"Unhandled Exception: {str(e)}")