NeilDriscoll's picture
Upload 13 files
3c3e122 verified
"""
Insta-AutoApp LLM Client
HuggingFace Inference API wrapper with retry logic.
"""
import time
import logging
import requests
from typing import Optional
from config import (
HF_API_TOKEN,
HF_API_URL,
HF_MODEL_ID,
MAX_RETRIES,
RETRY_DELAY,
REQUEST_TIMEOUT
)
logger = logging.getLogger(__name__)
class LLMClientError(Exception):
"""Custom exception for LLM client errors."""
pass
class LLMClient:
"""Client for HuggingFace Inference API with retry logic."""
def __init__(self):
self.api_url = HF_API_URL
self.headers = {
"Authorization": f"Bearer {HF_API_TOKEN}",
"Content-Type": "application/json"
}
self.model_id = HF_MODEL_ID
if not HF_API_TOKEN:
logger.warning("HF_API_TOKEN not set. API calls will fail.")
def _make_request(self, prompt: str, max_new_tokens: int = 1024) -> str:
"""
Make a single request to the HuggingFace Inference API.
Args:
prompt: The full prompt to send to the model
max_new_tokens: Maximum tokens to generate
Returns:
The generated text response
Raises:
LLMClientError: If the request fails
"""
payload = {
"inputs": prompt,
"parameters": {
"max_new_tokens": max_new_tokens,
"temperature": 0.7,
"top_p": 0.9,
"do_sample": True,
"return_full_text": False
}
}
try:
response = requests.post(
self.api_url,
headers=self.headers,
json=payload,
timeout=REQUEST_TIMEOUT
)
# Handle specific error codes
if response.status_code == 401:
raise LLMClientError("Invalid API token. Please check your HF_API_TOKEN.")
elif response.status_code == 503:
raise LLMClientError("Model is loading. Please try again in a moment.")
elif response.status_code >= 500:
raise LLMClientError(f"Server error (HTTP {response.status_code}). Retrying...")
elif response.status_code >= 400:
raise LLMClientError(f"Request error (HTTP {response.status_code}): {response.text}")
response.raise_for_status()
result = response.json()
# Handle different response formats
if isinstance(result, list) and len(result) > 0:
if "generated_text" in result[0]:
return result[0]["generated_text"].strip()
else:
raise LLMClientError(f"Unexpected response format: {result}")
elif isinstance(result, dict):
if "generated_text" in result:
return result["generated_text"].strip()
elif "error" in result:
raise LLMClientError(f"API error: {result['error']}")
else:
raise LLMClientError(f"Unexpected response format: {result}")
else:
raise LLMClientError(f"Unexpected response type: {type(result)}")
except requests.exceptions.Timeout:
raise LLMClientError("Request timed out. The service may be overloaded.")
except requests.exceptions.ConnectionError:
raise LLMClientError("Could not connect to the AI service. Please check your internet connection.")
except requests.exceptions.RequestException as e:
raise LLMClientError(f"Request failed: {str(e)}")
def generate(self, prompt: str, max_new_tokens: int = 1024) -> Optional[str]:
"""
Generate text with automatic retry on transient failures.
Args:
prompt: The full prompt to send to the model
max_new_tokens: Maximum tokens to generate
Returns:
The generated text, or None if all retries fail
"""
last_error = None
for attempt in range(MAX_RETRIES):
try:
result = self._make_request(prompt, max_new_tokens)
return result
except LLMClientError as e:
last_error = e
logger.warning(f"LLM request failed (attempt {attempt + 1}/{MAX_RETRIES}): {e}")
# Don't retry on auth errors
if "Invalid API token" in str(e):
break
# Wait before retrying (except on last attempt)
if attempt < MAX_RETRIES - 1:
time.sleep(RETRY_DELAY)
logger.error(f"All {MAX_RETRIES} LLM request attempts failed. Last error: {last_error}")
return None
def is_configured(self) -> bool:
"""Check if the client is properly configured with an API token."""
return bool(HF_API_TOKEN)
# Singleton instance
_llm_client: Optional[LLMClient] = None
def get_llm_client() -> LLMClient:
"""Get the singleton LLM client instance."""
global _llm_client
if _llm_client is None:
_llm_client = LLMClient()
return _llm_client