Spaces:
Runtime error
Runtime error
| from openai import OpenAI | |
| from langchain_openai import ChatOpenAI | |
| from langchain_community.chat_models import ChatOllama | |
| from langchain_google_genai import ChatGoogleGenerativeAI | |
| from langchain_groq import ChatGroq | |
| try: | |
| from .utils.db import load_api_key, load_openai_url, load_model_settings, load_groq_api_key, load_google_api_key | |
| from .custom_callback import customcallback | |
| except ImportError: | |
| from utils.db import load_api_key, load_openai_url, load_model_settings, load_groq_api_key, load_google_api_key | |
| from custom_callback import customcallback | |
| the_callback = customcallback(strip_tokens=False, answer_prefix_tokens=["Answer"]) | |
| def get_model(high_context=False): | |
| the_model = load_model_settings() | |
| the_api_key = load_api_key() | |
| the_groq_api_key = load_groq_api_key() | |
| the_google_api_key = load_google_api_key() | |
| the_openai_url = load_openai_url() | |
| def open_ai_base(high_context): | |
| if the_openai_url == "default": | |
| true_model = the_model | |
| if high_context: | |
| true_model = "gpt-4-turbo" | |
| return {"model": true_model, "api_key": the_api_key, "max_retries":15, "streaming":True, "callbacks":[the_callback]} | |
| else: | |
| return {"model": the_model, "api_key": the_api_key, "max_retries":15, "streaming":True, "callbacks":[the_callback], "base_url": the_openai_url} | |
| args_mapping = { | |
| ChatOpenAI: open_ai_base(high_context=high_context), | |
| ChatOllama: {"model": the_model}, | |
| ChatGroq: {"temperature": 0, "model_name": the_model.replace("-groq", ""), "groq_api_key": the_openai_url}, | |
| ChatGoogleGenerativeAI:{"model": the_model, "google_api_key": the_google_api_key} | |
| } | |
| model_mapping = { | |
| # OpenAI | |
| "gpt-4o": (ChatOpenAI, args_mapping[ChatOpenAI]), | |
| "gpt-4-turbo": (ChatOpenAI, args_mapping[ChatOpenAI]), | |
| "gpt-3.5": (ChatOpenAI, args_mapping[ChatOpenAI]), | |
| "gpt-3.5-turbo": (ChatOpenAI, args_mapping[ChatOpenAI]), | |
| # Google Generative AI - Llama | |
| "llava": (ChatOllama, args_mapping[ChatOllama]), | |
| "llama3": (ChatOllama, args_mapping[ChatOllama]), | |
| "bakllava": (ChatOllama, args_mapping[ChatOllama]), | |
| # Google Generative AI - Gemini | |
| "gemini-pro": (ChatGoogleGenerativeAI, args_mapping[ChatGoogleGenerativeAI]), | |
| # Groq | |
| "mixtral-8x7b-groq": (ChatGroq, args_mapping[ChatGroq]) | |
| } | |
| model_class, args = model_mapping[the_model] | |
| return model_class(**args) if model_class else None | |
| def get_client(): | |
| the_api_key = load_api_key() | |
| the_openai_url = load_openai_url() | |
| if the_openai_url == "default": | |
| return OpenAI(api_key=the_api_key) | |
| else: | |
| return OpenAI(api_key=the_api_key, base_url=the_openai_url) | |