Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| HuggingFace Qwen 2.5 Model Client | |
| Handles inference for router, main, and complex models with cost tracking | |
| """ | |
| import os | |
| import time | |
| import logging | |
| from typing import Dict, Any, List, Optional | |
| from dataclasses import dataclass | |
| from enum import Enum | |
| from huggingface_hub import InferenceClient | |
| from langchain_huggingface import HuggingFaceEndpoint | |
| from langchain_core.language_models.llms import LLM | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| class ModelTier(Enum): | |
| """Model complexity tiers for cost optimization""" | |
| ROUTER = "router" # Fast, cheap routing decisions | |
| MAIN = "main" # Balanced performance | |
| COMPLEX = "complex" # Best performance for hard tasks | |
| class ModelConfig: | |
| """Configuration for each model""" | |
| name: str | |
| tier: ModelTier | |
| max_tokens: int | |
| temperature: float | |
| cost_per_token: float # Estimated cost per token | |
| timeout: int | |
| requires_special_auth: bool = False # For Nebius API models | |
| class InferenceResult: | |
| """Result of model inference with metadata""" | |
| response: str | |
| model_used: str | |
| tokens_used: int | |
| cost_estimate: float | |
| response_time: float | |
| success: bool | |
| error: Optional[str] = None | |
| class QwenClient: | |
| """HuggingFace client with fallback model support""" | |
| def __init__(self, hf_token: Optional[str] = None): | |
| """Initialize the client with HuggingFace token for Qwen models only""" | |
| self.hf_token = hf_token or os.getenv("HUGGINGFACE_TOKEN") or os.getenv("HF_TOKEN") | |
| if not self.hf_token: | |
| raise ValueError("HuggingFace token is required for Qwen model access. Please provide HF_TOKEN or login with inference permissions.") | |
| # Initialize cost tracking first | |
| self.total_cost = 0.0 | |
| self.request_count = 0 | |
| self.budget_limit = 0.10 # $0.10 total budget | |
| # Define Qwen model configurations (only these models) | |
| self.models = { | |
| ModelTier.ROUTER: ModelConfig( | |
| name="Qwen/Qwen2.5-7B-Instruct", | |
| tier=ModelTier.ROUTER, | |
| max_tokens=512, | |
| temperature=0.1, | |
| cost_per_token=0.0003, | |
| timeout=15, | |
| requires_special_auth=True | |
| ), | |
| ModelTier.MAIN: ModelConfig( | |
| name="Qwen/Qwen2.5-32B-Instruct", | |
| tier=ModelTier.MAIN, | |
| max_tokens=1024, | |
| temperature=0.1, | |
| cost_per_token=0.0008, | |
| timeout=25, | |
| requires_special_auth=True | |
| ), | |
| ModelTier.COMPLEX: ModelConfig( | |
| name="Qwen/Qwen2.5-72B-Instruct", | |
| tier=ModelTier.COMPLEX, | |
| max_tokens=2048, | |
| temperature=0.1, | |
| cost_per_token=0.0015, | |
| timeout=35, | |
| requires_special_auth=True | |
| ) | |
| } | |
| # Initialize clients | |
| self.inference_clients = {} | |
| self.langchain_clients = {} | |
| self._initialize_clients() | |
| def _initialize_clients(self): | |
| """Initialize HuggingFace clients for Qwen models only""" | |
| logger.info("🎯 Initializing Qwen models via HuggingFace Inference API...") | |
| success = self._try_initialize_models(self.models, "Qwen") | |
| if not success: | |
| raise RuntimeError("Failed to initialize any Qwen models. Please check your HF_TOKEN has inference permissions and try again.") | |
| # Test the main model to ensure it's working | |
| logger.info("🧪 Testing Qwen model connectivity...") | |
| try: | |
| test_result = self.generate("Hello", max_tokens=10) | |
| if test_result.success and test_result.response.strip(): | |
| logger.info(f"✅ Qwen models ready: '{test_result.response.strip()}'") | |
| else: | |
| logger.error(f"❌ Qwen model test failed: {test_result}") | |
| raise RuntimeError("Qwen models failed connectivity test") | |
| except Exception as e: | |
| logger.error(f"❌ Qwen model test exception: {e}") | |
| raise RuntimeError(f"Qwen model initialization failed: {e}") | |
| def _try_initialize_models(self, model_configs: Dict, model_type: str) -> bool: | |
| """Try to initialize Qwen models""" | |
| success_count = 0 | |
| for tier, config in model_configs.items(): | |
| try: | |
| # Test Qwen model authentication | |
| test_client = InferenceClient( | |
| model=config.name, | |
| token=self.hf_token | |
| ) | |
| # Quick test to verify authentication and model access | |
| try: | |
| test_response = test_client.chat_completion( | |
| messages=[{"role": "user", "content": "Hello"}], | |
| model=config.name, | |
| max_tokens=5, | |
| temperature=0.1 | |
| ) | |
| logger.info(f"✅ {model_type} auth test passed for {config.name}") | |
| except Exception as auth_error: | |
| logger.warning(f"❌ {model_type} auth failed for {config.name}: {auth_error}") | |
| continue | |
| # Initialize the clients | |
| self.inference_clients[tier] = InferenceClient( | |
| model=config.name, | |
| token=self.hf_token | |
| ) | |
| self.langchain_clients[tier] = HuggingFaceEndpoint( | |
| repo_id=config.name, | |
| max_new_tokens=config.max_tokens, | |
| temperature=config.temperature, | |
| huggingfacehub_api_token=self.hf_token, | |
| timeout=config.timeout | |
| ) | |
| logger.info(f"✅ Initialized {model_type} {tier.value} model: {config.name}") | |
| success_count += 1 | |
| except Exception as e: | |
| logger.warning(f"❌ Failed to initialize {model_type} {tier.value} model: {e}") | |
| self.inference_clients[tier] = None | |
| self.langchain_clients[tier] = None | |
| return success_count > 0 | |
| def get_model_status(self) -> Dict[str, bool]: | |
| """Check which models are available""" | |
| status = {} | |
| for tier in ModelTier: | |
| status[tier.value] = ( | |
| self.inference_clients.get(tier) is not None and | |
| self.langchain_clients.get(tier) is not None | |
| ) | |
| return status | |
| def select_model_tier(self, complexity: str = "medium", budget_conscious: bool = True, question_text: str = "") -> ModelTier: | |
| """Smart model selection based on task complexity, budget, and question analysis""" | |
| # Check budget constraints | |
| budget_used_percent = (self.total_cost / self.budget_limit) * 100 | |
| if budget_conscious and budget_used_percent > 80: | |
| logger.warning(f"Budget critical ({budget_used_percent:.1f}% used), forcing router model") | |
| return ModelTier.ROUTER | |
| elif budget_conscious and budget_used_percent > 60: | |
| logger.warning(f"Budget warning ({budget_used_percent:.1f}% used), limiting complex model usage") | |
| complexity = "simple" if complexity == "complex" else complexity | |
| # Enhanced complexity analysis based on question content | |
| if question_text: | |
| question_lower = question_text.lower() | |
| # Indicators for complex reasoning (use 72B model) | |
| complex_indicators = [ | |
| "analyze", "explain why", "reasoning", "logic", "complex", "difficult", | |
| "multi-step", "calculate and explain", "compare and contrast", | |
| "what is the relationship", "how does", "why is", "prove that", | |
| "step by step", "detailed analysis", "comprehensive" | |
| ] | |
| # Indicators for simple tasks (use 7B model) | |
| simple_indicators = [ | |
| "what is", "who is", "when", "where", "simple", "quick", | |
| "yes or no", "true or false", "list", "name", "find" | |
| ] | |
| # Math and coding indicators (use 32B model - good balance) | |
| math_indicators = [ | |
| "calculate", "compute", "solve", "equation", "formula", "math", | |
| "number", "total", "sum", "average", "percentage", "code", "program" | |
| ] | |
| # File processing indicators (use 32B+ models) | |
| file_indicators = [ | |
| "image", "picture", "photo", "audio", "sound", "video", "file", | |
| "document", "excel", "csv", "data", "chart", "graph" | |
| ] | |
| # Count indicators | |
| complex_score = sum(1 for indicator in complex_indicators if indicator in question_lower) | |
| simple_score = sum(1 for indicator in simple_indicators if indicator in question_lower) | |
| math_score = sum(1 for indicator in math_indicators if indicator in question_lower) | |
| file_score = sum(1 for indicator in file_indicators if indicator in question_lower) | |
| # Auto-detect complexity based on content | |
| if complex_score >= 2 or len(question_text) > 200: | |
| complexity = "complex" | |
| elif file_score >= 1 or math_score >= 2: | |
| complexity = "medium" | |
| elif simple_score >= 2 and complex_score == 0: | |
| complexity = "simple" | |
| # Select based on complexity with budget awareness | |
| if complexity == "complex" and budget_used_percent < 70: | |
| selected_tier = ModelTier.COMPLEX | |
| elif complexity == "simple" or budget_used_percent > 75: | |
| selected_tier = ModelTier.ROUTER | |
| else: | |
| selected_tier = ModelTier.MAIN | |
| # Fallback if selected model unavailable | |
| if not self.inference_clients.get(selected_tier): | |
| logger.warning(f"Selected model {selected_tier.value} unavailable, falling back") | |
| for fallback in [ModelTier.MAIN, ModelTier.ROUTER, ModelTier.COMPLEX]: | |
| if self.inference_clients.get(fallback): | |
| selected_tier = fallback | |
| break | |
| else: | |
| raise RuntimeError("No models available") | |
| # Log selection reasoning | |
| logger.info(f"Selected {selected_tier.value} model (complexity: {complexity}, budget: {budget_used_percent:.1f}%)") | |
| return selected_tier | |
| async def generate_async(self, | |
| prompt: str, | |
| tier: Optional[ModelTier] = None, | |
| max_tokens: Optional[int] = None) -> InferenceResult: | |
| """Async text generation with Qwen models via HuggingFace Inference API""" | |
| if tier is None: | |
| tier = self.select_model_tier(question_text=prompt) | |
| config = self.models[tier] | |
| client = self.inference_clients.get(tier) | |
| if not client: | |
| return InferenceResult( | |
| response="", | |
| model_used=config.name, | |
| tokens_used=0, | |
| cost_estimate=0.0, | |
| response_time=0.0, | |
| success=False, | |
| error=f"Qwen model {tier.value} not available" | |
| ) | |
| start_time = time.time() | |
| try: | |
| # Use specified max_tokens or model default | |
| tokens = max_tokens or config.max_tokens | |
| # Qwen models use chat completion API | |
| messages = [{"role": "user", "content": prompt}] | |
| logger.info(f"🤖 Generating with {config.name}...") | |
| response = client.chat_completion( | |
| messages=messages, | |
| model=config.name, | |
| max_tokens=tokens, | |
| temperature=config.temperature | |
| ) | |
| # Extract response from chat completion | |
| if response and response.choices: | |
| response_text = response.choices[0].message.content | |
| else: | |
| raise ValueError(f"No response received from {config.name}") | |
| response_time = time.time() - start_time | |
| # Clean up response text | |
| response_text = str(response_text).strip() | |
| if not response_text: | |
| raise ValueError(f"Empty response from {config.name}") | |
| # Estimate tokens used (rough approximation) | |
| estimated_tokens = len(prompt.split()) + len(response_text.split()) | |
| cost_estimate = estimated_tokens * config.cost_per_token | |
| # Update tracking | |
| self.total_cost += cost_estimate | |
| self.request_count += 1 | |
| logger.info(f"✅ Generated with {tier.value} model in {response_time:.2f}s") | |
| return InferenceResult( | |
| response=response_text, | |
| model_used=config.name, | |
| tokens_used=estimated_tokens, | |
| cost_estimate=cost_estimate, | |
| response_time=response_time, | |
| success=True | |
| ) | |
| except Exception as e: | |
| response_time = time.time() - start_time | |
| error_msg = str(e) | |
| logger.error(f"❌ Generation failed with {tier.value} model ({config.name}): {error_msg}") | |
| return InferenceResult( | |
| response="", | |
| model_used=config.name, | |
| tokens_used=0, | |
| cost_estimate=0.0, | |
| response_time=response_time, | |
| success=False, | |
| error=error_msg | |
| ) | |
| def generate(self, | |
| prompt: str, | |
| tier: Optional[ModelTier] = None, | |
| max_tokens: Optional[int] = None) -> InferenceResult: | |
| """Synchronous text generation (wrapper for async)""" | |
| import asyncio | |
| # Create event loop if needed | |
| try: | |
| loop = asyncio.get_event_loop() | |
| except RuntimeError: | |
| loop = asyncio.new_event_loop() | |
| asyncio.set_event_loop(loop) | |
| return loop.run_until_complete( | |
| self.generate_async(prompt, tier, max_tokens) | |
| ) | |
| def get_langchain_llm(self, tier: ModelTier) -> Optional[LLM]: | |
| """Get LangChain LLM instance for agent integration""" | |
| return self.langchain_clients.get(tier) | |
| def get_usage_stats(self) -> Dict[str, Any]: | |
| """Get current usage and cost statistics""" | |
| return { | |
| "total_cost": self.total_cost, | |
| "request_count": self.request_count, | |
| "budget_limit": self.budget_limit, | |
| "budget_remaining": self.budget_limit - self.total_cost, | |
| "budget_used_percent": (self.total_cost / self.budget_limit) * 100, | |
| "average_cost_per_request": self.total_cost / max(self.request_count, 1), | |
| "models_available": self.get_model_status() | |
| } | |
| def reset_usage_tracking(self): | |
| """Reset usage statistics (for testing/development)""" | |
| self.total_cost = 0.0 | |
| self.request_count = 0 | |
| logger.info("Usage tracking reset") | |
| # Test functions | |
| def test_model_connection(client: QwenClient, tier: ModelTier): | |
| """Test connection to a specific model tier""" | |
| test_prompt = "Hello! Please respond with 'Connection successful' if you can read this." | |
| logger.info(f"Testing {tier.value} model...") | |
| result = client.generate(test_prompt, tier=tier, max_tokens=50) | |
| if result.success: | |
| logger.info(f"✅ {tier.value} model test successful: {result.response[:50]}...") | |
| logger.info(f" Response time: {result.response_time:.2f}s") | |
| logger.info(f" Cost estimate: ${result.cost_estimate:.6f}") | |
| else: | |
| logger.error(f"❌ {tier.value} model test failed: {result.error}") | |
| return result.success | |
| def test_all_models(): | |
| """Test all available models""" | |
| logger.info("🧪 Testing all Qwen models...") | |
| client = QwenClient() | |
| results = {} | |
| for tier in ModelTier: | |
| results[tier] = test_model_connection(client, tier) | |
| logger.info("📊 Test Results Summary:") | |
| for tier, success in results.items(): | |
| status = "✅ PASS" if success else "❌ FAIL" | |
| logger.info(f" {tier.value:8}: {status}") | |
| logger.info("💰 Usage Statistics:") | |
| stats = client.get_usage_stats() | |
| for key, value in stats.items(): | |
| if key != "models_available": | |
| logger.info(f" {key}: {value}") | |
| return results | |
| if __name__ == "__main__": | |
| # Load environment variables for testing | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| # Run tests when script executed directly | |
| test_all_models() |