Spaces:
Sleeping
Sleeping
File size: 2,558 Bytes
686a009 88e12f1 686a009 88e12f1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 |
import os
from dotenv import load_dotenv
from utils.config_loader import load_config
from langchain_groq import ChatGroq
from langchain_openai import ChatOpenAI, AzureChatOpenAI
from azure.identity import AzureCliCredential, ManagedIdentityCredential
# Load .env file
load_dotenv()
class ConfigLoader:
def __init__(self):
print("Loading config...")
self.config = load_config()
def __getitem__(self, key):
return self.config[key]
class ModelLoader:
def __init__(self, model_provider: str = "azureopenai"):
print(f"Initializing ModelLoader with provider: {model_provider}")
self.model_provider = model_provider.lower()
self.config = ConfigLoader()
def load_llm(self):
self.model_provider = "openai"
print(f"LLM loading from provider: {self.model_provider}")
if self.model_provider == "groq":
print("β Using Groq")
groq_api_key = os.getenv("GROQ_API_KEY")
print(f"Groq API Key: {groq_api_key}")
if not groq_api_key:
raise ValueError("GROQ_API_KEY environment variable is not set.")
model_name = self.config["llm"]["groq"]["model_name"]
return ChatGroq(model=model_name, api_key=groq_api_key)
elif self.model_provider == "openai":
print("β Using OpenAI")
openai_api_key = os.getenv("OPENAI_API_KEY")
model_name = self.config["llm"]["openai"]["model_name"]
return ChatOpenAI(model_name=model_name, api_key=openai_api_key)
elif self.model_provider == "azureopenai":
print("β Using Azure OpenAI")
client_id = os.getenv("AZURE_MANAGED_IDENTITY_CLIENT_ID")
if client_id and len(client_id) > 1:
credential = ManagedIdentityCredential(client_id=client_id)
else:
credential = AzureCliCredential()
token = credential.get_token("https://cognitiveservices.azure.com/.default").token
if not token:
raise ValueError("Azure token could not be retrieved.")
return AzureChatOpenAI(
azure_endpoint=self.config["llm"]["azureopenai"]["endpoint"],
azure_deployment=self.config["llm"]["azureopenai"]["model_name"],
api_version=self.config["llm"]["azureopenai"]["api_version"],
api_key=token
)
else:
raise ValueError(f"Unsupported model provider: {self.model_provider}")
|