Spaces:
Sleeping
Sleeping
| """ | |
| Factory functions for API-backed LLM clients. | |
| Detects provider and key, returns an API-based LLM instance. | |
| """ | |
| import os | |
| import requests | |
| from openai import (api_key, ChatCompletion) | |
| import anthropic | |
| from llama_index.core import Settings | |
| class OpenAI: | |
| def __init__(self, api_key: str, temperature: float = 0.7, model_name: str = "gpt-3.5-turbo"): | |
| api_key = api_key | |
| self.temperature = temperature | |
| self.model_name = model_name | |
| def complete(self, prompt: str): | |
| resp = ChatCompletion.create( | |
| model=self.model_name, | |
| messages=[{"role": "user", "content": prompt}], | |
| temperature=self.temperature | |
| ) | |
| class Response: pass | |
| result = Response() | |
| result.text = resp.choices[0].message["content"] | |
| return result | |
| class Anthropic: | |
| def __init__(self, api_key: str, temperature: float = 0.7, model_name: str = "claude-2"): | |
| self.client = anthropic.Client(api_key) | |
| self.temperature = temperature | |
| self.model_name = model_name | |
| def complete(self, prompt: str): | |
| resp = self.client.completions.create( | |
| model=self.model_name, | |
| prompt=prompt, | |
| max_tokens_to_sample=256, | |
| temperature=self.temperature | |
| ) | |
| class Response: pass | |
| result = Response() | |
| result.text = resp.completion | |
| return result | |
| class MistralAPI: | |
| def __init__(self, api_key: str, temperature: float = 0.7, model_name: str = "mistral-large"): | |
| self.api_key = api_key | |
| self.temperature = temperature | |
| self.model_name = model_name | |
| self.endpoint = f"https://api.mistral.ai/v1/models/{self.model_name}/completions" | |
| def complete(self, prompt: str): | |
| headers = {"Authorization": f"Bearer {self.api_key}"} | |
| payload = {"prompt": prompt, "temperature": self.temperature, "max_tokens": 256} | |
| resp = requests.post(self.endpoint, headers=headers, json=payload).json() | |
| class Response: pass | |
| result = Response() | |
| result.text = resp.get("choices", [{}])[0].get("text", "") | |
| return result | |
| def build_api_llm(provider: str, keys: dict, temperature: float = 0.7): | |
| """ | |
| Instantiate an API LLM based on provider name and supplied keys. | |
| Args: | |
| provider (str): one of "openai", "anthropic", "mistralai" | |
| keys (dict): mapping provider -> API key | |
| Returns: | |
| LLM instance configured for API calls | |
| """ | |
| p = provider.lower() | |
| if p == "openai": | |
| key = keys.get("openai") or os.getenv("OPENAI_API_KEY") | |
| client = OpenAI(api_key=key, temperature=temperature) | |
| elif p == "anthropic": | |
| key = keys.get("anthropic") or os.getenv("ANTHROPIC_API_KEY") | |
| client = Anthropic(api_key=key, temperature=temperature) | |
| elif p == "mistralai": | |
| key = keys.get("mistralai") or os.getenv("MISTRAL_API_KEY") | |
| client = MistralAPI(api_key=key, temperature=temperature) | |
| else: | |
| raise ValueError(f"Unsupported provider: {provider}") | |
| Settings.llm = client | |
| return client |