File size: 5,389 Bytes
3c3e122
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
"""
Insta-AutoApp LLM Client
HuggingFace Inference API wrapper with retry logic.
"""

import time
import logging
import requests
from typing import Optional

from config import (
    HF_API_TOKEN,
    HF_API_URL,
    HF_MODEL_ID,
    MAX_RETRIES,
    RETRY_DELAY,
    REQUEST_TIMEOUT
)

logger = logging.getLogger(__name__)


class LLMClientError(Exception):
    """Custom exception for LLM client errors."""
    pass


class LLMClient:
    """Client for HuggingFace Inference API with retry logic."""
    
    def __init__(self):
        self.api_url = HF_API_URL
        self.headers = {
            "Authorization": f"Bearer {HF_API_TOKEN}",
            "Content-Type": "application/json"
        }
        self.model_id = HF_MODEL_ID
        
        if not HF_API_TOKEN:
            logger.warning("HF_API_TOKEN not set. API calls will fail.")
    
    def _make_request(self, prompt: str, max_new_tokens: int = 1024) -> str:
        """
        Make a single request to the HuggingFace Inference API.
        
        Args:
            prompt: The full prompt to send to the model
            max_new_tokens: Maximum tokens to generate
            
        Returns:
            The generated text response
            
        Raises:
            LLMClientError: If the request fails
        """
        payload = {
            "inputs": prompt,
            "parameters": {
                "max_new_tokens": max_new_tokens,
                "temperature": 0.7,
                "top_p": 0.9,
                "do_sample": True,
                "return_full_text": False
            }
        }
        
        try:
            response = requests.post(
                self.api_url,
                headers=self.headers,
                json=payload,
                timeout=REQUEST_TIMEOUT
            )
            
            # Handle specific error codes
            if response.status_code == 401:
                raise LLMClientError("Invalid API token. Please check your HF_API_TOKEN.")
            elif response.status_code == 503:
                raise LLMClientError("Model is loading. Please try again in a moment.")
            elif response.status_code >= 500:
                raise LLMClientError(f"Server error (HTTP {response.status_code}). Retrying...")
            elif response.status_code >= 400:
                raise LLMClientError(f"Request error (HTTP {response.status_code}): {response.text}")
            
            response.raise_for_status()
            
            result = response.json()
            
            # Handle different response formats
            if isinstance(result, list) and len(result) > 0:
                if "generated_text" in result[0]:
                    return result[0]["generated_text"].strip()
                else:
                    raise LLMClientError(f"Unexpected response format: {result}")
            elif isinstance(result, dict):
                if "generated_text" in result:
                    return result["generated_text"].strip()
                elif "error" in result:
                    raise LLMClientError(f"API error: {result['error']}")
                else:
                    raise LLMClientError(f"Unexpected response format: {result}")
            else:
                raise LLMClientError(f"Unexpected response type: {type(result)}")
                
        except requests.exceptions.Timeout:
            raise LLMClientError("Request timed out. The service may be overloaded.")
        except requests.exceptions.ConnectionError:
            raise LLMClientError("Could not connect to the AI service. Please check your internet connection.")
        except requests.exceptions.RequestException as e:
            raise LLMClientError(f"Request failed: {str(e)}")
    
    def generate(self, prompt: str, max_new_tokens: int = 1024) -> Optional[str]:
        """
        Generate text with automatic retry on transient failures.
        
        Args:
            prompt: The full prompt to send to the model
            max_new_tokens: Maximum tokens to generate
            
        Returns:
            The generated text, or None if all retries fail
        """
        last_error = None
        
        for attempt in range(MAX_RETRIES):
            try:
                result = self._make_request(prompt, max_new_tokens)
                return result
                
            except LLMClientError as e:
                last_error = e
                logger.warning(f"LLM request failed (attempt {attempt + 1}/{MAX_RETRIES}): {e}")
                
                # Don't retry on auth errors
                if "Invalid API token" in str(e):
                    break
                
                # Wait before retrying (except on last attempt)
                if attempt < MAX_RETRIES - 1:
                    time.sleep(RETRY_DELAY)
        
        logger.error(f"All {MAX_RETRIES} LLM request attempts failed. Last error: {last_error}")
        return None
    
    def is_configured(self) -> bool:
        """Check if the client is properly configured with an API token."""
        return bool(HF_API_TOKEN)


# Singleton instance
_llm_client: Optional[LLMClient] = None


def get_llm_client() -> LLMClient:
    """Get the singleton LLM client instance."""
    global _llm_client
    if _llm_client is None:
        _llm_client = LLMClient()
    return _llm_client