import yaml from box import Box from langchain_community.chat_models import ChatOpenAI from langchain_community.llms import DeepInfra from langchain.schema import HumanMessage, SystemMessage, AIMessage import logging import os def chat_llm(messages, model, temperature, max_tokens, n, timeout=600, stop=None, return_tokens=False): if model.__contains__("gpt"): iterated_query = False try: chat = ChatOpenAI(model_name=model, temperature=temperature, max_tokens=max_tokens, n=n, request_timeout=timeout, model_kwargs={"seed": 0, "top_p": 0 }) except Exception as e: print(f"Error in loading model: {e}") return None else: # deepinfra iterated_query = True chat = ChatOpenAI(model_name=model, openai_api_key=None, temperature=temperature, max_tokens=max_tokens, n=1, request_timeout=timeout, openai_api_base="https://api.deepinfra.com/v1/openai") longchain_msgs = [] for msg in messages: if msg['role'] == 'system': longchain_msgs.append(SystemMessage(content=msg['content'])) elif msg['role'] == 'user': print('human message', msg) longchain_msgs.append(HumanMessage(content=msg['content'])) elif msg['role'] == 'assistant': print('8' * 20) print(msg['content']) longchain_msgs.append(AIMessage(content=msg['content'])) else: raise NotImplementedError # add an empty user message to avoid no user message error longchain_msgs.append(HumanMessage(content="")) if n > 1 and iterated_query: response_list = [] total_completion_tokens = 0 total_prompt_tokens = 0 for n in range(n): generations = chat.generate([longchain_msgs], stop=[ stop] if stop is not None else None) responses = [ chat_gen.message.content for chat_gen in generations.generations[0]] response_list.append(responses[0]) completion_tokens = generations.llm_output['token_usage']['completion_tokens'] prompt_tokens = generations.llm_output['token_usage']['prompt_tokens'] total_completion_tokens += completion_tokens total_prompt_tokens += prompt_tokens responses = response_list completion_tokens = total_completion_tokens prompt_tokens = total_prompt_tokens else: generations = chat.generate([longchain_msgs], stop=[ stop] if stop is not None else None) responses = [ chat_gen.message.content for chat_gen in generations.generations[0]] completion_tokens = generations.llm_output['token_usage']['completion_tokens'] prompt_tokens = generations.llm_output['token_usage']['prompt_tokens'] return { 'generations': responses, 'completion_tokens': completion_tokens, 'prompt_tokens': prompt_tokens }