Spaces:
Sleeping
Sleeping
File size: 2,917 Bytes
197a291 5766729 197a291 0be2042 197a291 87370c1 197a291 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 | import os
import gradio as gr
from langchain_community.llms import OpenAI
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_community.chat_message_histories import ChatMessageHistory
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain_openai.chat_models import ChatOpenAI
GENERATE_ARGS = {
'temperature': max(float(os.getenv("TEMPERATURE", 0.2)), 1e-2),
'max_tokens': int(os.getenv("MAX_NEW_TOKENS", 1024)),
}
class Chat:
def __init__(self, system_prompt: str):
base = ChatOpenAI
model = os.getenv("CHAT_MODEL")
self.assistant_model = base(
model=model,
streaming=True,
**GENERATE_ARGS
)
self.store = {}
self.prompt = ChatPromptTemplate.from_messages([
("system", system_prompt),
MessagesPlaceholder(variable_name="history"),
("human", "{input}")
])
self.runnable = self.prompt | self.assistant_model
self.chat_model = RunnableWithMessageHistory(
self.runnable,
self.get_session_history,
input_messages_key="input",
history_messages_key="history",
)
def format_prompt(self, system_prompt: str, user_prompt: str):
messages = [
SystemMessage(
content=system_prompt
),
HumanMessage(
content=user_prompt
),
]
return messages
def get_session_history(self, session_id: (str | int)) -> BaseChatMessageHistory:
if session_id not in self.store:
self.store[session_id] = ChatMessageHistory()
return self.store[session_id]
def stream(self, user_prompt: str, session_id: (str | int) = 0):
try:
stream_answer = self.chat_model.stream(
{"input": user_prompt},
config={"configurable": {"session_id": session_id}},
)
output = ""
for response in stream_answer:
if type(self.assistant_model) == OpenAI:
if response.choices[0].delta.content:
output += response.choices[0].delta.content
yield output
else:
output += response.content
yield output
except Exception as e:
if "Too Many Requests" in str(e):
raise gr.Error(f"Too many requests: {str(e)}")
elif "Authorization header is invalid" in str(e):
raise gr.Error("Authentication error: API token was either not provided or incorrect")
else:
raise gr.Error(f"Unhandled Exception: {str(e)}")
|