import os import gradio as gr from langchain_community.llms import OpenAI from langchain_core.messages import HumanMessage, SystemMessage from langchain_core.chat_history import BaseChatMessageHistory from langchain_community.chat_message_histories import ChatMessageHistory from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder from langchain_core.runnables.history import RunnableWithMessageHistory from langchain_openai.chat_models import ChatOpenAI GENERATE_ARGS = { 'temperature': max(float(os.getenv("TEMPERATURE", 0.2)), 1e-2), 'max_tokens': int(os.getenv("MAX_NEW_TOKENS", 1024)), } class Chat: def __init__(self, system_prompt: str): base = ChatOpenAI model = os.getenv("CHAT_MODEL") self.assistant_model = base( model=model, streaming=True, **GENERATE_ARGS ) self.store = {} self.prompt = ChatPromptTemplate.from_messages([ ("system", system_prompt), MessagesPlaceholder(variable_name="history"), ("human", "{input}") ]) self.runnable = self.prompt | self.assistant_model self.chat_model = RunnableWithMessageHistory( self.runnable, self.get_session_history, input_messages_key="input", history_messages_key="history", ) def format_prompt(self, system_prompt: str, user_prompt: str): messages = [ SystemMessage( content=system_prompt ), HumanMessage( content=user_prompt ), ] return messages def get_session_history(self, session_id: (str | int)) -> BaseChatMessageHistory: if session_id not in self.store: self.store[session_id] = ChatMessageHistory() return self.store[session_id] def stream(self, user_prompt: str, session_id: (str | int) = 0): try: stream_answer = self.chat_model.stream( {"input": user_prompt}, config={"configurable": {"session_id": session_id}}, ) output = "" for response in stream_answer: if type(self.assistant_model) == OpenAI: if response.choices[0].delta.content: output += response.choices[0].delta.content yield output else: output += response.content yield output except Exception as e: if "Too Many Requests" in str(e): raise gr.Error(f"Too many requests: {str(e)}") elif "Authorization header is invalid" in str(e): raise gr.Error("Authentication error: API token was either not provided or incorrect") else: raise gr.Error(f"Unhandled Exception: {str(e)}")