ha1772007's picture
Update chat.py
9d5ea64 verified
raw
history blame
2.99 kB
from langchain_openai import ChatOpenAI
import json
from langchain_core.prompts import ChatPromptTemplate
from langchain_groq import ChatGroq
from langchain.chains import ConversationChain
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_core.output_parsers import JsonOutputParser
from langchain_community.llms.cloudflare_workersai import CloudflareWorkersAI
import requests
from langchain_cohere import ChatCohere
def langchainConversation(conversation):
prompts = []
for message in conversation:
prompts.append((message['role'],message['context']))
chat_template = ChatPromptTemplate.from_messages(prompts)
return chat_template.format_messages()
def segmind_input_parser(input):
toreturn = []
for thisdict in input:
toreturn.append({'role':thisdict['role'],'content':thisdict['context']})
return toreturn
def segmind_output_parser(input):
return json.dumps({"content": input['choices'][0]['message']['content'], "additional_kwargs": {}, "response_metadata": {}, "type": "ai", "name": None, "id": input['id'], "example": False, "tool_calls": [], "invalid_tool_calls": [], "usage_metadata": {"input_tokens": input['usage']['prompt_tokens'], "output_tokens": input['usage']['completion_tokens'], "total_tokens": input['usage']['total_tokens']}},indent=4)
def converse(conversation,provider,model,key,other:dict={}):
if(provider=='groq'):
chat = ChatGroq(temperature=0, groq_api_key=key, model_name=model)
elif(provider=='gemini'):
chat = ChatGoogleGenerativeAI(model=model,google_api_key=key)
elif(provider=='cohere'):
chat = ChatCohere(model=model,cohere_api_key=key)
elif(provider=='lepton'):
url = f'https://{model}.lepton.run/api/v1/'
print(url)
chat = ChatOpenAI(openai_api_base = url,openai_api_key=key)
elif(provider == 'cloudflare'):
try:
account_id = key.split('~')[0]
api_token = key.split('~')[1]
except:
raise Exception('Invalid Accound Id or api token')
chat = CloudflareWorkersAI(account_id=account_id,api_token=api_token,model=model)
return json.dumps({'content':chat.invoke(langchainConversation(conversation))})
elif(provider == 'openrouter'):
chat = ChatOpenAI(
base_url="https://openrouter.ai/api/v1",
api_key=key,
model=model
)
elif(provider == 'segmind'):
url = f"https://api.segmind.com/v1/{model}"
# Request payload
data = {
"messages": segmind_input_parser(conversation)
}
response = requests.post(url, json=data, headers={'x-api-key': key})
output = json.loads(response.text)
print(json.dumps(output,indent=4))
return segmind_output_parser(output)
else:
return json.dumps({'content':'unspported Provider'})
return json.dumps(json.loads(chat.invoke(langchainConversation(conversation)).json()),indent=4)