| | |
| | """ |
| | HuggingFace Qwen 2.5 Model Client |
| | Handles inference for router, main, and complex models with cost tracking |
| | """ |
| |
|
| | import os |
| | import time |
| | import logging |
| | from typing import Dict, Any, List, Optional |
| | from dataclasses import dataclass |
| | from enum import Enum |
| |
|
| | from huggingface_hub import InferenceClient |
| | from langchain_huggingface import HuggingFaceEndpoint |
| | from langchain_core.language_models.llms import LLM |
| |
|
| | |
| | logging.basicConfig(level=logging.INFO) |
| | logger = logging.getLogger(__name__) |
| |
|
| | class ModelTier(Enum): |
| | """Model complexity tiers for cost optimization""" |
| | ROUTER = "router" |
| | MAIN = "main" |
| | COMPLEX = "complex" |
| |
|
| | @dataclass |
| | class ModelConfig: |
| | """Configuration for each model""" |
| | name: str |
| | tier: ModelTier |
| | max_tokens: int |
| | temperature: float |
| | cost_per_token: float |
| | timeout: int |
| | requires_special_auth: bool = False |
| |
|
| | @dataclass |
| | class InferenceResult: |
| | """Result of model inference with metadata""" |
| | response: str |
| | model_used: str |
| | tokens_used: int |
| | cost_estimate: float |
| | response_time: float |
| | success: bool |
| | error: Optional[str] = None |
| |
|
| | class QwenClient: |
| | """HuggingFace client with fallback model support""" |
| | |
| | def __init__(self, hf_token: Optional[str] = None): |
| | """Initialize the client with HuggingFace token""" |
| | self.hf_token = hf_token or os.getenv("HUGGINGFACE_TOKEN") or os.getenv("HF_TOKEN") |
| | if not self.hf_token: |
| | logger.warning("No HuggingFace token provided. API access may be limited.") |
| | |
| | |
| | self.models = { |
| | ModelTier.ROUTER: ModelConfig( |
| | name="google/flan-t5-small", |
| | tier=ModelTier.ROUTER, |
| | max_tokens=512, |
| | temperature=0.1, |
| | cost_per_token=0.0003, |
| | timeout=15, |
| | requires_special_auth=False |
| | ), |
| | ModelTier.MAIN: ModelConfig( |
| | name="google/flan-t5-base", |
| | tier=ModelTier.MAIN, |
| | max_tokens=1024, |
| | temperature=0.1, |
| | cost_per_token=0.0008, |
| | timeout=25, |
| | requires_special_auth=False |
| | ), |
| | ModelTier.COMPLEX: ModelConfig( |
| | name="google/flan-t5-large", |
| | tier=ModelTier.COMPLEX, |
| | max_tokens=2048, |
| | temperature=0.1, |
| | cost_per_token=0.0015, |
| | timeout=35, |
| | requires_special_auth=False |
| | ) |
| | } |
| | |
| | |
| | self.qwen_models = { |
| | ModelTier.ROUTER: ModelConfig( |
| | name="Qwen/Qwen2.5-7B-Instruct", |
| | tier=ModelTier.ROUTER, |
| | max_tokens=512, |
| | temperature=0.1, |
| | cost_per_token=0.0003, |
| | timeout=15, |
| | requires_special_auth=True |
| | ), |
| | ModelTier.MAIN: ModelConfig( |
| | name="Qwen/Qwen2.5-32B-Instruct", |
| | tier=ModelTier.MAIN, |
| | max_tokens=1024, |
| | temperature=0.1, |
| | cost_per_token=0.0008, |
| | timeout=25, |
| | requires_special_auth=True |
| | ), |
| | ModelTier.COMPLEX: ModelConfig( |
| | name="Qwen/Qwen2.5-72B-Instruct", |
| | tier=ModelTier.COMPLEX, |
| | max_tokens=2048, |
| | temperature=0.1, |
| | cost_per_token=0.0015, |
| | timeout=35, |
| | requires_special_auth=True |
| | ) |
| | } |
| | |
| | |
| | self.inference_clients = {} |
| | self.langchain_clients = {} |
| | self._initialize_clients() |
| | |
| | |
| | self.total_cost = 0.0 |
| | self.request_count = 0 |
| | self.budget_limit = 0.10 |
| | |
| | def _initialize_clients(self): |
| | """Initialize HuggingFace clients with fallback support""" |
| | |
| | |
| | if self.hf_token: |
| | logger.info("🎯 Attempting to initialize Qwen models...") |
| | qwen_success = self._try_initialize_models(self.qwen_models, "Qwen") |
| | |
| | if qwen_success: |
| | logger.info("✅ Qwen models initialized successfully") |
| | self.models = self.qwen_models |
| | return |
| | else: |
| | logger.warning("⚠️ Qwen models failed, falling back to standard models") |
| | |
| | |
| | logger.info("🔄 Initializing fallback models...") |
| | fallback_success = self._try_initialize_models(self.models, "Fallback") |
| | |
| | if not fallback_success: |
| | logger.error("❌ All model initialization failed") |
| | |
| | def _try_initialize_models(self, model_configs: Dict, model_type: str) -> bool: |
| | """Try to initialize a set of models""" |
| | success_count = 0 |
| | |
| | for tier, config in model_configs.items(): |
| | try: |
| | |
| | if config.requires_special_auth and self.hf_token: |
| | test_client = InferenceClient( |
| | model=config.name, |
| | token=self.hf_token |
| | ) |
| | |
| | |
| | try: |
| | test_response = test_client.text_generation( |
| | "Hello", |
| | max_new_tokens=5, |
| | temperature=0.1 |
| | ) |
| | logger.info(f"✅ {model_type} auth test passed for {config.name}") |
| | except Exception as auth_error: |
| | logger.warning(f"❌ {model_type} auth failed for {config.name}: {auth_error}") |
| | continue |
| | |
| | |
| | self.inference_clients[tier] = InferenceClient( |
| | model=config.name, |
| | token=self.hf_token |
| | ) |
| | |
| | self.langchain_clients[tier] = HuggingFaceEndpoint( |
| | repo_id=config.name, |
| | max_new_tokens=config.max_tokens, |
| | temperature=config.temperature, |
| | huggingfacehub_api_token=self.hf_token, |
| | timeout=config.timeout |
| | ) |
| | |
| | logger.info(f"✅ Initialized {model_type} {tier.value} model: {config.name}") |
| | success_count += 1 |
| | |
| | except Exception as e: |
| | logger.warning(f"❌ Failed to initialize {model_type} {tier.value} model: {e}") |
| | self.inference_clients[tier] = None |
| | self.langchain_clients[tier] = None |
| | |
| | return success_count > 0 |
| | |
| | def get_model_status(self) -> Dict[str, bool]: |
| | """Check which models are available""" |
| | status = {} |
| | for tier in ModelTier: |
| | status[tier.value] = ( |
| | self.inference_clients.get(tier) is not None and |
| | self.langchain_clients.get(tier) is not None |
| | ) |
| | return status |
| | |
| | def select_model_tier(self, complexity: str = "medium", budget_conscious: bool = True, question_text: str = "") -> ModelTier: |
| | """Smart model selection based on task complexity, budget, and question analysis""" |
| | |
| | |
| | budget_used_percent = (self.total_cost / self.budget_limit) * 100 |
| | |
| | if budget_conscious and budget_used_percent > 80: |
| | logger.warning(f"Budget critical ({budget_used_percent:.1f}% used), forcing router model") |
| | return ModelTier.ROUTER |
| | elif budget_conscious and budget_used_percent > 60: |
| | logger.warning(f"Budget warning ({budget_used_percent:.1f}% used), limiting complex model usage") |
| | complexity = "simple" if complexity == "complex" else complexity |
| | |
| | |
| | if question_text: |
| | question_lower = question_text.lower() |
| | |
| | |
| | complex_indicators = [ |
| | "analyze", "explain why", "reasoning", "logic", "complex", "difficult", |
| | "multi-step", "calculate and explain", "compare and contrast", |
| | "what is the relationship", "how does", "why is", "prove that", |
| | "step by step", "detailed analysis", "comprehensive" |
| | ] |
| | |
| | |
| | simple_indicators = [ |
| | "what is", "who is", "when", "where", "simple", "quick", |
| | "yes or no", "true or false", "list", "name", "find" |
| | ] |
| | |
| | |
| | math_indicators = [ |
| | "calculate", "compute", "solve", "equation", "formula", "math", |
| | "number", "total", "sum", "average", "percentage", "code", "program" |
| | ] |
| | |
| | |
| | file_indicators = [ |
| | "image", "picture", "photo", "audio", "sound", "video", "file", |
| | "document", "excel", "csv", "data", "chart", "graph" |
| | ] |
| | |
| | |
| | complex_score = sum(1 for indicator in complex_indicators if indicator in question_lower) |
| | simple_score = sum(1 for indicator in simple_indicators if indicator in question_lower) |
| | math_score = sum(1 for indicator in math_indicators if indicator in question_lower) |
| | file_score = sum(1 for indicator in file_indicators if indicator in question_lower) |
| | |
| | |
| | if complex_score >= 2 or len(question_text) > 200: |
| | complexity = "complex" |
| | elif file_score >= 1 or math_score >= 2: |
| | complexity = "medium" |
| | elif simple_score >= 2 and complex_score == 0: |
| | complexity = "simple" |
| | |
| | |
| | if complexity == "complex" and budget_used_percent < 70: |
| | selected_tier = ModelTier.COMPLEX |
| | elif complexity == "simple" or budget_used_percent > 75: |
| | selected_tier = ModelTier.ROUTER |
| | else: |
| | selected_tier = ModelTier.MAIN |
| | |
| | |
| | if not self.inference_clients.get(selected_tier): |
| | logger.warning(f"Selected model {selected_tier.value} unavailable, falling back") |
| | for fallback in [ModelTier.MAIN, ModelTier.ROUTER, ModelTier.COMPLEX]: |
| | if self.inference_clients.get(fallback): |
| | selected_tier = fallback |
| | break |
| | else: |
| | raise RuntimeError("No models available") |
| | |
| | |
| | logger.info(f"Selected {selected_tier.value} model (complexity: {complexity}, budget: {budget_used_percent:.1f}%)") |
| | return selected_tier |
| | |
| | async def generate_async(self, |
| | prompt: str, |
| | tier: Optional[ModelTier] = None, |
| | max_tokens: Optional[int] = None) -> InferenceResult: |
| | """Async text generation with the specified model tier""" |
| | |
| | if tier is None: |
| | tier = self.select_model_tier() |
| | |
| | config = self.models[tier] |
| | client = self.inference_clients.get(tier) |
| | |
| | if not client: |
| | return InferenceResult( |
| | response="", |
| | model_used=config.name, |
| | tokens_used=0, |
| | cost_estimate=0.0, |
| | response_time=0.0, |
| | success=False, |
| | error=f"Model {tier.value} not available" |
| | ) |
| | |
| | start_time = time.time() |
| | |
| | try: |
| | |
| | tokens = max_tokens or config.max_tokens |
| | |
| | |
| | if config.requires_special_auth: |
| | |
| | messages = [{"role": "user", "content": prompt}] |
| | |
| | response = client.chat_completion( |
| | messages=messages, |
| | model=config.name, |
| | max_tokens=tokens, |
| | temperature=config.temperature |
| | ) |
| | |
| | |
| | if response and response.choices: |
| | response_text = response.choices[0].message.content |
| | else: |
| | raise ValueError("No response received from model") |
| | else: |
| | |
| | |
| | formatted_prompt = f"Question: {prompt}\nAnswer:" |
| | |
| | response_text = client.text_generation( |
| | formatted_prompt, |
| | max_new_tokens=tokens, |
| | temperature=config.temperature, |
| | return_full_text=False, |
| | do_sample=True if config.temperature > 0 else False |
| | ) |
| | |
| | if not response_text or not response_text.strip(): |
| | |
| | logger.warning(f"Empty response from {config.name}, trying alternative...") |
| | response_text = client.text_generation( |
| | prompt, |
| | max_new_tokens=min(tokens, 100), |
| | temperature=0.7, |
| | return_full_text=False |
| | ) |
| | |
| | if not response_text or not response_text.strip(): |
| | raise ValueError(f"No response received from {config.name} after multiple attempts") |
| | |
| | response_time = time.time() - start_time |
| | |
| | |
| | response_text = str(response_text).strip() |
| | |
| | |
| | estimated_tokens = len(prompt.split()) + len(response_text.split()) |
| | cost_estimate = estimated_tokens * config.cost_per_token |
| | |
| | |
| | self.total_cost += cost_estimate |
| | self.request_count += 1 |
| | |
| | logger.info(f"✅ Generated response using {tier.value} model in {response_time:.2f}s") |
| | |
| | return InferenceResult( |
| | response=response_text, |
| | model_used=config.name, |
| | tokens_used=estimated_tokens, |
| | cost_estimate=cost_estimate, |
| | response_time=response_time, |
| | success=True |
| | ) |
| | |
| | except Exception as e: |
| | response_time = time.time() - start_time |
| | error_msg = str(e) |
| | |
| | |
| | if "api_key" in error_msg.lower() or "nebius" in error_msg.lower() or "unauthorized" in error_msg.lower(): |
| | logger.error(f"❌ Authentication failed with {tier.value} model: {error_msg}") |
| | |
| | |
| | if config.requires_special_auth: |
| | logger.info("🔄 Attempting to fallback to standard models due to auth failure...") |
| | self._initialize_fallback_emergency() |
| | |
| | |
| | fallback_client = self.inference_clients.get(tier) |
| | if fallback_client and not self.models[tier].requires_special_auth: |
| | logger.info(f"🔄 Retrying with fallback model...") |
| | return await self.generate_async(prompt, tier, max_tokens) |
| | else: |
| | logger.error(f"❌ Generation failed with {tier.value} model: {error_msg}") |
| | |
| | return InferenceResult( |
| | response="", |
| | model_used=config.name, |
| | tokens_used=0, |
| | cost_estimate=0.0, |
| | response_time=response_time, |
| | success=False, |
| | error=error_msg |
| | ) |
| | |
| | def _initialize_fallback_emergency(self): |
| | """Emergency fallback to standard models when auth fails""" |
| | logger.warning("🚨 Emergency fallback: Switching to standard HF models") |
| | |
| | |
| | self.models = { |
| | ModelTier.ROUTER: ModelConfig( |
| | name="google/flan-t5-small", |
| | tier=ModelTier.ROUTER, |
| | max_tokens=512, |
| | temperature=0.1, |
| | cost_per_token=0.0003, |
| | timeout=15, |
| | requires_special_auth=False |
| | ), |
| | ModelTier.MAIN: ModelConfig( |
| | name="google/flan-t5-base", |
| | tier=ModelTier.MAIN, |
| | max_tokens=1024, |
| | temperature=0.1, |
| | cost_per_token=0.0008, |
| | timeout=25, |
| | requires_special_auth=False |
| | ), |
| | ModelTier.COMPLEX: ModelConfig( |
| | name="google/flan-t5-large", |
| | tier=ModelTier.COMPLEX, |
| | max_tokens=2048, |
| | temperature=0.1, |
| | cost_per_token=0.0015, |
| | timeout=35, |
| | requires_special_auth=False |
| | ) |
| | } |
| | |
| | |
| | self._try_initialize_models(self.models, "Emergency Fallback") |
| | |
| | def generate(self, |
| | prompt: str, |
| | tier: Optional[ModelTier] = None, |
| | max_tokens: Optional[int] = None) -> InferenceResult: |
| | """Synchronous text generation (wrapper for async)""" |
| | import asyncio |
| | |
| | |
| | try: |
| | loop = asyncio.get_event_loop() |
| | except RuntimeError: |
| | loop = asyncio.new_event_loop() |
| | asyncio.set_event_loop(loop) |
| | |
| | return loop.run_until_complete( |
| | self.generate_async(prompt, tier, max_tokens) |
| | ) |
| | |
| | def get_langchain_llm(self, tier: ModelTier) -> Optional[LLM]: |
| | """Get LangChain LLM instance for agent integration""" |
| | return self.langchain_clients.get(tier) |
| | |
| | def get_usage_stats(self) -> Dict[str, Any]: |
| | """Get current usage and cost statistics""" |
| | return { |
| | "total_cost": self.total_cost, |
| | "request_count": self.request_count, |
| | "budget_limit": self.budget_limit, |
| | "budget_remaining": self.budget_limit - self.total_cost, |
| | "budget_used_percent": (self.total_cost / self.budget_limit) * 100, |
| | "average_cost_per_request": self.total_cost / max(self.request_count, 1), |
| | "models_available": self.get_model_status() |
| | } |
| | |
| | def reset_usage_tracking(self): |
| | """Reset usage statistics (for testing/development)""" |
| | self.total_cost = 0.0 |
| | self.request_count = 0 |
| | logger.info("Usage tracking reset") |
| |
|
| | |
| | def test_model_connection(client: QwenClient, tier: ModelTier): |
| | """Test connection to a specific model tier""" |
| | test_prompt = "Hello! Please respond with 'Connection successful' if you can read this." |
| | |
| | logger.info(f"Testing {tier.value} model...") |
| | result = client.generate(test_prompt, tier=tier, max_tokens=50) |
| | |
| | if result.success: |
| | logger.info(f"✅ {tier.value} model test successful: {result.response[:50]}...") |
| | logger.info(f" Response time: {result.response_time:.2f}s") |
| | logger.info(f" Cost estimate: ${result.cost_estimate:.6f}") |
| | else: |
| | logger.error(f"❌ {tier.value} model test failed: {result.error}") |
| | |
| | return result.success |
| |
|
| | def test_all_models(): |
| | """Test all available models""" |
| | logger.info("🧪 Testing all Qwen models...") |
| | |
| | client = QwenClient() |
| | |
| | results = {} |
| | for tier in ModelTier: |
| | results[tier] = test_model_connection(client, tier) |
| | |
| | logger.info("📊 Test Results Summary:") |
| | for tier, success in results.items(): |
| | status = "✅ PASS" if success else "❌ FAIL" |
| | logger.info(f" {tier.value:8}: {status}") |
| | |
| | logger.info("💰 Usage Statistics:") |
| | stats = client.get_usage_stats() |
| | for key, value in stats.items(): |
| | if key != "models_available": |
| | logger.info(f" {key}: {value}") |
| | |
| | return results |
| |
|
| | if __name__ == "__main__": |
| | |
| | from dotenv import load_dotenv |
| | load_dotenv() |
| | |
| | |
| | test_all_models() |