Spaces:
Sleeping
Sleeping
| import os | |
| import gradio as gr | |
| from langchain_community.llms import OpenAI | |
| from langchain_core.messages import HumanMessage, SystemMessage | |
| from langchain_core.chat_history import BaseChatMessageHistory | |
| from langchain_community.chat_message_histories import ChatMessageHistory | |
| from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder | |
| from langchain_core.runnables.history import RunnableWithMessageHistory | |
| from langchain_openai.chat_models import ChatOpenAI | |
| GENERATE_ARGS = { | |
| 'temperature': max(float(os.getenv("TEMPERATURE", 0.2)), 1e-2), | |
| 'max_tokens': int(os.getenv("MAX_NEW_TOKENS", 1024)), | |
| } | |
| class Chat: | |
| def __init__(self, system_prompt: str): | |
| base = ChatOpenAI | |
| model = os.getenv("CHAT_MODEL") | |
| self.assistant_model = base( | |
| model=model, | |
| streaming=True, | |
| **GENERATE_ARGS | |
| ) | |
| self.store = {} | |
| self.prompt = ChatPromptTemplate.from_messages([ | |
| ("system", system_prompt), | |
| MessagesPlaceholder(variable_name="history"), | |
| ("human", "{input}") | |
| ]) | |
| self.runnable = self.prompt | self.assistant_model | |
| self.chat_model = RunnableWithMessageHistory( | |
| self.runnable, | |
| self.get_session_history, | |
| input_messages_key="input", | |
| history_messages_key="history", | |
| ) | |
| def format_prompt(self, system_prompt: str, user_prompt: str): | |
| messages = [ | |
| SystemMessage( | |
| content=system_prompt | |
| ), | |
| HumanMessage( | |
| content=user_prompt | |
| ), | |
| ] | |
| return messages | |
| def get_session_history(self, session_id: (str | int)) -> BaseChatMessageHistory: | |
| if session_id not in self.store: | |
| self.store[session_id] = ChatMessageHistory() | |
| return self.store[session_id] | |
| def stream(self, user_prompt: str, session_id: (str | int) = 0): | |
| try: | |
| stream_answer = self.chat_model.stream( | |
| {"input": user_prompt}, | |
| config={"configurable": {"session_id": session_id}}, | |
| ) | |
| output = "" | |
| for response in stream_answer: | |
| if type(self.assistant_model) == OpenAI: | |
| if response.choices[0].delta.content: | |
| output += response.choices[0].delta.content | |
| yield output | |
| else: | |
| output += response.content | |
| yield output | |
| except Exception as e: | |
| if "Too Many Requests" in str(e): | |
| raise gr.Error(f"Too many requests: {str(e)}") | |
| elif "Authorization header is invalid" in str(e): | |
| raise gr.Error("Authentication error: API token was either not provided or incorrect") | |
| else: | |
| raise gr.Error(f"Unhandled Exception: {str(e)}") | |