mlstocks / backend /app /core /model_factory.py
github-actions[bot]
Deploy to Hugging Face Space
abf702c
import os
from typing import Optional, Dict
class AutoGenModelFactory:
"""
Factory for creating AutoGen compatible model instances.
"""
@staticmethod
def get_model(provider: str = "openai",
model_name: str = "gpt-4o",
temperature: float = 0,
model_info: Optional[Dict] = None
):
"""
Returns an AutoGen OpenAIChatCompletionClient instance.
"""
# Lazy import to avoid dependency issues if autogen is not installed
try:
from autogen_ext.models.openai import OpenAIChatCompletionClient
except ImportError as e:
raise ImportError("AutoGen libraries (autogen-agentchat, autogen-ext[openai]) are not installed.") from e
# ----------------------------------------------------------------------
# OPENAI
# ----------------------------------------------------------------------
if provider.lower() == "openai":
return OpenAIChatCompletionClient(
model=model_name,
api_key=os.environ.get("OPENAI_API_KEY"),
temperature=temperature,
)
# ----------------------------------------------------------------------
# GOOGLE (GEMINI) via OpenAI Compat
# ----------------------------------------------------------------------
elif provider.lower() == "google" or provider.lower() == "gemini":
api_key = os.environ.get("GOOGLE_API_KEY")
return OpenAIChatCompletionClient(
model=model_name,
base_url="https://generativelanguage.googleapis.com/v1beta/openai/",
api_key=api_key,
temperature=temperature,
model_info={
"family": "gemini",
"vision": False,
"function_calling": True,
"json_output": True,
"structured_output": True
},
# Safety: Ensure n=1 and explicit tokens to prevent shim errors
extra_body={"n": 1},
)
# ----------------------------------------------------------------------
# GROQ
# ----------------------------------------------------------------------
elif provider.lower() == "groq":
return OpenAIChatCompletionClient(
model=model_name,
base_url="https://api.groq.com/openai/v1",
api_key=os.environ.get("GROQ_API_KEY"),
temperature=temperature,
)
# ----------------------------------------------------------------------
# OLLAMA
# ----------------------------------------------------------------------
elif provider.lower() == "ollama":
# Ensure model_info defaults to empty dict if None
info = model_info if model_info is not None else {}
base_url = os.environ.get("OLLAMA_BASE_URL", "http://localhost:11434/v1")
return OpenAIChatCompletionClient(
model=model_name,
base_url=base_url,
api_key="ollama", # dummy key
model_info=info,
temperature=temperature,
)
else:
raise ValueError(f"Unsupported AutoGen provider: {provider}")