Spaces:
Sleeping
Sleeping
| import os | |
| import configparser | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| import requests # Need this for _call_hf_endpoint, but we will define the function here | |
| import httpx | |
| import json | |
| from typing import Dict, Any, List | |
| from langchain_core.documents import Document | |
| from dotenv import load_dotenv | |
| load_dotenv() # Load environment variables from .env file if present | |
| def getconfig(configfile_path: str) -> configparser.ConfigParser: | |
| """Reads the config file.""" | |
| config = configparser.ConfigParser() | |
| try: | |
| config.read_file(open(configfile_path)) | |
| return config | |
| except FileNotFoundError: | |
| logger.error(f"Warning: Config file not found at {configfile_path}. Relying on environment variables.") | |
| return configparser.ConfigParser() | |
| def get_auth_for_generator(provider: str) -> dict: | |
| """Get authentication configuration for different providers""" | |
| auth_configs = { | |
| "openai": {"api_key": os.getenv("OPENAI_API_KEY")}, | |
| "huggingface": {"api_key": os.getenv("HF_TOKEN")}, | |
| "anthropic": {"api_key": os.getenv("ANTHROPIC_API_KEY")}, | |
| "cohere": {"api_key": os.getenv("COHERE_API_KEY")}, | |
| } | |
| if provider not in auth_configs: | |
| logger.error(f"Unsupported provider: {provider}") | |
| raise ValueError(f"Unsupported provider: {provider}") | |
| auth_config = auth_configs[provider] | |
| api_key = auth_config.get("api_key") | |
| if not api_key: | |
| logger.error(f"Missing API key for provider '{provider}'. Please set the appropriate environment variable.") | |
| raise RuntimeError(f"Missing API key for provider '{provider}'. Please set the appropriate environment variable.") | |
| return auth_config | |
| def get_config_value(config: configparser.ConfigParser, section: str, key: str, env_var: str, fallback: Any = None) -> Any: | |
| """ | |
| Retrieves a config value, prioritizing: Environment Variable > Config File > Fallback. | |
| """ | |
| # 1. Check Environment Variable (Highest Priority) | |
| env_value = os.getenv(env_var) | |
| if env_value is not None: | |
| return env_value | |
| # 2. Check Config File | |
| try: | |
| return config.get(section, key) | |
| except (configparser.NoSectionError, configparser.NoOptionError): | |
| # 3. Use Fallback | |
| if fallback is not None: | |
| return fallback | |
| # 4. Error if essential config is missing | |
| logger.error(f"Configuration missing: Required value for [{section}]{key} was not found ") | |
| logger.error(f"in 'params.cfg' and the environment variable {env_var} is not set.") | |
| raise ValueError( | |
| f"Configuration missing: Required value for [{section}]{key} was not found " | |
| f"in 'params.cfg' and the environment variable {env_var} is not set." | |
| ) | |
| def _call_hf_endpoint(url: str, token: str, payload: Dict[str, Any]) -> Dict[str, Any]: | |
| """Helper for making authenticated requests to Hugging Face Endpoints.""" | |
| headers = { | |
| "Authorization": f"Bearer {token}", | |
| "Content-Type": "application/json", | |
| } | |
| try: | |
| logger.info(f"Calling endpoint {url}") | |
| response = requests.post(url, headers=headers, json=payload, timeout=60) | |
| if response.status_code == 503: | |
| logger.warning(f"HF Endpoint 503: Service overloaded/starting up at {url}") | |
| raise Exception("HF Service Unavailable (503)") | |
| elif response.status_code == 404: | |
| logger.error(f"HF Endpoint 404: Model not found at {url}") | |
| raise Exception("HF Model Not Found (404)") | |
| response.raise_for_status() | |
| return response.json() | |
| except requests.exceptions.RequestException as e: | |
| logger.error(f"Error calling HF endpoint ({url}): {e}") | |
| raise | |
| async def _acall_hf_endpoint(url: str, token: str, payload: Dict[str, Any]) -> Dict[str, Any]: | |
| """Asynchronously calls a Hugging Face Inference Endpoint using httpx.""" | |
| headers = { | |
| "Authorization": f"Bearer {token}", | |
| "Content-Type": "application/json", | |
| } | |
| # Use httpx.AsyncClient for asynchronous requests | |
| async with httpx.AsyncClient(timeout=60.0) as client: | |
| try: | |
| logger.info(f"Async Calling endpoint {url}") | |
| response = await client.post(url, headers=headers, json=payload) | |
| if response.status_code == 503: | |
| logger.warning(f"HF Endpoint 503: Service overloaded/starting up at {url}") | |
| raise Exception("HF Service Unavailable (503)") | |
| elif response.status_code == 404: | |
| logger.error(f"HF Endpoint 404: Model not found at {url}") | |
| raise Exception("HF Model Not Found (404)") | |
| response.raise_for_status() | |
| return response.json() | |
| except httpx.RequestError as e: | |
| logger.error(f"Error calling HF endpoint ({url}): {e}") | |
| raise | |
| def build_conversation_context(messages: List, max_turns: int = 3, max_chars: int = 8000) -> str: | |
| """ | |
| Build conversation context from structured messages to send to generator. | |
| Always keeps the first user and assistant messages, plus the last N turns. | |
| A "turn" is one user message + following assistant response. | |
| Args: | |
| messages: List of Message objects with 'role' and 'content' attributes | |
| max_turns: Maximum number of user-assistant exchange pairs to include (from the end) | |
| max_chars: Maximum total characters in context (default 8000) | |
| Returns: | |
| Formatted conversation context string like: | |
| "USER: query1\nASSISTANT: response1\n\nUSER: query2\nASSISTANT: response2" | |
| """ | |
| if not messages: | |
| return "" | |
| context_parts = [] | |
| char_count = 0 | |
| msgs_included = 0 | |
| # Always include the first user and assistant messages | |
| first_user_msg = None | |
| first_assistant_msg = None | |
| # Find first user and assistant messages | |
| for msg in messages: | |
| if msg.role == 'user' and first_user_msg is None: | |
| first_user_msg = msg | |
| elif msg.role == 'assistant' and first_assistant_msg is None: | |
| first_assistant_msg = msg | |
| if first_user_msg and first_assistant_msg: | |
| break | |
| # Add first messages if they exist | |
| if first_user_msg: | |
| msg_text = f"USER: {first_user_msg.content}" | |
| msg_chars = len(msg_text) | |
| if char_count + msg_chars <= max_chars: | |
| context_parts.append(msg_text) | |
| char_count += msg_chars | |
| msgs_included += 1 | |
| if first_assistant_msg: | |
| msg_text = f"ASSISTANT: {first_assistant_msg.content}" | |
| msg_chars = len(msg_text) | |
| if char_count + msg_chars <= max_chars: | |
| context_parts.append(msg_text) | |
| char_count += msg_chars | |
| msgs_included += 1 | |
| # Collect last N complete turns (user + assistant pairs) | |
| # Find the last N user messages and their corresponding assistant responses | |
| user_messages = [msg for msg in messages if msg.role == 'user'] | |
| # Get the last N user messages (excluding the first one we already included) | |
| recent_user_messages = user_messages[1:][-max_turns:] if len(user_messages) > 1 else [] | |
| turn_count = 0 | |
| recent_messages = [] | |
| # Process each recent user message and find its corresponding assistant response | |
| for user_msg in recent_user_messages: | |
| if turn_count >= max_turns: | |
| break | |
| # Find the assistant response that follows this user message | |
| user_index = messages.index(user_msg) | |
| assistant_msg = None | |
| # Look for the next assistant message after this user message | |
| for i in range(user_index + 1, len(messages)): | |
| if messages[i].role == 'assistant': | |
| assistant_msg = messages[i] | |
| break | |
| # Add user message | |
| user_text = f"USER: {user_msg.content}" | |
| user_chars = len(user_text) | |
| if char_count + user_chars > max_chars: | |
| logger.info(f"Stopping context build: would exceed max_chars ({max_chars})") | |
| break | |
| recent_messages.append(user_text) | |
| char_count += user_chars | |
| msgs_included += 1 | |
| # Add assistant message if it exists | |
| if assistant_msg: | |
| assistant_text = f"ASSISTANT: {assistant_msg.content}" | |
| assistant_chars = len(assistant_text) | |
| if char_count + assistant_chars > max_chars: | |
| logger.info(f"Stopping context build: would exceed max_chars ({max_chars})") | |
| break | |
| recent_messages.append(assistant_text) | |
| char_count += assistant_chars | |
| msgs_included += 1 | |
| turn_count += 1 | |
| # Add recent messages to context | |
| context_parts.extend(recent_messages) | |
| context = "\n\n".join(context_parts) | |
| logger.debug(f"Built conversation context: {msgs_included} messages, {char_count} chars") | |
| return context |