rag_test_task / src /chat.py
alexandraroze's picture
added readme
87370c1
import os
import gradio as gr
from langchain_community.llms import OpenAI
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_community.chat_message_histories import ChatMessageHistory
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain_openai.chat_models import ChatOpenAI
GENERATE_ARGS = {
'temperature': max(float(os.getenv("TEMPERATURE", 0.2)), 1e-2),
'max_tokens': int(os.getenv("MAX_NEW_TOKENS", 1024)),
}
class Chat:
def __init__(self, system_prompt: str):
base = ChatOpenAI
model = os.getenv("CHAT_MODEL")
self.assistant_model = base(
model=model,
streaming=True,
**GENERATE_ARGS
)
self.store = {}
self.prompt = ChatPromptTemplate.from_messages([
("system", system_prompt),
MessagesPlaceholder(variable_name="history"),
("human", "{input}")
])
self.runnable = self.prompt | self.assistant_model
self.chat_model = RunnableWithMessageHistory(
self.runnable,
self.get_session_history,
input_messages_key="input",
history_messages_key="history",
)
def format_prompt(self, system_prompt: str, user_prompt: str):
messages = [
SystemMessage(
content=system_prompt
),
HumanMessage(
content=user_prompt
),
]
return messages
def get_session_history(self, session_id: (str | int)) -> BaseChatMessageHistory:
if session_id not in self.store:
self.store[session_id] = ChatMessageHistory()
return self.store[session_id]
def stream(self, user_prompt: str, session_id: (str | int) = 0):
try:
stream_answer = self.chat_model.stream(
{"input": user_prompt},
config={"configurable": {"session_id": session_id}},
)
output = ""
for response in stream_answer:
if type(self.assistant_model) == OpenAI:
if response.choices[0].delta.content:
output += response.choices[0].delta.content
yield output
else:
output += response.content
yield output
except Exception as e:
if "Too Many Requests" in str(e):
raise gr.Error(f"Too many requests: {str(e)}")
elif "Authorization header is invalid" in str(e):
raise gr.Error("Authentication error: API token was either not provided or incorrect")
else:
raise gr.Error(f"Unhandled Exception: {str(e)}")