chatbot-mimic-notes / src /agents /chat_llm.py
Jesse Liu
Fix: Add boto3 dependency, fix LangChain deprecation warning
a6243bb
import yaml
from box import Box
from langchain_community.chat_models import ChatOpenAI
from langchain_community.llms import DeepInfra
from langchain.schema import HumanMessage, SystemMessage, AIMessage
import logging
import os
def chat_llm(messages, model, temperature, max_tokens, n, timeout=600, stop=None, return_tokens=False):
if model.__contains__("gpt"):
iterated_query = False
try:
chat = ChatOpenAI(model_name=model,
temperature=temperature,
max_tokens=max_tokens,
n=n,
request_timeout=timeout,
model_kwargs={"seed": 0,
"top_p": 0
})
except Exception as e:
print(f"Error in loading model: {e}")
return None
else:
# deepinfra
iterated_query = True
chat = ChatOpenAI(model_name=model,
openai_api_key=None,
temperature=temperature,
max_tokens=max_tokens,
n=1,
request_timeout=timeout,
openai_api_base="https://api.deepinfra.com/v1/openai")
longchain_msgs = []
for msg in messages:
if msg['role'] == 'system':
longchain_msgs.append(SystemMessage(content=msg['content']))
elif msg['role'] == 'user':
print('human message', msg)
longchain_msgs.append(HumanMessage(content=msg['content']))
elif msg['role'] == 'assistant':
print('8' * 20)
print(msg['content'])
longchain_msgs.append(AIMessage(content=msg['content']))
else:
raise NotImplementedError
# add an empty user message to avoid no user message error
longchain_msgs.append(HumanMessage(content=""))
if n > 1 and iterated_query:
response_list = []
total_completion_tokens = 0
total_prompt_tokens = 0
for n in range(n):
generations = chat.generate([longchain_msgs], stop=[
stop] if stop is not None else None)
responses = [
chat_gen.message.content for chat_gen in generations.generations[0]]
response_list.append(responses[0])
completion_tokens = generations.llm_output['token_usage']['completion_tokens']
prompt_tokens = generations.llm_output['token_usage']['prompt_tokens']
total_completion_tokens += completion_tokens
total_prompt_tokens += prompt_tokens
responses = response_list
completion_tokens = total_completion_tokens
prompt_tokens = total_prompt_tokens
else:
generations = chat.generate([longchain_msgs], stop=[
stop] if stop is not None else None)
responses = [
chat_gen.message.content for chat_gen in generations.generations[0]]
completion_tokens = generations.llm_output['token_usage']['completion_tokens']
prompt_tokens = generations.llm_output['token_usage']['prompt_tokens']
return {
'generations': responses,
'completion_tokens': completion_tokens,
'prompt_tokens': prompt_tokens
}