| |
| |
| |
|
|
| import os |
| import logging |
| from typing import List, Dict, Any, Optional |
| import random |
|
|
| from huggingface_hub import InferenceClient |
|
|
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| class QwenClient: |
| """ |
| Hugging Face Qwen model client with retry logic and timeout handling. |
| |
| Implements exponential backoff retry strategy for transient failures. |
| """ |
|
|
| def __init__( |
| self, |
| model: str = None, |
| timeout: int = 8, |
| max_retries: int = 3 |
| ): |
| """ |
| Initialize Qwen client. |
| |
| Args: |
| model: Qwen model name (from env or default) |
| timeout: Request timeout in seconds |
| max_retries: Maximum number of retry attempts |
| """ |
| self.model = model or os.getenv("QWEN_MODEL", "Qwen/Qwen-14B-Chat") |
| self.timeout = timeout |
| self.max_retries = max_retries |
| self.api_key = os.getenv("HUGGINGFACE_API_KEY") |
|
|
| if not self.api_key: |
| raise ValueError( |
| "HUGGINGFACE_API_KEY not found in environment variables. " |
| "Please set it in your .env file." |
| ) |
|
|
| |
| self.client = InferenceClient(model=self.model, token=self.api_key) |
|
|
| logger.info(f"Qwen client initialized with model: {self.model}, timeout: {timeout}s") |
|
|
| def generate( |
| self, |
| messages: List[Dict[str, str]], |
| temperature: float = 0.7, |
| max_tokens: int = 1024 |
| ) -> str: |
| """ |
| Generate response from Qwen model with retry logic. |
| |
| Args: |
| messages: Chat messages array (OpenAI format) |
| temperature: Sampling temperature |
| max_tokens: Maximum tokens to generate |
| |
| Returns: |
| Generated text response |
| |
| Raises: |
| Exception: If all retries exhausted |
| """ |
| import time |
|
|
| for attempt in range(self.max_retries): |
| try: |
| logger.info(f"Qwen inference attempt {attempt + 1}/{self.max_retries}") |
|
|
| |
| prompt = self._build_prompt(messages) |
|
|
| |
| response = self.client.text_generation( |
| prompt=prompt, |
| temperature=temperature, |
| max_new_tokens=max_tokens, |
| do_sample=True, |
| stream=False |
| ) |
|
|
| logger.info("Qwen inference successful") |
| return response.strip() |
|
|
| except Exception as e: |
| logger.error(f"Qwen inference failed on attempt {attempt + 1}: {str(e)}") |
| if attempt == self.max_retries - 1: |
| raise |
|
|
| |
| if "429" in str(e) or "rate limit" in str(e).lower(): |
| logger.warning("Rate limit detected, waiting 60 seconds...") |
| time.sleep(60) |
| else: |
| wait_time = (2 ** attempt) + random.uniform(0, 1) |
| logger.info(f"Retrying in {wait_time:.2f}s...") |
| time.sleep(wait_time) |
|
|
| def _build_prompt(self, messages: List[Dict[str, str]]) -> str: |
| """ |
| Build prompt from message array for Qwen. |
| |
| Args: |
| messages: Chat messages in OpenAI format |
| |
| Returns: |
| Formatted prompt string |
| """ |
| prompt_parts = [] |
|
|
| for msg in messages: |
| role = msg.get("role", "user") |
| content = msg.get("content", "") |
|
|
| if role == "system": |
| prompt_parts.append(f"System: {content}") |
| elif role == "user": |
| prompt_parts.append(f"User: {content}") |
| elif role == "assistant": |
| prompt_parts.append(f"Assistant: {content}") |
|
|
| prompt = "\n".join(prompt_parts) |
| prompt += "\nAssistant:" |
|
|
| return prompt |
|
|