ollive-api / llm /factory.py
Karthik Namboori
Deploy ollive FastAPI Docker Space
7b4b748
import logging
from langchain_anthropic import ChatAnthropic
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
from langchain_openai import ChatOpenAI
from transformers import GenerationConfig
from config import FrontierConfig, OSSConfig
from llm.cleaned_hf_chat import CleanedChatHuggingFace
logger = logging.getLogger(__name__)
def _configure_local_generation(llm: CleanedChatHuggingFace, config: OSSConfig) -> None:
"""Store generation settings on the pipeline config to avoid duplicate kwargs warnings."""
hf_pipeline = getattr(llm.llm, "pipeline", None)
if hf_pipeline is None:
return
hf_pipeline.generation_config = GenerationConfig(
max_new_tokens=config.max_tokens,
temperature=config.temperature,
do_sample=config.temperature > 0,
)
runtime_kwargs = {
"return_full_text": False,
"clean_up_tokenization_spaces": False,
"truncation": False,
}
llm.llm.pipeline_kwargs = runtime_kwargs
def build_oss_llm(config: OSSConfig) -> BaseChatModel:
backend = config.backend.lower()
if backend == "api":
return _build_oss_api_llm(config)
if backend == "local":
return _build_oss_local_llm(config)
raise ValueError(f"Unknown OSS_BACKEND '{config.backend}'. Use 'local' or 'api'.")
def _build_oss_local_llm(config: OSSConfig) -> BaseChatModel:
"""Run the model locally via transformers (works for small models like 0.5B)."""
model_kwargs: dict = {}
if config.hf_token:
model_kwargs["token"] = config.hf_token
pipeline_kwargs = {
"return_full_text": False,
"clean_up_tokenization_spaces": False,
"truncation": False,
}
llm = CleanedChatHuggingFace.from_model_id(
model_id=config.model_id,
backend="pipeline",
task="text-generation",
device=config.device,
model_kwargs=model_kwargs or None,
pipeline_kwargs=pipeline_kwargs,
)
_configure_local_generation(llm, config)
device_label = "CPU" if config.device < 0 else f"cuda:{config.device}"
logger.info(
"Built local OSS LangChain LLM model=%s device=%s",
config.model_id,
device_label,
)
return llm
def _build_oss_api_llm(config: OSSConfig) -> BaseChatModel:
"""Use Hugging Face Inference API (only models enabled on your HF account)."""
if not config.hf_token:
raise ValueError(
"HF_TOKEN is required for OSS_BACKEND=api. "
"Get one at https://huggingface.co/settings/tokens"
)
endpoint = HuggingFaceEndpoint(
repo_id=config.model_id,
huggingfacehub_api_token=config.hf_token,
max_new_tokens=config.max_tokens,
temperature=config.temperature,
return_full_text=False,
task="text-generation",
)
llm = CleanedChatHuggingFace(llm=endpoint)
logger.info("Built HF API OSS LangChain LLM model=%s", config.model_id)
return llm
def build_frontier_llm(config: FrontierConfig) -> BaseChatModel:
provider = config.provider.lower()
if provider == "anthropic":
if not config.anthropic_api_key:
raise ValueError(
"ANTHROPIC_API_KEY is required when FRONTIER_PROVIDER=anthropic"
)
llm = ChatAnthropic(
model=config.model_id,
api_key=config.anthropic_api_key,
max_tokens=config.max_tokens,
temperature=config.temperature,
)
else:
if not config.api_key:
raise ValueError("OPENAI_API_KEY is required. Set it in ollive/.env")
llm = ChatOpenAI(
model=config.model_id,
api_key=config.api_key,
max_tokens=config.max_tokens,
temperature=config.temperature,
)
logger.info("Built frontier LangChain LLM provider=%s model=%s", provider, config.model_id)
return llm