Spaces:
Sleeping
Sleeping
| import ast | |
| import datetime | |
| import os | |
| from typing import List, Dict, Generator | |
| from dotenv import load_dotenv | |
| from tenacity import ( | |
| retry, | |
| stop_after_attempt, | |
| wait_random_exponential, | |
| ) # for exponential backoff | |
| # Load environment variables from .env file | |
| load_dotenv() | |
| def anthropic_completion_with_backoff(client, *args, **kwargs): | |
| return client.beta.prompt_caching.messages.create(*args, **kwargs) | |
| def get_anthropic(model: str, | |
| prompt: str, | |
| temperature: float = 0, | |
| max_tokens: int = 4096, | |
| system: str = '', | |
| chat_history: List[Dict] = None, | |
| secrets: Dict = {}, | |
| verbose=False) -> \ | |
| Generator[dict, None, None]: | |
| model = model.replace('anthropic:', '') | |
| # https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching | |
| import anthropic | |
| clawd_key = secrets.get('ANTHROPIC_API_KEY') | |
| clawd_client = anthropic.Anthropic(api_key=clawd_key) if clawd_key else None | |
| if chat_history is None: | |
| chat_history = [] | |
| messages = [] | |
| # Add conversation history, removing cache_control from all but the last two user messages | |
| for i, message in enumerate(chat_history): | |
| if message["role"] == "user": | |
| if i >= len(chat_history) - 3: # Last two user messages | |
| messages.append(message) | |
| else: | |
| messages.append({ | |
| "role": "user", | |
| "content": [{"type": "text", "text": message["content"][0]["text"]}] | |
| }) | |
| else: | |
| messages.append(message) | |
| # Add the new user message | |
| messages.append({ | |
| "role": "user", | |
| "content": [ | |
| { | |
| "type": "text", | |
| "text": prompt, | |
| "cache_control": {"type": "ephemeral"} | |
| } | |
| ] | |
| }) | |
| response = anthropic_completion_with_backoff(clawd_client, | |
| model=model, | |
| max_tokens=max_tokens, | |
| temperature=temperature, | |
| system=system, | |
| messages=messages, | |
| stream=True | |
| ) | |
| output_tokens = 0 | |
| input_tokens = 0 | |
| cache_creation_input_tokens = 0 | |
| cache_read_input_tokens = 0 | |
| for chunk in response: | |
| if chunk.type == "content_block_start": | |
| # This is where we might find usage info in the future | |
| pass | |
| elif chunk.type == "content_block_delta": | |
| yield dict(text=chunk.delta.text) | |
| elif chunk.type == "message_delta": | |
| output_tokens = dict(chunk.usage).get('output_tokens', 0) | |
| elif chunk.type == "message_start": | |
| usage = chunk.message.usage | |
| input_tokens = dict(usage).get('input_tokens', 0) | |
| cache_creation_input_tokens = dict(usage).get('cache_creation_input_tokens', 0) | |
| cache_read_input_tokens = dict(usage).get('cache_read_input_tokens', 0) | |
| else: | |
| if verbose: | |
| print("Unknown chunk type:", chunk.type) | |
| print("Chunk:", chunk) | |
| if verbose: | |
| # After streaming is complete, print the usage information | |
| print(f"Output tokens: {output_tokens}") | |
| print(f"Input tokens: {input_tokens}") | |
| print(f"Cache creation input tokens: {cache_creation_input_tokens}") | |
| print(f"Cache read input tokens: {cache_read_input_tokens}") | |
| yield dict(output_tokens=output_tokens, input_tokens=input_tokens, | |
| cache_creation_input_tokens=cache_creation_input_tokens, | |
| cache_read_input_tokens=cache_read_input_tokens) | |
| def openai_completion_with_backoff(client, *args, **kwargs): | |
| return client.chat.completions.create(*args, **kwargs) | |
| def get_openai(model: str, | |
| prompt: str, | |
| temperature: float = 0, | |
| max_tokens: int = 4096, | |
| system: str = '', | |
| chat_history: List[Dict] = None, | |
| secrets: Dict = {}, | |
| verbose=False) -> Generator[dict, None, None]: | |
| if model.startswith('ollama:'): | |
| model = model.replace('ollama:', '') | |
| openai_key = secrets.get('OLLAMA_OPENAI_API_KEY') | |
| openai_base_url = secrets.get('OLLAMA_OPENAI_BASE_URL', 'http://localhost:11434/v1/') | |
| else: | |
| model = model.replace('openai:', '') | |
| openai_key = secrets.get('OPENAI_API_KEY') | |
| openai_base_url = secrets.get('OPENAI_BASE_URL', 'https://api.openai.com/v1') | |
| from openai import OpenAI | |
| openai_client = OpenAI(api_key=openai_key, base_url=openai_base_url) if openai_key else None | |
| if chat_history is None: | |
| chat_history = [] | |
| chat_history_copy = chat_history.copy() | |
| for mi, message in enumerate(chat_history_copy): | |
| if isinstance(message["content"], list): | |
| chat_history_copy[mi]["content"] = message["content"][0]["text"] | |
| chat_history = chat_history_copy | |
| messages = [{"role": "system", "content": system}] + chat_history + [{"role": "user", "content": prompt}] | |
| response = openai_completion_with_backoff(openai_client, | |
| model=model, | |
| messages=messages, | |
| temperature=temperature, | |
| max_tokens=max_tokens, | |
| stream=True, | |
| ) | |
| output_tokens = 0 | |
| input_tokens = 0 | |
| for chunk in response: | |
| if chunk.choices[0].delta.content: | |
| yield dict(text=chunk.choices[0].delta.content) | |
| if chunk.usage: | |
| output_tokens = chunk.usage.completion_tokens | |
| input_tokens = chunk.usage.prompt_tokens | |
| if verbose: | |
| print(f"Output tokens: {output_tokens}") | |
| print(f"Input tokens: {input_tokens}") | |
| yield dict(output_tokens=output_tokens, input_tokens=input_tokens) | |
| def openai_messages_to_gemini_history(messages): | |
| """Converts OpenAI messages to Gemini history format. | |
| Args: | |
| messages: A list of OpenAI messages, each with "role" and "content" keys. | |
| Returns: | |
| A list of dictionaries representing the chat history for Gemini. | |
| """ | |
| history = [] | |
| for message in messages: | |
| if isinstance(message["content"], list): | |
| message["content"] = message["content"][0]["text"] | |
| if message["role"] == "user": | |
| history.append({"role": "user", "parts": [{"text": message["content"]}]}) | |
| elif message["role"] == "assistant": | |
| history.append({"role": "model", "parts": [{"text": message["content"]}]}) | |
| # Optionally handle system messages if needed | |
| # elif message["role"] == "system": | |
| # history.append({"role": "system", "parts": [{"text": message["content"]}]}) | |
| return history | |
| def gemini_send_message_with_backoff(chat, prompt, stream=True): | |
| return chat.send_message(prompt, stream=stream) | |
| def gemini_generate_content_with_backoff(model, prompt, stream=True): | |
| return model.generate_content(prompt, stream=stream) | |
| def get_google(model: str, | |
| prompt: str, | |
| temperature: float = 0, | |
| max_tokens: int = 4096, | |
| system: str = '', | |
| chat_history: List[Dict] = None, | |
| secrets: Dict = {}, | |
| verbose=False) -> Generator[dict, None, None]: | |
| model = model.replace('google:', '').replace('gemini:', '') | |
| import google.generativeai as genai | |
| gemini_key = secrets.get("GEMINI_API_KEY") | |
| genai.configure(api_key=gemini_key) | |
| # Create the model | |
| generation_config = { | |
| "temperature": temperature, | |
| "top_p": 0.95, | |
| "top_k": 64, | |
| "max_output_tokens": max_tokens, | |
| "response_mime_type": "text/plain", | |
| } | |
| if chat_history is None: | |
| chat_history = [] | |
| chat_history = chat_history.copy() | |
| chat_history = openai_messages_to_gemini_history(chat_history) | |
| # NOTE: assume want own control. Too many false positives by Google. | |
| from google.generativeai.types import HarmCategory | |
| from google.generativeai.types import HarmBlockThreshold | |
| safety_settings = { | |
| HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE, | |
| HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE, | |
| HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE, | |
| HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE, | |
| } | |
| cache = None | |
| # disable cache for now until work into things well | |
| use_cache = False | |
| if use_cache and model == 'gemini-1.5-pro': | |
| from google.generativeai import caching | |
| # Estimate token count (this is a rough estimate, you may need a more accurate method) | |
| estimated_tokens = len(prompt.split()) + sum(len(msg['content'].split()) for msg in chat_history) | |
| if estimated_tokens > 32000: | |
| cache = caching.CachedContent.create( | |
| model=model, | |
| display_name=f'cache_{datetime.datetime.now().isoformat()}', | |
| system_instruction=system, | |
| contents=[prompt] + [msg['content'] for msg in chat_history], | |
| ttl=datetime.timedelta(minutes=5), # Set an appropriate TTL. Short for now for cost savings. | |
| ) | |
| gemini_model = genai.GenerativeModel.from_cached_content(cached_content=cache) | |
| else: | |
| gemini_model = genai.GenerativeModel(model_name=model, | |
| generation_config=generation_config, | |
| safety_settings=safety_settings) | |
| else: | |
| gemini_model = genai.GenerativeModel(model_name=model, | |
| generation_config=generation_config, | |
| safety_settings=safety_settings) | |
| if cache: | |
| response = gemini_generate_content_with_backoff(gemini_model, prompt, stream=True) | |
| else: | |
| chat = gemini_model.start_chat(history=chat_history) | |
| response = gemini_send_message_with_backoff(chat, prompt, stream=True) | |
| output_tokens = 0 | |
| input_tokens = 0 | |
| cache_read_input_tokens = 0 | |
| cache_creation_input_tokens = 0 | |
| for chunk in response: | |
| if chunk.text: | |
| yield dict(text=chunk.text) | |
| if chunk.usage_metadata: | |
| output_tokens = chunk.usage_metadata.candidates_token_count | |
| input_tokens = chunk.usage_metadata.prompt_token_count | |
| cache_read_input_tokens = chunk.usage_metadata.cached_content_token_count | |
| cache_creation_input_tokens = 0 # This might need to be updated if available in the API | |
| if verbose: | |
| print(f"Output tokens: {output_tokens}") | |
| print(f"Input tokens: {input_tokens}") | |
| print(f"Cached tokens: {cache_read_input_tokens}") | |
| yield dict(output_tokens=output_tokens, input_tokens=input_tokens, | |
| cache_read_input_tokens=cache_read_input_tokens, | |
| cache_creation_input_tokens=cache_creation_input_tokens) | |
| def delete_cache(cache): | |
| if cache: | |
| cache.delete() | |
| print(f"Cache {cache.display_name} deleted.") | |
| else: | |
| print("No cache to delete.") | |
| def get_groq(model: str, | |
| prompt: str, | |
| temperature: float = 0, | |
| max_tokens: int = 4096, | |
| system: str = '', | |
| chat_history: List[Dict] = None, | |
| secrets: Dict = {}, | |
| verbose=False) -> Generator[dict, None, None]: | |
| model = model.replace('groq:', '') | |
| from groq import Groq | |
| groq_key = secrets.get("GROQ_API_KEY") | |
| client = Groq(api_key=groq_key) | |
| if chat_history is None: | |
| chat_history = [] | |
| chat_history = chat_history.copy() | |
| messages = [{"role": "system", "content": system}] + chat_history + [{"role": "user", "content": prompt}] | |
| stream = openai_completion_with_backoff(client, | |
| messages=messages, | |
| model=model, | |
| temperature=temperature, | |
| max_tokens=max_tokens, | |
| stream=True, | |
| ) | |
| output_tokens = 0 | |
| input_tokens = 0 | |
| for chunk in stream: | |
| if chunk.choices[0].delta.content: | |
| yield dict(text=chunk.choices[0].delta.content) | |
| if chunk.usage: | |
| output_tokens = chunk.usage.completion_tokens | |
| input_tokens = chunk.usage.prompt_tokens | |
| if verbose: | |
| print(f"Output tokens: {output_tokens}") | |
| print(f"Input tokens: {input_tokens}") | |
| yield dict(output_tokens=output_tokens, input_tokens=input_tokens) | |
| def get_cerebras(model: str, | |
| prompt: str, | |
| temperature: float = 0, | |
| max_tokens: int = 4096, | |
| system: str = '', | |
| chat_history: List[Dict] = None, | |
| secrets: Dict = {}, | |
| verbose=False) -> Generator[dict, None, None]: | |
| # context_length is only 8207 | |
| model = model.replace('cerebras:', '') | |
| from cerebras.cloud.sdk import Cerebras | |
| api_key = secrets.get("CEREBRAS_OPENAI_API_KEY") | |
| client = Cerebras(api_key=api_key) | |
| if chat_history is None: | |
| chat_history = [] | |
| chat_history = chat_history.copy() | |
| messages = [{"role": "system", "content": system}] + chat_history + [{"role": "user", "content": prompt}] | |
| stream = openai_completion_with_backoff(client, | |
| messages=messages, | |
| model=model, | |
| temperature=temperature, | |
| max_tokens=max_tokens, | |
| stream=True, | |
| ) | |
| output_tokens = 0 | |
| input_tokens = 0 | |
| for chunk in stream: | |
| if chunk.choices[0].delta.content: | |
| yield dict(text=chunk.choices[0].delta.content) | |
| if chunk.usage: | |
| output_tokens = chunk.usage.completion_tokens | |
| input_tokens = chunk.usage.prompt_tokens | |
| if verbose: | |
| print(f"Output tokens: {output_tokens}") | |
| print(f"Input tokens: {input_tokens}") | |
| yield dict(output_tokens=output_tokens, input_tokens=input_tokens) | |
| def get_openai_azure(model: str, | |
| prompt: str, | |
| temperature: float = 0, | |
| max_tokens: int = 4096, | |
| system: str = '', | |
| chat_history: List[Dict] = None, | |
| secrets: Dict = {}, | |
| verbose=False) -> Generator[dict, None, None]: | |
| model = model.replace('azure:', '').replace('openai_azure:', '') | |
| from openai import AzureOpenAI | |
| azure_endpoint = secrets.get("AZURE_OPENAI_ENDPOINT") # e.g. https://project.openai.azure.com | |
| azure_key = secrets.get("AZURE_OPENAI_API_KEY") | |
| azure_deployment = secrets.get("AZURE_OPENAI_DEPLOYMENT") # i.e. deployment name with some models deployed | |
| azure_api_version = secrets.get('AZURE_OPENAI_API_VERSION', '2024-07-01-preview') | |
| assert azure_endpoint is not None, "Azure OpenAI endpoint not set" | |
| assert azure_key is not None, "Azure OpenAI API key not set" | |
| assert azure_deployment is not None, "Azure OpenAI deployment not set" | |
| client = AzureOpenAI( | |
| azure_endpoint=azure_endpoint, | |
| api_key=azure_key, | |
| api_version=azure_api_version, | |
| azure_deployment=azure_deployment, | |
| ) | |
| if chat_history is None: | |
| chat_history = [] | |
| messages = [{"role": "system", "content": system}] + chat_history + [{"role": "user", "content": prompt}] | |
| response = openai_completion_with_backoff(client, | |
| model=model, | |
| messages=messages, | |
| temperature=temperature, | |
| max_tokens=max_tokens, | |
| stream=True | |
| ) | |
| output_tokens = 0 | |
| input_tokens = 0 | |
| for chunk in response: | |
| if chunk.choices and chunk.choices[0].delta.content: | |
| yield dict(text=chunk.choices[0].delta.content) | |
| if chunk.usage: | |
| output_tokens = chunk.usage.completion_tokens | |
| input_tokens = chunk.usage.prompt_tokens | |
| if verbose: | |
| print(f"Output tokens: {output_tokens}") | |
| print(f"Input tokens: {input_tokens}") | |
| yield dict(output_tokens=output_tokens, input_tokens=input_tokens) | |
| def to_list(x): | |
| if x: | |
| try: | |
| ollama_model_list = ast.literal_eval(x) | |
| assert isinstance(ollama_model_list, list) | |
| except: | |
| x = [x] | |
| else: | |
| x = [] | |
| return x | |
| def get_model_names(secrets, on_hf_spaces=False): | |
| if not on_hf_spaces: | |
| secrets = os.environ | |
| if secrets.get('ANTHROPIC_API_KEY'): | |
| anthropic_models = ['claude-3-5-sonnet-20240620', 'claude-3-haiku-20240307', 'claude-3-opus-20240229'] | |
| else: | |
| anthropic_models = [] | |
| if secrets.get('OPENAI_API_KEY'): | |
| if secrets.get('OPENAI_MODEL_NAME'): | |
| openai_models = to_list(secrets.get('OPENAI_MODEL_NAME')) | |
| else: | |
| openai_models = ['gpt-4o', 'gpt-4-turbo-2024-04-09', 'gpt-4o-mini'] | |
| else: | |
| openai_models = [] | |
| if secrets.get('AZURE_OPENAI_API_KEY'): | |
| if secrets.get('AZURE_OPENAI_MODEL_NAME'): | |
| azure_models = to_list(secrets.get('AZURE_OPENAI_MODEL_NAME')) | |
| else: | |
| azure_models = ['gpt-4o', 'gpt-4-turbo-2024-04-09', 'gpt-4o-mini'] | |
| else: | |
| azure_models = [] | |
| if secrets.get('GEMINI_API_KEY'): | |
| google_models = ['gemini-1.5-pro-latest', 'gemini-1.5-flash-latest'] | |
| else: | |
| google_models = [] | |
| if secrets.get('GROQ_API_KEY'): | |
| groq_models = ['llama-3.1-70b-versatile', | |
| 'llama-3.1-8b-instant', | |
| 'llama3-groq-70b-8192-tool-use-preview', | |
| 'llama3-groq-8b-8192-tool-use-preview', | |
| 'mixtral-8x7b-32768'] | |
| else: | |
| groq_models = [] | |
| if secrets.get('CEREBRAS_OPENAI_API_KEY'): | |
| cerebras_models = ['llama3.1-70b', 'llama3.1-8b'] | |
| else: | |
| cerebras_models = [] | |
| if secrets.get('OLLAMA_OPENAI_API_KEY'): | |
| ollama_model = os.environ['OLLAMA_OPENAI_MODEL_NAME'] | |
| ollama_model = to_list(ollama_model) | |
| else: | |
| ollama_model = [] | |
| groq_models = ['groq:' + x for x in groq_models] | |
| cerebras_models = ['cerebras:' + x for x in cerebras_models] | |
| azure_models = ['azure:' + x for x in azure_models] | |
| openai_models = ['openai:' + x for x in openai_models] | |
| google_models = ['google:' + x for x in google_models] | |
| anthropic_models = ['anthropic:' + x for x in anthropic_models] | |
| ollama = ['ollama:' + x if 'ollama:' not in x else x for x in ollama_model] | |
| return anthropic_models + openai_models + google_models + groq_models + cerebras_models + azure_models + ollama | |
| def get_model_api(model: str): | |
| assert model not in ['', None], "Model not set, need to add API key to have models appear and select one." | |
| if model.startswith('anthropic:'): | |
| return get_anthropic | |
| elif model.startswith('openai:') or model.startswith('ollama:'): | |
| return get_openai | |
| elif model.startswith('google:'): | |
| return get_google | |
| elif model.startswith('groq:'): | |
| return get_groq | |
| elif model.startswith('cerebras:'): | |
| return get_cerebras | |
| elif model.startswith('azure:'): | |
| return get_openai_azure | |
| else: | |
| raise ValueError( | |
| f"Unsupported model: {model}. Ensure to add prefix (e.g. openai:, google:, groq:, cerebras:, azure:, ollama:, anthropic:)") | |