| """ |
| Insta-AutoApp LLM Client |
| HuggingFace Inference API wrapper with retry logic. |
| """ |
|
|
| import time |
| import logging |
| import requests |
| from typing import Optional |
|
|
| from config import ( |
| HF_API_TOKEN, |
| HF_API_URL, |
| HF_MODEL_ID, |
| MAX_RETRIES, |
| RETRY_DELAY, |
| REQUEST_TIMEOUT |
| ) |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| class LLMClientError(Exception): |
| """Custom exception for LLM client errors.""" |
| pass |
|
|
|
|
| class LLMClient: |
| """Client for HuggingFace Inference API with retry logic.""" |
| |
| def __init__(self): |
| self.api_url = HF_API_URL |
| self.headers = { |
| "Authorization": f"Bearer {HF_API_TOKEN}", |
| "Content-Type": "application/json" |
| } |
| self.model_id = HF_MODEL_ID |
| |
| if not HF_API_TOKEN: |
| logger.warning("HF_API_TOKEN not set. API calls will fail.") |
| |
| def _make_request(self, prompt: str, max_new_tokens: int = 1024) -> str: |
| """ |
| Make a single request to the HuggingFace Inference API. |
| |
| Args: |
| prompt: The full prompt to send to the model |
| max_new_tokens: Maximum tokens to generate |
| |
| Returns: |
| The generated text response |
| |
| Raises: |
| LLMClientError: If the request fails |
| """ |
| payload = { |
| "inputs": prompt, |
| "parameters": { |
| "max_new_tokens": max_new_tokens, |
| "temperature": 0.7, |
| "top_p": 0.9, |
| "do_sample": True, |
| "return_full_text": False |
| } |
| } |
| |
| try: |
| response = requests.post( |
| self.api_url, |
| headers=self.headers, |
| json=payload, |
| timeout=REQUEST_TIMEOUT |
| ) |
| |
| |
| if response.status_code == 401: |
| raise LLMClientError("Invalid API token. Please check your HF_API_TOKEN.") |
| elif response.status_code == 503: |
| raise LLMClientError("Model is loading. Please try again in a moment.") |
| elif response.status_code >= 500: |
| raise LLMClientError(f"Server error (HTTP {response.status_code}). Retrying...") |
| elif response.status_code >= 400: |
| raise LLMClientError(f"Request error (HTTP {response.status_code}): {response.text}") |
| |
| response.raise_for_status() |
| |
| result = response.json() |
| |
| |
| if isinstance(result, list) and len(result) > 0: |
| if "generated_text" in result[0]: |
| return result[0]["generated_text"].strip() |
| else: |
| raise LLMClientError(f"Unexpected response format: {result}") |
| elif isinstance(result, dict): |
| if "generated_text" in result: |
| return result["generated_text"].strip() |
| elif "error" in result: |
| raise LLMClientError(f"API error: {result['error']}") |
| else: |
| raise LLMClientError(f"Unexpected response format: {result}") |
| else: |
| raise LLMClientError(f"Unexpected response type: {type(result)}") |
| |
| except requests.exceptions.Timeout: |
| raise LLMClientError("Request timed out. The service may be overloaded.") |
| except requests.exceptions.ConnectionError: |
| raise LLMClientError("Could not connect to the AI service. Please check your internet connection.") |
| except requests.exceptions.RequestException as e: |
| raise LLMClientError(f"Request failed: {str(e)}") |
| |
| def generate(self, prompt: str, max_new_tokens: int = 1024) -> Optional[str]: |
| """ |
| Generate text with automatic retry on transient failures. |
| |
| Args: |
| prompt: The full prompt to send to the model |
| max_new_tokens: Maximum tokens to generate |
| |
| Returns: |
| The generated text, or None if all retries fail |
| """ |
| last_error = None |
| |
| for attempt in range(MAX_RETRIES): |
| try: |
| result = self._make_request(prompt, max_new_tokens) |
| return result |
| |
| except LLMClientError as e: |
| last_error = e |
| logger.warning(f"LLM request failed (attempt {attempt + 1}/{MAX_RETRIES}): {e}") |
| |
| |
| if "Invalid API token" in str(e): |
| break |
| |
| |
| if attempt < MAX_RETRIES - 1: |
| time.sleep(RETRY_DELAY) |
| |
| logger.error(f"All {MAX_RETRIES} LLM request attempts failed. Last error: {last_error}") |
| return None |
| |
| def is_configured(self) -> bool: |
| """Check if the client is properly configured with an API token.""" |
| return bool(HF_API_TOKEN) |
|
|
|
|
| |
| _llm_client: Optional[LLMClient] = None |
|
|
|
|
| def get_llm_client() -> LLMClient: |
| """Get the singleton LLM client instance.""" |
| global _llm_client |
| if _llm_client is None: |
| _llm_client = LLMClient() |
| return _llm_client |
|
|