elaineaishophouse's picture
Upload 15 files
441d880 verified
raw
history blame
7.85 kB
from langchain_openai import ChatOpenAI
from langchain_groq import ChatGroq
from langchain_together import ChatTogether
from crewai import LLM
from Config import Config
import os
# ========== PUBLIC INTERFACE ==========
def get_respondent_agent_llm_instance(model_type=None):
# Default to Config if model_type is not specified
if not model_type:
model_type = Config.respondent_agent_host
model = Config.respondent_agent_model
api_key = Config.respondent_agent_api_key
url = Config.respondent_agent_url
temperature = Config.respondent_agent_temperature
top_p = Config.respondent_agent_top_p
frequency_penalty = Config.respondent_agent_frequency_penalty
presence_penalty = Config.respondent_agent_presence_penalty
# If model_type is specified, determine the prefix (e.g., "GROQ" for model_type="groq") to fetch values from env
else:
prefix = model_type.upper()
model = os.getenv(f"{prefix}_AGENT_MODEL")
api_key = os.getenv(f"{prefix}_API_KEY")
url = os.getenv(f"{prefix}_URL")
temperature = float(os.getenv(f"{prefix}_TEMPERATURE", 0.7))
top_p = float(os.getenv(f"{prefix}_TOP_P", 1.0))
frequency_penalty = float(os.getenv(f"{prefix}_FREQUENCY_PENALTY", 0.0))
presence_penalty = float(os.getenv(f"{prefix}_PRESENCE_PENALTY", 0.0))
if not api_key:
raise ValueError(f"API key not found for model_type={model_type}.")
if not model:
raise ValueError(f"Model not found for model_type={model_type}.")
print(f"Respondent Agent LLM: model_type={model_type}, model={model}, api_key={'*****' if api_key else 'MISSING'}, url={url}")
print(f"Params: temperature={temperature}, top_p={top_p}, frequency_penalty={frequency_penalty}, presence_penalty={presence_penalty}")
return get_crewai_instance(model_type, model, api_key, url, temperature, top_p, frequency_penalty, presence_penalty)
def get_processing_agent_llm_instance(model_type=None):
# Default to Config if model_type not specified
if not model_type:
model_type = Config.processing_agent_host
model = Config.processing_agent_model
api_key = Config.processing_agent_api_key
url = Config.processing_agent_url
temperature = Config.processing_agent_temperature
top_p = Config.processing_agent_top_p
frequency_penalty = Config.processing_agent_frequency_penalty
presence_penalty = Config.processing_agent_presence_penalty
# If model_type is specified, determine the prefix (e.g., "GROQ" for model_type="groq") to fetch values from env
else:
prefix = model_type.upper()
model = os.getenv(f"{prefix}_AGENT_MODEL")
api_key = os.getenv(f"{prefix}_API_KEY")
url = os.getenv(f"{prefix}_URL")
temperature = float(os.getenv(f"{prefix}_TEMPERATURE", 0.7))
top_p = float(os.getenv(f"{prefix}_TOP_P", 1.0))
frequency_penalty = float(os.getenv(f"{prefix}_FREQUENCY_PENALTY", 0.0))
presence_penalty = float(os.getenv(f"{prefix}_PRESENCE_PENALTY", 0.0))
if not api_key:
raise ValueError(f"API key not found for model_type={model_type}.")
if not model:
raise ValueError(f"Model not found for model_type={model_type}.")
print(f"Processing Agent LLM: model_type={model_type}, model={model}, api_key={'*****' if api_key else 'MISSING'}, url={url}")
print(f"Params: temperature={temperature}, top_p={top_p}, frequency_penalty={frequency_penalty}, presence_penalty={presence_penalty}")
return get_crewai_instance(model_type, model, api_key, url, temperature, top_p, frequency_penalty, presence_penalty)
def get_processor_llm_instance(model_type=None):
# Default to Config if model_type not specified
if not model_type:
model_type = Config.processor_host
model = Config.processor_model
api_key = Config.processor_api_key
url = Config.processor_url
temperature = Config.processor_temperature
top_p = Config.processor_top_p
frequency_penalty = Config.processor_frequency_penalty
presence_penalty = Config.processor_presence_penalty
# If model_type is specified, determine the prefix (e.g., "GROQ" for model_type="groq") to fetch values from env
else:
prefix = model_type.upper()
model = os.getenv(f"{prefix}_AGENT_MODEL")
api_key = os.getenv(f"{prefix}_API_KEY")
url = os.getenv(f"{prefix}_URL")
temperature = float(os.getenv(f"{prefix}_TEMPERATURE", 0.7))
top_p = float(os.getenv(f"{prefix}_TOP_P", 1.0))
frequency_penalty = float(os.getenv(f"{prefix}_FREQUENCY_PENALTY", 0.0))
presence_penalty = float(os.getenv(f"{prefix}_PRESENCE_PENALTY", 0.0))
if not api_key:
raise ValueError(f"API key not found for model_type={model_type}.")
if not model:
raise ValueError(f"Model not found for model_type={model_type}.")
print(f"Processor LLM: model_type={model_type}, model={model}, api_key={'*****' if api_key else 'MISSING'}, url={url}")
print(f"Params: temperature={temperature}, top_p={top_p}, frequency_penalty={frequency_penalty}, presence_penalty={presence_penalty}")
return get_langchain_instance(model_type, model, api_key, url, temperature, top_p, frequency_penalty, presence_penalty)
# ========== INTERNAL HELPERS ==========
def get_crewai_instance(model_type, model, api_key, url, temperature, top_p, frequency_penalty, presence_penalty):
model_type = model_type.lower()
if model_type == 'groq':
return ChatGroq(groq_api_key=api_key, model_name=f"{model_type}/{model}", temperature=temperature, model_kwargs={})
common_args = {
"temperature": temperature,
"top_p": top_p,
"frequency_penalty": frequency_penalty,
"presence_penalty": presence_penalty
}
common_args = {k: v for k, v in common_args.items() if v is not None} # Remove None values
if model_type == 'openai':
return ChatOpenAI(model=model, api_key=api_key, **common_args)
elif model_type == 'openrouter':
return ChatOpenAI(base_url=url, model=f"{model_type}/{model}", api_key=api_key, **common_args)
elif model_type == 'together_ai':
return LLM(model=f"{model_type}/{model}", api_key=api_key, api_base=url, **common_args)
else:
raise ValueError(f"Unsupported model type for CrewAI: {model_type}")
def get_langchain_instance(model_type, model, api_key, url, temperature, top_p, frequency_penalty, presence_penalty):
model_type = model_type.lower()
if model_type == 'groq':
return ChatGroq(groq_api_key=api_key, model_name=model, temperature=temperature, model_kwargs={})
common_args = {
"temperature": temperature,
"top_p": top_p,
"frequency_penalty": frequency_penalty,
"presence_penalty": presence_penalty
}
common_args = {k: v for k, v in common_args.items() if v is not None} # Remove None values
if model_type == 'openai':
return ChatOpenAI(model=model, api_key=api_key, **common_args)
elif model_type == 'openrouter':
return ChatOpenAI(base_url=url, model=model, api_key=api_key, **common_args)
elif model_type == 'together_ai':
return ChatTogether(model=model, together_api_key=api_key, **common_args)
else:
raise ValueError(f"Unsupported model type for LangChain: {model_type}")