| import os |
| import requests |
| import json |
| import logging |
|
|
| logging.basicConfig( |
| level=logging.INFO, |
| format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" |
| ) |
| logger = logging.getLogger(__name__) |
|
|
| API_KEYS = { |
| "HUGGINGFACE": 'HF_TOKEN', |
| "GROQ": 'GROQ_API_KEY', |
| "OPENROUTER": 'OPENROUTER_API_KEY', |
| "TOGETHERAI": 'TOGETHERAI_API_KEY', |
| "COHERE": 'COHERE_API_KEY', |
| "XAI": 'XAI_API_KEY', |
| "OPENAI": 'OPENAI_API_KEY', |
| "GOOGLE": 'GOOGLE_API_KEY', |
| } |
|
|
| API_URLS = { |
| "HUGGINGFACE": 'https://api-inference.huggingface.co/models/', |
| "GROQ": 'https://api.groq.com/openai/v1/chat/completions', |
| "OPENROUTER": 'https://openrouter.ai/api/v1/chat/completions', |
| "TOGETHERAI": 'https://api.together.ai/v1/chat/completions', |
| "COHERE": 'https://api.cohere.ai/v1/chat', |
| "XAI": 'https://api.x.ai/v1/chat/completions', |
| "OPENAI": 'https://api.openai.com/v1/chat/completions', |
| "GOOGLE": 'https://generativelanguage.googleapis.com/v1beta/models/', |
| } |
|
|
| MODELS_BY_PROVIDER = { |
| "groq": { |
| "default": "llama3-8b-8192", |
| "models": { |
| "Llama 3 8B (Groq)": "llama3-8b-8192", |
| "Llama 3 70B (Groq)": "llama3-70b-8192", |
| "Mixtral 8x7B (Groq)": "mixtral-8x7b-32768", |
| "Gemma 7B (Groq)": "gemma-7b-it", |
| } |
| }, |
| "openrouter": { |
| "default": "nousresearch/llama-3-8b-instruct", |
| "models": { |
| "Nous Llama-3 8B Instruct (OpenRouter)": "nousresearch/llama-3-8b-instruct", |
| "Mistral 7B Instruct v0.2 (OpenRouter)": "mistralai/mistral-7b-instruct:free", |
| "Gemma 7B Instruct (OpenRouter)": "google/gemma-7b-it:free", |
| "Mixtral 8x7B Instruct v0.1 (OpenRouter)": "mistralai/mixtral-8x7b-instruct", |
| "Llama 2 70B Chat (OpenRouter)": "meta-llama/llama-2-70b-chat", |
| "Neural Chat 7B v3.1 (OpenRouter)": "intel/neural-chat-7b-v3-1", |
| "Goliath 120B (OpenRouter)": "twob/goliath-v2-120b", |
| } |
| }, |
| "togetherai": { |
| "default": "meta-llama/Llama-3-8b-chat-hf", |
| "models": { |
| "Llama 3 8B Chat (TogetherAI)": "meta-llama/Llama-3-8b-chat-hf", |
| "Llama 3 70B Chat (TogetherAI)": "meta-llama/Llama-3-70b-chat-hf", |
| "Mixtral 8x7B Instruct (TogetherAI)": "mistralai/Mixtral-8x7B-Instruct-v0.1", |
| "Gemma 7B Instruct (TogetherAI)": "google/gemma-7b-it", |
| "RedPajama INCITE Chat 3B (TogetherAI)": "togethercomputer/RedPajama-INCITE-Chat-3B-v1", |
| } |
| }, |
| "google": { |
| "default": "gemini-1.5-flash-latest", |
| "models": { |
| "Gemini 1.5 Flash (Latest)": "gemini-1.5-flash-latest", |
| "Gemini 1.5 Pro (Latest)": "gemini-1.5-pro-latest", |
| } |
| }, |
| "cohere": { |
| "default": "command-light", |
| "models": { |
| "Command R (Cohere)": "command-r", |
| "Command R+ (Cohere)": "command-r-plus", |
| "Command Light (Cohere)": "command-light", |
| "Command (Cohere)": "command", |
| } |
| }, |
| "huggingface": { |
| "default": "HuggingFaceH4/zephyr-7b-beta", |
| "models": { |
| "Zephyr 7B Beta (H4/HF Inf.)": "HuggingFaceH4/zephyr-7b-beta", |
| "Mistral 7B Instruct v0.2 (HF Inf.)": "mistralai/Mistral-7B-Instruct-v0.2", |
| "Llama 2 13B Chat (Meta/HF Inf.)": "meta-llama/Llama-2-13b-chat-hf", |
| "OpenAssistant/oasst-sft-4-pythia-12b (HF Inf.)": "OpenAssistant/oasst-sft-4-pythia-12b", |
| } |
| }, |
| "openai": { |
| "default": "gpt-3.5-turbo", |
| "models": { |
| "GPT-4o (OpenAI)": "gpt-4o", |
| "GPT-4o mini (OpenAI)": "gpt-4o-mini", |
| "GPT-4 Turbo (OpenAI)": "gpt-4-turbo", |
| "GPT-3.5 Turbo (OpenAI)": "gpt-3.5-turbo", |
| } |
| }, |
| "xai": { |
| "default": "grok-1", |
| "models": { |
| "Grok-1 (xAI)": "grok-1", |
| } |
| } |
| } |
|
|
| def _get_api_key(provider: str, ui_api_key_override: str = None) -> str: |
| if ui_api_key_override: |
| return ui_api_key_override.strip() |
|
|
| env_var_name = API_KEYS.get(provider.upper()) |
| if env_var_name: |
| env_key = os.getenv(env_var_name) |
| if env_key: |
| return env_key.strip() |
|
|
| if provider.lower() == 'huggingface': |
| hf_token = os.getenv("HF_TOKEN") |
| if hf_token: return hf_token.strip() |
|
|
| logger.warning(f"API Key not found for provider '{provider}'. Checked UI override and environment variable '{env_var_name or 'N/A'}'.") |
| return None |
|
|
| def get_available_providers() -> list[str]: |
| return sorted(list(MODELS_BY_PROVIDER.keys())) |
|
|
| def get_models_for_provider(provider: str) -> list[str]: |
| return sorted(list(MODELS_BY_PROVIDER.get(provider.lower(), {}).get("models", {}).keys())) |
|
|
| def get_default_model_for_provider(provider: str) -> str | None: |
| models_dict = MODELS_BY_PROVIDER.get(provider.lower(), {}).get("models", {}) |
| default_model_id = MODELS_BY_PROVIDER.get(provider.lower(), {}).get("default") |
| if default_model_id: |
| for display_name, model_id in models_dict.items(): |
| if model_id == default_model_id: |
| return display_name |
| if models_dict: |
| return sorted(list(models_dict.keys()))[0] |
| return None |
|
|
| def get_model_id_from_display_name(provider: str, display_name: str) -> str | None: |
| models = MODELS_BY_PROVIDER.get(provider.lower(), {}).get("models", {}) |
| return models.get(display_name) |
|
|
| def generate_stream(provider: str, model_display_name: str, api_key_override: str, messages: list[dict]) -> iter: |
| provider_lower = provider.lower() |
| api_key = _get_api_key(provider_lower, api_key_override) |
|
|
| base_url = API_URLS.get(provider.upper()) |
| model_id = get_model_id_from_display_name(provider_lower, model_display_name) |
|
|
| if not api_key: |
| env_var_name = API_KEYS.get(provider.upper(), 'N/A') |
| yield f"Error: API Key not found for {provider}. Please set it in the UI override or environment variable '{env_var_name}'." |
| return |
| if not base_url: |
| yield f"Error: Unknown provider '{provider}' or missing API URL configuration." |
| return |
| if not model_id: |
| yield f"Error: Unknown model '{model_display_name}' for provider '{provider}'. Please select a valid model." |
| return |
|
|
| headers = {} |
| payload = {} |
| request_url = base_url |
|
|
| logger.info(f"Calling {provider}/{model_display_name} (ID: {model_id}) stream...") |
|
|
| try: |
| if provider_lower in ["groq", "openrouter", "togetherai", "openai", "xai"]: |
| headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"} |
| payload = { |
| "model": model_id, |
| "messages": messages, |
| "stream": True |
| } |
| if provider_lower == "openrouter": |
| headers["HTTP-Referer"] = os.getenv("SPACE_HOST") or "https://github.com/your_username/ai-space-builder" |
| headers["X-Title"] = "AI Space Builder" |
|
|
| response = requests.post(request_url, headers=headers, json=payload, stream=True, timeout=180) |
| response.raise_for_status() |
|
|
| byte_buffer = b"" |
| for chunk in response.iter_content(chunk_size=8192): |
| byte_buffer += chunk |
| while b'\n' in byte_buffer: |
| line, byte_buffer = byte_buffer.split(b'\n', 1) |
| decoded_line = line.decode('utf-8', errors='ignore') |
| if decoded_line.startswith('data: '): |
| data = decoded_line[6:] |
| if data == '[DONE]': |
| byte_buffer = b'' |
| break |
| try: |
| event_data = json.loads(data) |
| if event_data.get("choices") and len(event_data["choices"]) > 0: |
| delta = event_data["choices"][0].get("delta") |
| if delta and delta.get("content"): |
| yield delta["content"] |
| except json.JSONDecodeError: |
| logger.warning(f"Failed to decode JSON from stream line: {decoded_line}") |
| except Exception as e: |
| logger.error(f"Error processing stream data: {e}, Data: {decoded_line}") |
| if byte_buffer: |
| remaining_line = byte_buffer.decode('utf-8', errors='ignore') |
| if remaining_line.startswith('data: '): |
| data = remaining_line[6:] |
| if data != '[DONE]': |
| try: |
| event_data = json.loads(data) |
| if event_data.get("choices") and len(event_data["choices"]) > 0: |
| delta = event_data["choices"][0].get("delta") |
| if delta and delta.get("content"): |
| yield delta["content"] |
| except json.JSONDecodeError: |
| logger.warning(f"Failed to decode final stream buffer JSON: {remaining_line}") |
| except Exception as e: |
| logger.error(f"Error processing final stream buffer data: {e}, Data: {remaining_line}") |
|
|
|
|
| elif provider_lower == "google": |
| system_instruction = None |
| filtered_messages = [] |
| for msg in messages: |
| if msg["role"] == "system": |
| system_instruction = msg["content"] |
| else: |
| role = "model" if msg["role"] == "assistant" else msg["role"] |
| filtered_messages.append({"role": role, "parts": [{"text": msg["content"]}]}) |
|
|
| payload = { |
| "contents": filtered_messages, |
| "safetySettings": [ |
| {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE"}, |
| {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"}, |
| {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"}, |
| {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_NONE"}, |
| ], |
| "generationConfig": { |
| "temperature": 0.7, |
| } |
| } |
| if system_instruction: |
| payload["system_instruction"] = {"parts": [{"text": system_instruction}]} |
|
|
| request_url = f"{base_url}{model_id}:streamGenerateContent" |
| headers = {"Content-Type": "application/json"} |
| request_url = f"{request_url}?key={api_key}" |
|
|
| response = requests.post(request_url, headers=headers, json=payload, stream=True, timeout=180) |
| response.raise_for_status() |
|
|
| byte_buffer = b"" |
| for chunk in response.iter_content(chunk_size=8192): |
| byte_buffer += chunk |
| while b'\n' in byte_buffer: |
| line, byte_buffer = byte_buffer.split(b'\n', 1) |
| decoded_line = line.decode('utf-8', errors='ignore') |
|
|
| if decoded_line.startswith('data: '): |
| decoded_line = decoded_line[6:].strip() |
|
|
| if not decoded_line: continue |
|
|
| try: |
| event_data_list = json.loads(f"[{decoded_line}]") |
| if not isinstance(event_data_list, list): event_data_list = [event_data_list] |
|
|
| for event_data in event_data_list: |
| if not isinstance(event_data, dict): continue |
|
|
| if event_data.get("candidates") and len(event_data["candidates"]) > 0: |
| candidate = event_data["candidates"][0] |
| if candidate.get("content") and candidate["content"].get("parts"): |
| full_text_chunk = "".join(part.get("text", "") for part in candidate["content"]["parts"]) |
| if full_text_chunk: |
| yield full_text_chunk |
|
|
| except json.JSONDecodeError: |
| logger.warning(f"Failed to decode JSON from Google stream chunk: {decoded_line}. Accumulating buffer.") |
| pass |
|
|
| except Exception as e: |
| logger.error(f"Error processing Google stream data: {e}, Data: {decoded_line}") |
|
|
| if byte_buffer: |
| remaining_line = byte_buffer.decode('utf-8', errors='ignore').strip() |
| if remaining_line: |
| try: |
| event_data_list = json.loads(f"[{remaining_line}]") |
| if not isinstance(event_data_list, list): event_data_list = [event_data_list] |
| for event_data in event_data_list: |
| if not isinstance(event_data, dict): continue |
| if event_data.get("candidates") and len(event_data["candidates"]) > 0: |
| candidate = event_data["candidates"][0] |
| if candidate.get("content") and candidate["content"].get("parts"): |
| full_text_chunk = "".join(part.get("text", "") for part in candidate["content"]["parts"]) |
| if full_text_chunk: |
| yield full_text_chunk |
| except json.JSONDecodeError: |
| logger.warning(f"Failed to decode final Google stream buffer JSON: {remaining_line}") |
| except Exception as e: |
| logger.error(f"Error processing final Google stream buffer data: {e}, Data: {remaining_line}") |
|
|
|
|
| elif provider_lower == "cohere": |
| headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"} |
| request_url = f"{base_url}" |
|
|
| chat_history_for_cohere = [] |
| system_prompt_for_cohere = None |
| current_message_for_cohere = "" |
|
|
| temp_history = [] |
| for msg in messages: |
| if msg["role"] == "system": |
| system_prompt_for_cohere = msg["content"] |
| elif msg["role"] == "user" or msg["role"] == "assistant": |
| temp_history.append(msg) |
|
|
| if temp_history: |
| current_message_for_cohere = temp_history[-1]["content"] |
| chat_history_for_cohere = [{"role": ("chatbot" if m["role"] == "assistant" else m["role"]), "message": m["content"]} for m in temp_history[:-1]] |
|
|
| if not current_message_for_cohere: |
| yield "Error: User message not found for Cohere API call." |
| return |
|
|
| payload = { |
| "model": model_id, |
| "message": current_message_for_cohere, |
| "stream": True, |
| "temperature": 0.7 |
| } |
| if chat_history_for_cohere: |
| payload["chat_history"] = chat_history_for_cohere |
| if system_prompt_for_cohere: |
| payload["preamble"] = system_prompt_for_cohere |
|
|
| response = requests.post(request_url, headers=headers, json=payload, stream=True, timeout=180) |
| response.raise_for_status() |
|
|
| byte_buffer = b"" |
| for chunk in response.iter_content(chunk_size=8192): |
| byte_buffer += chunk |
| while b'\n\n' in byte_buffer: |
| event_chunk, byte_buffer = byte_buffer.split(b'\n\n', 1) |
| lines = event_chunk.strip().split(b'\n') |
| event_type = None |
| event_data = None |
|
|
| for l in lines: |
| if l.startswith(b"event: "): event_type = l[7:].strip().decode('utf-8', errors='ignore') |
| elif l.startswith(b"data: "): |
| try: event_data = json.loads(l[6:].strip().decode('utf-8', errors='ignore')) |
| except json.JSONDecodeError: logger.warning(f"Cohere: Failed to decode event data JSON: {l[6:].strip()}") |
|
|
| if event_type == "text-generation" and event_data and "text" in event_data: |
| yield event_data["text"] |
| elif event_type == "stream-end": |
| byte_buffer = b'' |
| break |
|
|
| if byte_buffer: |
| event_chunk = byte_buffer.strip() |
| if event_chunk: |
| lines = event_chunk.split(b'\n') |
| event_type = None |
| event_data = None |
| for l in lines: |
| if l.startswith(b"event: "): event_type = l[7:].strip().decode('utf-8', errors='ignore') |
| elif l.startswith(b"data: "): |
| try: event_data = json.loads(l[6:].strip().decode('utf-8', errors='ignore')) |
| except json.JSONDecodeError: logger.warning(f"Cohere: Failed to decode final event data JSON: {l[6:].strip()}") |
|
|
| if event_type == "text-generation" and event_data and "text" in event_data: |
| yield event_data["text"] |
| elif event_type == "stream-end": |
| pass |
|
|
|
|
| elif provider_lower == "huggingface": |
| yield f"Error: Direct Hugging Face Inference API streaming for chat models is experimental and model-dependent. Consider using OpenRouter or TogetherAI for HF models with standardized streaming." |
| return |
|
|
| else: |
| yield f"Error: Unsupported provider '{provider}' for streaming chat." |
| return |
|
|
| except requests.exceptions.HTTPError as e: |
| status_code = e.response.status_code if e.response is not None else 'N/A' |
| error_text = e.response.text if e.response is not None else 'No response text' |
| logger.error(f"HTTP error during streaming for {provider}/{model_id}: {e}") |
| yield f"API HTTP Error ({status_code}): {error_text}\nDetails: {e}" |
| except requests.exceptions.RequestException as e: |
| logger.error(f"Request error during streaming for {provider}/{model_id}: {e}") |
| yield f"API Request Error: Could not connect or receive response from {provider} ({e})" |
| except Exception as e: |
| logger.exception(f"Unexpected error during streaming for {provider}/{model_id}:") |
| yield f"An unexpected error occurred: {e}" |