File size: 8,905 Bytes
3194955
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
import os
import configparser
import logging
logger = logging.getLogger(__name__)
import requests # Need this for _call_hf_endpoint, but we will define the function here
import httpx
import json
from typing import Dict, Any, List
from langchain_core.documents import Document

from dotenv import load_dotenv
load_dotenv()  # Load environment variables from .env file if present

def getconfig(configfile_path: str) -> configparser.ConfigParser:
    """Reads the config file."""
    config = configparser.ConfigParser()
    try:
        config.read_file(open(configfile_path))
        return config
    except FileNotFoundError:
        logger.error(f"Warning: Config file not found at {configfile_path}. Relying on environment variables.")
        return configparser.ConfigParser()


def get_auth_for_generator(provider: str) -> dict:
    """Get authentication configuration for different providers"""
    auth_configs = {
        "openai": {"api_key": os.getenv("OPENAI_API_KEY")},
        "huggingface": {"api_key": os.getenv("HF_TOKEN")},
        "anthropic": {"api_key": os.getenv("ANTHROPIC_API_KEY")},
        "cohere": {"api_key": os.getenv("COHERE_API_KEY")},
    }
    
    if provider not in auth_configs:
        logger.error(f"Unsupported provider: {provider}")
        raise ValueError(f"Unsupported provider: {provider}")
    
    auth_config = auth_configs[provider]
    api_key = auth_config.get("api_key")
    
    if not api_key:
        logger.error(f"Missing API key for provider '{provider}'. Please set the appropriate environment variable.")
        raise RuntimeError(f"Missing API key for provider '{provider}'. Please set the appropriate environment variable.")
    
    return auth_config



def get_config_value(config: configparser.ConfigParser, section: str, key: str, env_var: str, fallback: Any = None) -> Any:
    """
    Retrieves a config value, prioritizing: Environment Variable > Config File > Fallback.
    """
    # 1. Check Environment Variable (Highest Priority)
    env_value = os.getenv(env_var)
    if env_value is not None:
        return env_value
        
    # 2. Check Config File
    try:
        return config.get(section, key)
    except (configparser.NoSectionError, configparser.NoOptionError):
        # 3. Use Fallback
        if fallback is not None:
            return fallback
        
        # 4. Error if essential config is missing
        logger.error(f"Configuration missing: Required value for [{section}]{key} was not found ")
        logger.error(f"in 'params.cfg' and the environment variable {env_var} is not set.")
        raise ValueError(
            f"Configuration missing: Required value for [{section}]{key} was not found "
            f"in 'params.cfg' and the environment variable {env_var} is not set."
        )

def _call_hf_endpoint(url: str, token: str, payload: Dict[str, Any]) -> Dict[str, Any]:
    """Helper for making authenticated requests to Hugging Face Endpoints."""
    headers = {
        "Authorization": f"Bearer {token}",
        "Content-Type": "application/json",
    }
    try:
        logger.info(f"Calling endpoint {url}")
        response = requests.post(url, headers=headers, json=payload, timeout=60)
        if response.status_code == 503:
                logger.warning(f"HF Endpoint 503: Service overloaded/starting up at {url}")
                raise Exception("HF Service Unavailable (503)")
        elif response.status_code == 404:
            logger.error(f"HF Endpoint 404: Model not found at {url}")
            raise Exception("HF Model Not Found (404)")
        response.raise_for_status()
        return response.json()
    except requests.exceptions.RequestException as e:
        logger.error(f"Error calling HF endpoint ({url}): {e}")
        raise

async def _acall_hf_endpoint(url: str, token: str, payload: Dict[str, Any]) -> Dict[str, Any]:
    """Asynchronously calls a Hugging Face Inference Endpoint using httpx."""
    headers = {
        "Authorization": f"Bearer {token}",
        "Content-Type": "application/json",
    }
    # Use httpx.AsyncClient for asynchronous requests
    async with httpx.AsyncClient(timeout=60.0) as client:
        try:
            logger.info(f"Async Calling endpoint {url}")
            response = await client.post(url, headers=headers, json=payload)
            if response.status_code == 503:
                logger.warning(f"HF Endpoint 503: Service overloaded/starting up at {url}")
                raise Exception("HF Service Unavailable (503)")
            elif response.status_code == 404:
                logger.error(f"HF Endpoint 404: Model not found at {url}")
                raise Exception("HF Model Not Found (404)")
            response.raise_for_status()
            return response.json()
        except httpx.RequestError as e:
            logger.error(f"Error calling HF endpoint ({url}): {e}")
            raise


def build_conversation_context(messages: List, max_turns: int = 3, max_chars: int = 8000) -> str:
    """
    Build conversation context from structured messages to send to generator.
    Always keeps the first user and assistant messages, plus the last N turns.

    A "turn" is one user message + following assistant response.

    Args:
        messages: List of Message objects with 'role' and 'content' attributes
        max_turns: Maximum number of user-assistant exchange pairs to include (from the end)
        max_chars: Maximum total characters in context (default 8000)

    Returns:
        Formatted conversation context string like:
        "USER: query1\nASSISTANT: response1\n\nUSER: query2\nASSISTANT: response2"
    """
    if not messages:
        return ""

    context_parts = []
    char_count = 0
    msgs_included = 0

    # Always include the first user and assistant messages
    first_user_msg = None
    first_assistant_msg = None

    # Find first user and assistant messages
    for msg in messages:
        if msg.role == 'user' and first_user_msg is None:
            first_user_msg = msg
        elif msg.role == 'assistant' and first_assistant_msg is None:
            first_assistant_msg = msg
        if first_user_msg and first_assistant_msg:
            break

    # Add first messages if they exist
    if first_user_msg:
        msg_text = f"USER: {first_user_msg.content}"
        msg_chars = len(msg_text)
        if char_count + msg_chars <= max_chars:
            context_parts.append(msg_text)
            char_count += msg_chars
            msgs_included += 1

    if first_assistant_msg:
        msg_text = f"ASSISTANT: {first_assistant_msg.content}"
        msg_chars = len(msg_text)
        if char_count + msg_chars <= max_chars:
            context_parts.append(msg_text)
            char_count += msg_chars
            msgs_included += 1

    # Collect last N complete turns (user + assistant pairs)
    # Find the last N user messages and their corresponding assistant responses
    user_messages = [msg for msg in messages if msg.role == 'user']

    # Get the last N user messages (excluding the first one we already included)
    recent_user_messages = user_messages[1:][-max_turns:] if len(user_messages) > 1 else []

    turn_count = 0
    recent_messages = []

    # Process each recent user message and find its corresponding assistant response
    for user_msg in recent_user_messages:
        if turn_count >= max_turns:
            break

        # Find the assistant response that follows this user message
        user_index = messages.index(user_msg)
        assistant_msg = None

        # Look for the next assistant message after this user message
        for i in range(user_index + 1, len(messages)):
            if messages[i].role == 'assistant':
                assistant_msg = messages[i]
                break

        # Add user message
        user_text = f"USER: {user_msg.content}"
        user_chars = len(user_text)

        if char_count + user_chars > max_chars:
            logger.info(f"Stopping context build: would exceed max_chars ({max_chars})")
            break

        recent_messages.append(user_text)
        char_count += user_chars
        msgs_included += 1

        # Add assistant message if it exists
        if assistant_msg:
            assistant_text = f"ASSISTANT: {assistant_msg.content}"
            assistant_chars = len(assistant_text)

            if char_count + assistant_chars > max_chars:
                logger.info(f"Stopping context build: would exceed max_chars ({max_chars})")
                break

            recent_messages.append(assistant_text)
            char_count += assistant_chars
            msgs_included += 1

        turn_count += 1

    # Add recent messages to context
    context_parts.extend(recent_messages)

    context = "\n\n".join(context_parts)

    logger.debug(f"Built conversation context: {msgs_included} messages, {char_count} chars")
    return context