Spaces:
Sleeping
Sleeping
File size: 8,905 Bytes
3194955 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 |
import os
import configparser
import logging
logger = logging.getLogger(__name__)
import requests # Need this for _call_hf_endpoint, but we will define the function here
import httpx
import json
from typing import Dict, Any, List
from langchain_core.documents import Document
from dotenv import load_dotenv
load_dotenv() # Load environment variables from .env file if present
def getconfig(configfile_path: str) -> configparser.ConfigParser:
"""Reads the config file."""
config = configparser.ConfigParser()
try:
config.read_file(open(configfile_path))
return config
except FileNotFoundError:
logger.error(f"Warning: Config file not found at {configfile_path}. Relying on environment variables.")
return configparser.ConfigParser()
def get_auth_for_generator(provider: str) -> dict:
"""Get authentication configuration for different providers"""
auth_configs = {
"openai": {"api_key": os.getenv("OPENAI_API_KEY")},
"huggingface": {"api_key": os.getenv("HF_TOKEN")},
"anthropic": {"api_key": os.getenv("ANTHROPIC_API_KEY")},
"cohere": {"api_key": os.getenv("COHERE_API_KEY")},
}
if provider not in auth_configs:
logger.error(f"Unsupported provider: {provider}")
raise ValueError(f"Unsupported provider: {provider}")
auth_config = auth_configs[provider]
api_key = auth_config.get("api_key")
if not api_key:
logger.error(f"Missing API key for provider '{provider}'. Please set the appropriate environment variable.")
raise RuntimeError(f"Missing API key for provider '{provider}'. Please set the appropriate environment variable.")
return auth_config
def get_config_value(config: configparser.ConfigParser, section: str, key: str, env_var: str, fallback: Any = None) -> Any:
"""
Retrieves a config value, prioritizing: Environment Variable > Config File > Fallback.
"""
# 1. Check Environment Variable (Highest Priority)
env_value = os.getenv(env_var)
if env_value is not None:
return env_value
# 2. Check Config File
try:
return config.get(section, key)
except (configparser.NoSectionError, configparser.NoOptionError):
# 3. Use Fallback
if fallback is not None:
return fallback
# 4. Error if essential config is missing
logger.error(f"Configuration missing: Required value for [{section}]{key} was not found ")
logger.error(f"in 'params.cfg' and the environment variable {env_var} is not set.")
raise ValueError(
f"Configuration missing: Required value for [{section}]{key} was not found "
f"in 'params.cfg' and the environment variable {env_var} is not set."
)
def _call_hf_endpoint(url: str, token: str, payload: Dict[str, Any]) -> Dict[str, Any]:
"""Helper for making authenticated requests to Hugging Face Endpoints."""
headers = {
"Authorization": f"Bearer {token}",
"Content-Type": "application/json",
}
try:
logger.info(f"Calling endpoint {url}")
response = requests.post(url, headers=headers, json=payload, timeout=60)
if response.status_code == 503:
logger.warning(f"HF Endpoint 503: Service overloaded/starting up at {url}")
raise Exception("HF Service Unavailable (503)")
elif response.status_code == 404:
logger.error(f"HF Endpoint 404: Model not found at {url}")
raise Exception("HF Model Not Found (404)")
response.raise_for_status()
return response.json()
except requests.exceptions.RequestException as e:
logger.error(f"Error calling HF endpoint ({url}): {e}")
raise
async def _acall_hf_endpoint(url: str, token: str, payload: Dict[str, Any]) -> Dict[str, Any]:
"""Asynchronously calls a Hugging Face Inference Endpoint using httpx."""
headers = {
"Authorization": f"Bearer {token}",
"Content-Type": "application/json",
}
# Use httpx.AsyncClient for asynchronous requests
async with httpx.AsyncClient(timeout=60.0) as client:
try:
logger.info(f"Async Calling endpoint {url}")
response = await client.post(url, headers=headers, json=payload)
if response.status_code == 503:
logger.warning(f"HF Endpoint 503: Service overloaded/starting up at {url}")
raise Exception("HF Service Unavailable (503)")
elif response.status_code == 404:
logger.error(f"HF Endpoint 404: Model not found at {url}")
raise Exception("HF Model Not Found (404)")
response.raise_for_status()
return response.json()
except httpx.RequestError as e:
logger.error(f"Error calling HF endpoint ({url}): {e}")
raise
def build_conversation_context(messages: List, max_turns: int = 3, max_chars: int = 8000) -> str:
"""
Build conversation context from structured messages to send to generator.
Always keeps the first user and assistant messages, plus the last N turns.
A "turn" is one user message + following assistant response.
Args:
messages: List of Message objects with 'role' and 'content' attributes
max_turns: Maximum number of user-assistant exchange pairs to include (from the end)
max_chars: Maximum total characters in context (default 8000)
Returns:
Formatted conversation context string like:
"USER: query1\nASSISTANT: response1\n\nUSER: query2\nASSISTANT: response2"
"""
if not messages:
return ""
context_parts = []
char_count = 0
msgs_included = 0
# Always include the first user and assistant messages
first_user_msg = None
first_assistant_msg = None
# Find first user and assistant messages
for msg in messages:
if msg.role == 'user' and first_user_msg is None:
first_user_msg = msg
elif msg.role == 'assistant' and first_assistant_msg is None:
first_assistant_msg = msg
if first_user_msg and first_assistant_msg:
break
# Add first messages if they exist
if first_user_msg:
msg_text = f"USER: {first_user_msg.content}"
msg_chars = len(msg_text)
if char_count + msg_chars <= max_chars:
context_parts.append(msg_text)
char_count += msg_chars
msgs_included += 1
if first_assistant_msg:
msg_text = f"ASSISTANT: {first_assistant_msg.content}"
msg_chars = len(msg_text)
if char_count + msg_chars <= max_chars:
context_parts.append(msg_text)
char_count += msg_chars
msgs_included += 1
# Collect last N complete turns (user + assistant pairs)
# Find the last N user messages and their corresponding assistant responses
user_messages = [msg for msg in messages if msg.role == 'user']
# Get the last N user messages (excluding the first one we already included)
recent_user_messages = user_messages[1:][-max_turns:] if len(user_messages) > 1 else []
turn_count = 0
recent_messages = []
# Process each recent user message and find its corresponding assistant response
for user_msg in recent_user_messages:
if turn_count >= max_turns:
break
# Find the assistant response that follows this user message
user_index = messages.index(user_msg)
assistant_msg = None
# Look for the next assistant message after this user message
for i in range(user_index + 1, len(messages)):
if messages[i].role == 'assistant':
assistant_msg = messages[i]
break
# Add user message
user_text = f"USER: {user_msg.content}"
user_chars = len(user_text)
if char_count + user_chars > max_chars:
logger.info(f"Stopping context build: would exceed max_chars ({max_chars})")
break
recent_messages.append(user_text)
char_count += user_chars
msgs_included += 1
# Add assistant message if it exists
if assistant_msg:
assistant_text = f"ASSISTANT: {assistant_msg.content}"
assistant_chars = len(assistant_text)
if char_count + assistant_chars > max_chars:
logger.info(f"Stopping context build: would exceed max_chars ({max_chars})")
break
recent_messages.append(assistant_text)
char_count += assistant_chars
msgs_included += 1
turn_count += 1
# Add recent messages to context
context_parts.extend(recent_messages)
context = "\n\n".join(context_parts)
logger.debug(f"Built conversation context: {msgs_included} messages, {char_count} chars")
return context |