Spaces:
Sleeping
Sleeping
| import yaml | |
| from box import Box | |
| from langchain_community.chat_models import ChatOpenAI | |
| from langchain_community.llms import DeepInfra | |
| from langchain.schema import HumanMessage, SystemMessage, AIMessage | |
| import logging | |
| import os | |
| def chat_llm(messages, model, temperature, max_tokens, n, timeout=600, stop=None, return_tokens=False): | |
| if model.__contains__("gpt"): | |
| iterated_query = False | |
| try: | |
| chat = ChatOpenAI(model_name=model, | |
| temperature=temperature, | |
| max_tokens=max_tokens, | |
| n=n, | |
| request_timeout=timeout, | |
| model_kwargs={"seed": 0, | |
| "top_p": 0 | |
| }) | |
| except Exception as e: | |
| print(f"Error in loading model: {e}") | |
| return None | |
| else: | |
| # deepinfra | |
| iterated_query = True | |
| chat = ChatOpenAI(model_name=model, | |
| openai_api_key=None, | |
| temperature=temperature, | |
| max_tokens=max_tokens, | |
| n=1, | |
| request_timeout=timeout, | |
| openai_api_base="https://api.deepinfra.com/v1/openai") | |
| longchain_msgs = [] | |
| for msg in messages: | |
| if msg['role'] == 'system': | |
| longchain_msgs.append(SystemMessage(content=msg['content'])) | |
| elif msg['role'] == 'user': | |
| print('human message', msg) | |
| longchain_msgs.append(HumanMessage(content=msg['content'])) | |
| elif msg['role'] == 'assistant': | |
| print('8' * 20) | |
| print(msg['content']) | |
| longchain_msgs.append(AIMessage(content=msg['content'])) | |
| else: | |
| raise NotImplementedError | |
| # add an empty user message to avoid no user message error | |
| longchain_msgs.append(HumanMessage(content="")) | |
| if n > 1 and iterated_query: | |
| response_list = [] | |
| total_completion_tokens = 0 | |
| total_prompt_tokens = 0 | |
| for n in range(n): | |
| generations = chat.generate([longchain_msgs], stop=[ | |
| stop] if stop is not None else None) | |
| responses = [ | |
| chat_gen.message.content for chat_gen in generations.generations[0]] | |
| response_list.append(responses[0]) | |
| completion_tokens = generations.llm_output['token_usage']['completion_tokens'] | |
| prompt_tokens = generations.llm_output['token_usage']['prompt_tokens'] | |
| total_completion_tokens += completion_tokens | |
| total_prompt_tokens += prompt_tokens | |
| responses = response_list | |
| completion_tokens = total_completion_tokens | |
| prompt_tokens = total_prompt_tokens | |
| else: | |
| generations = chat.generate([longchain_msgs], stop=[ | |
| stop] if stop is not None else None) | |
| responses = [ | |
| chat_gen.message.content for chat_gen in generations.generations[0]] | |
| completion_tokens = generations.llm_output['token_usage']['completion_tokens'] | |
| prompt_tokens = generations.llm_output['token_usage']['prompt_tokens'] | |
| return { | |
| 'generations': responses, | |
| 'completion_tokens': completion_tokens, | |
| 'prompt_tokens': prompt_tokens | |
| } |