Spaces:
Build error
Build error
| from openai import OpenAI | |
| from langchain_openai import ChatOpenAI | |
| from langchain_community.chat_models import ChatOllama | |
| from langchain_google_genai import ChatGoogleGenerativeAI | |
| from langchain_groq import ChatGroq | |
| try: | |
| from .utils.db import ( | |
| load_api_key, | |
| load_openai_url, | |
| load_model_settings, | |
| load_groq_api_key, | |
| load_google_api_key, | |
| ) | |
| from .custom_callback import customcallback | |
| from .llm_settings import llm_settings | |
| except ImportError: | |
| from utils.db import ( | |
| load_api_key, | |
| load_openai_url, | |
| load_model_settings, | |
| load_groq_api_key, | |
| load_google_api_key, | |
| ) | |
| from custom_callback import customcallback | |
| from llm_settings import llm_settings | |
| the_callback = customcallback(strip_tokens=False, answer_prefix_tokens=["Answer"]) | |
| def get_model(high_context=False): | |
| the_model = load_model_settings() | |
| the_api_key = load_api_key() | |
| the_groq_api_key = load_groq_api_key() | |
| the_google_api_key = load_google_api_key() | |
| the_openai_url = load_openai_url() | |
| def open_ai_base(high_context): | |
| if the_openai_url == "default": | |
| true_model = the_model | |
| if high_context: | |
| true_model = "gpt-4-turbo" | |
| return { | |
| "model": true_model, | |
| "api_key": the_api_key, | |
| "max_retries": 15, | |
| "streaming": True, | |
| "callbacks": [the_callback], | |
| } | |
| else: | |
| return { | |
| "model": the_model, | |
| "api_key": the_api_key, | |
| "max_retries": 15, | |
| "streaming": True, | |
| "callbacks": [the_callback], | |
| "base_url": the_openai_url, | |
| } | |
| args_mapping = { | |
| ChatOpenAI: open_ai_base(high_context=high_context), | |
| ChatOllama: {"model": the_model}, | |
| ChatGroq: { | |
| "temperature": 0, | |
| "model_name": the_model.replace("-groq", ""), | |
| "groq_api_key": the_openai_url, | |
| }, | |
| ChatGoogleGenerativeAI: { | |
| "model": the_model, | |
| "google_api_key": the_google_api_key, | |
| }, | |
| } | |
| model_mapping = {} | |
| for model_name, model_args in llm_settings.items(): | |
| the_tuple = None | |
| if model_args["provider"] == "openai": | |
| the_tuple = (ChatOpenAI, args_mapping[ChatOpenAI]) | |
| elif model_args["provider"] == "ollama": | |
| the_tuple = ( | |
| ChatOpenAI, | |
| { | |
| "api_key": "ollama", | |
| "base_url": "http://localhost:11434/v1", | |
| "model": model_name, | |
| }, | |
| ) | |
| elif model_args["provider"] == "google": | |
| the_tuple = (ChatGoogleGenerativeAI, args_mapping[ChatGoogleGenerativeAI]) | |
| elif model_args["provider"] == "groq": | |
| the_tuple = (ChatGroq, args_mapping[ChatGroq]) | |
| if the_tuple: | |
| model_mapping[model_name] = the_tuple | |
| model_class, args = model_mapping[the_model] | |
| return model_class(**args) if model_class else None | |
| def get_client(): | |
| the_api_key = load_api_key() | |
| the_openai_url = load_openai_url() | |
| if the_openai_url == "default": | |
| return OpenAI(api_key=the_api_key) | |
| else: | |
| return OpenAI(api_key=the_api_key, base_url=the_openai_url) | |