File size: 5,389 Bytes
3c3e122 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 | """
Insta-AutoApp LLM Client
HuggingFace Inference API wrapper with retry logic.
"""
import time
import logging
import requests
from typing import Optional
from config import (
HF_API_TOKEN,
HF_API_URL,
HF_MODEL_ID,
MAX_RETRIES,
RETRY_DELAY,
REQUEST_TIMEOUT
)
logger = logging.getLogger(__name__)
class LLMClientError(Exception):
"""Custom exception for LLM client errors."""
pass
class LLMClient:
"""Client for HuggingFace Inference API with retry logic."""
def __init__(self):
self.api_url = HF_API_URL
self.headers = {
"Authorization": f"Bearer {HF_API_TOKEN}",
"Content-Type": "application/json"
}
self.model_id = HF_MODEL_ID
if not HF_API_TOKEN:
logger.warning("HF_API_TOKEN not set. API calls will fail.")
def _make_request(self, prompt: str, max_new_tokens: int = 1024) -> str:
"""
Make a single request to the HuggingFace Inference API.
Args:
prompt: The full prompt to send to the model
max_new_tokens: Maximum tokens to generate
Returns:
The generated text response
Raises:
LLMClientError: If the request fails
"""
payload = {
"inputs": prompt,
"parameters": {
"max_new_tokens": max_new_tokens,
"temperature": 0.7,
"top_p": 0.9,
"do_sample": True,
"return_full_text": False
}
}
try:
response = requests.post(
self.api_url,
headers=self.headers,
json=payload,
timeout=REQUEST_TIMEOUT
)
# Handle specific error codes
if response.status_code == 401:
raise LLMClientError("Invalid API token. Please check your HF_API_TOKEN.")
elif response.status_code == 503:
raise LLMClientError("Model is loading. Please try again in a moment.")
elif response.status_code >= 500:
raise LLMClientError(f"Server error (HTTP {response.status_code}). Retrying...")
elif response.status_code >= 400:
raise LLMClientError(f"Request error (HTTP {response.status_code}): {response.text}")
response.raise_for_status()
result = response.json()
# Handle different response formats
if isinstance(result, list) and len(result) > 0:
if "generated_text" in result[0]:
return result[0]["generated_text"].strip()
else:
raise LLMClientError(f"Unexpected response format: {result}")
elif isinstance(result, dict):
if "generated_text" in result:
return result["generated_text"].strip()
elif "error" in result:
raise LLMClientError(f"API error: {result['error']}")
else:
raise LLMClientError(f"Unexpected response format: {result}")
else:
raise LLMClientError(f"Unexpected response type: {type(result)}")
except requests.exceptions.Timeout:
raise LLMClientError("Request timed out. The service may be overloaded.")
except requests.exceptions.ConnectionError:
raise LLMClientError("Could not connect to the AI service. Please check your internet connection.")
except requests.exceptions.RequestException as e:
raise LLMClientError(f"Request failed: {str(e)}")
def generate(self, prompt: str, max_new_tokens: int = 1024) -> Optional[str]:
"""
Generate text with automatic retry on transient failures.
Args:
prompt: The full prompt to send to the model
max_new_tokens: Maximum tokens to generate
Returns:
The generated text, or None if all retries fail
"""
last_error = None
for attempt in range(MAX_RETRIES):
try:
result = self._make_request(prompt, max_new_tokens)
return result
except LLMClientError as e:
last_error = e
logger.warning(f"LLM request failed (attempt {attempt + 1}/{MAX_RETRIES}): {e}")
# Don't retry on auth errors
if "Invalid API token" in str(e):
break
# Wait before retrying (except on last attempt)
if attempt < MAX_RETRIES - 1:
time.sleep(RETRY_DELAY)
logger.error(f"All {MAX_RETRIES} LLM request attempts failed. Last error: {last_error}")
return None
def is_configured(self) -> bool:
"""Check if the client is properly configured with an API token."""
return bool(HF_API_TOKEN)
# Singleton instance
_llm_client: Optional[LLMClient] = None
def get_llm_client() -> LLMClient:
"""Get the singleton LLM client instance."""
global _llm_client
if _llm_client is None:
_llm_client = LLMClient()
return _llm_client
|