Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| HuggingFace Qwen 2.5 Model Client | |
| Handles inference for router, main, and complex models with cost tracking | |
| """ | |
| import os | |
| import time | |
| import logging | |
| from typing import Dict, Any, List, Optional | |
| from dataclasses import dataclass | |
| from enum import Enum | |
| from huggingface_hub import InferenceClient | |
| from langchain_huggingface import HuggingFaceEndpoint | |
| from langchain_core.language_models.llms import LLM | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| class ModelTier(Enum): | |
| """Model complexity tiers for cost optimization""" | |
| ROUTER = "router" # 3B - Fast, cheap routing decisions | |
| MAIN = "main" # 14B - Balanced performance | |
| COMPLEX = "complex" # 32B - Best performance for hard tasks | |
| class ModelConfig: | |
| """Configuration for each Qwen model""" | |
| name: str | |
| tier: ModelTier | |
| max_tokens: int | |
| temperature: float | |
| cost_per_token: float # Estimated cost per token | |
| timeout: int | |
| class InferenceResult: | |
| """Result of model inference with metadata""" | |
| response: str | |
| model_used: str | |
| tokens_used: int | |
| cost_estimate: float | |
| response_time: float | |
| success: bool | |
| error: Optional[str] = None | |
| class QwenClient: | |
| """HuggingFace client for Qwen 2.5 model family""" | |
| def __init__(self, hf_token: Optional[str] = None): | |
| """Initialize the Qwen client with HuggingFace token""" | |
| self.hf_token = hf_token or os.getenv("HUGGINGFACE_TOKEN") | |
| if not self.hf_token: | |
| logger.warning("No HuggingFace token provided. API access may be limited.") | |
| # Define model configurations - Updated with best available models | |
| self.models = { | |
| ModelTier.ROUTER: ModelConfig( | |
| name="Qwen/Qwen2.5-7B-Instruct", # Fast router for classification | |
| tier=ModelTier.ROUTER, | |
| max_tokens=512, | |
| temperature=0.1, | |
| cost_per_token=0.0003, # 7B model | |
| timeout=15 | |
| ), | |
| ModelTier.MAIN: ModelConfig( | |
| name="Qwen/Qwen2.5-32B-Instruct", # 4.5x more powerful for main tasks | |
| tier=ModelTier.MAIN, | |
| max_tokens=1024, | |
| temperature=0.1, | |
| cost_per_token=0.0008, # Higher cost for 32B | |
| timeout=25 | |
| ), | |
| ModelTier.COMPLEX: ModelConfig( | |
| name="Qwen/Qwen2.5-72B-Instruct", # 10x more powerful for complex reasoning! | |
| tier=ModelTier.COMPLEX, | |
| max_tokens=2048, | |
| temperature=0.1, | |
| cost_per_token=0.0015, # Premium for 72B model | |
| timeout=35 | |
| ) | |
| } | |
| # Initialize clients | |
| self.inference_clients = {} | |
| self.langchain_clients = {} | |
| self._initialize_clients() | |
| # Cost tracking | |
| self.total_cost = 0.0 | |
| self.request_count = 0 | |
| self.budget_limit = 0.10 # $0.10 total budget | |
| def _initialize_clients(self): | |
| """Initialize HuggingFace clients for each model""" | |
| for tier, config in self.models.items(): | |
| try: | |
| # HuggingFace InferenceClient for direct API calls | |
| self.inference_clients[tier] = InferenceClient( | |
| model=config.name, | |
| token=self.hf_token | |
| ) | |
| # LangChain wrapper for integration | |
| self.langchain_clients[tier] = HuggingFaceEndpoint( | |
| repo_id=config.name, | |
| max_new_tokens=config.max_tokens, | |
| temperature=config.temperature, | |
| huggingfacehub_api_token=self.hf_token, | |
| timeout=config.timeout | |
| ) | |
| logger.info(f"✅ Initialized {tier.value} model: {config.name}") | |
| except Exception as e: | |
| logger.error(f"❌ Failed to initialize {tier.value} model: {e}") | |
| self.inference_clients[tier] = None | |
| self.langchain_clients[tier] = None | |
| def get_model_status(self) -> Dict[str, bool]: | |
| """Check which models are available""" | |
| status = {} | |
| for tier in ModelTier: | |
| status[tier.value] = ( | |
| self.inference_clients.get(tier) is not None and | |
| self.langchain_clients.get(tier) is not None | |
| ) | |
| return status | |
| def select_model_tier(self, complexity: str = "medium", budget_conscious: bool = True, question_text: str = "") -> ModelTier: | |
| """Smart model selection based on task complexity, budget, and question analysis""" | |
| # Check budget constraints | |
| budget_used_percent = (self.total_cost / self.budget_limit) * 100 | |
| if budget_conscious and budget_used_percent > 80: | |
| logger.warning(f"Budget critical ({budget_used_percent:.1f}% used), forcing router model") | |
| return ModelTier.ROUTER | |
| elif budget_conscious and budget_used_percent > 60: | |
| logger.warning(f"Budget warning ({budget_used_percent:.1f}% used), limiting complex model usage") | |
| complexity = "simple" if complexity == "complex" else complexity | |
| # Enhanced complexity analysis based on question content | |
| if question_text: | |
| question_lower = question_text.lower() | |
| # Indicators for complex reasoning (use 72B model) | |
| complex_indicators = [ | |
| "analyze", "explain why", "reasoning", "logic", "complex", "difficult", | |
| "multi-step", "calculate and explain", "compare and contrast", | |
| "what is the relationship", "how does", "why is", "prove that", | |
| "step by step", "detailed analysis", "comprehensive" | |
| ] | |
| # Indicators for simple tasks (use 7B model) | |
| simple_indicators = [ | |
| "what is", "who is", "when", "where", "simple", "quick", | |
| "yes or no", "true or false", "list", "name", "find" | |
| ] | |
| # Math and coding indicators (use 32B model - good balance) | |
| math_indicators = [ | |
| "calculate", "compute", "solve", "equation", "formula", "math", | |
| "number", "total", "sum", "average", "percentage", "code", "program" | |
| ] | |
| # File processing indicators (use 32B+ models) | |
| file_indicators = [ | |
| "image", "picture", "photo", "audio", "sound", "video", "file", | |
| "document", "excel", "csv", "data", "chart", "graph" | |
| ] | |
| # Count indicators | |
| complex_score = sum(1 for indicator in complex_indicators if indicator in question_lower) | |
| simple_score = sum(1 for indicator in simple_indicators if indicator in question_lower) | |
| math_score = sum(1 for indicator in math_indicators if indicator in question_lower) | |
| file_score = sum(1 for indicator in file_indicators if indicator in question_lower) | |
| # Auto-detect complexity based on content | |
| if complex_score >= 2 or len(question_text) > 200: | |
| complexity = "complex" | |
| elif file_score >= 1 or math_score >= 2: | |
| complexity = "medium" | |
| elif simple_score >= 2 and complex_score == 0: | |
| complexity = "simple" | |
| # Select based on complexity with budget awareness | |
| if complexity == "complex" and budget_used_percent < 70: | |
| selected_tier = ModelTier.COMPLEX | |
| elif complexity == "simple" or budget_used_percent > 75: | |
| selected_tier = ModelTier.ROUTER | |
| else: | |
| selected_tier = ModelTier.MAIN | |
| # Fallback if selected model unavailable | |
| if not self.inference_clients.get(selected_tier): | |
| logger.warning(f"Selected model {selected_tier.value} unavailable, falling back") | |
| for fallback in [ModelTier.MAIN, ModelTier.ROUTER, ModelTier.COMPLEX]: | |
| if self.inference_clients.get(fallback): | |
| selected_tier = fallback | |
| break | |
| else: | |
| raise RuntimeError("No models available") | |
| # Log selection reasoning | |
| logger.info(f"Selected {selected_tier.value} model (complexity: {complexity}, budget: {budget_used_percent:.1f}%)") | |
| return selected_tier | |
| async def generate_async(self, | |
| prompt: str, | |
| tier: Optional[ModelTier] = None, | |
| max_tokens: Optional[int] = None) -> InferenceResult: | |
| """Async text generation with the specified model tier""" | |
| if tier is None: | |
| tier = self.select_model_tier() | |
| config = self.models[tier] | |
| client = self.inference_clients.get(tier) | |
| if not client: | |
| return InferenceResult( | |
| response="", | |
| model_used=config.name, | |
| tokens_used=0, | |
| cost_estimate=0.0, | |
| response_time=0.0, | |
| success=False, | |
| error=f"Model {tier.value} not available" | |
| ) | |
| start_time = time.time() | |
| try: | |
| # Use specified max_tokens or model default | |
| tokens = max_tokens or config.max_tokens | |
| # Use chat completion API for conversational models | |
| messages = [{"role": "user", "content": prompt}] | |
| response = client.chat_completion( | |
| messages=messages, | |
| model=config.name, | |
| max_tokens=tokens, | |
| temperature=config.temperature | |
| ) | |
| response_time = time.time() - start_time | |
| # Extract response from chat completion | |
| if response and response.choices: | |
| response_text = response.choices[0].message.content | |
| else: | |
| raise ValueError("No response received from model") | |
| # Estimate tokens used (rough approximation) | |
| estimated_tokens = len(prompt.split()) + len(response_text.split()) | |
| cost_estimate = estimated_tokens * config.cost_per_token | |
| # Update tracking | |
| self.total_cost += cost_estimate | |
| self.request_count += 1 | |
| logger.info(f"✅ Generated response using {tier.value} model in {response_time:.2f}s") | |
| return InferenceResult( | |
| response=response_text, | |
| model_used=config.name, | |
| tokens_used=estimated_tokens, | |
| cost_estimate=cost_estimate, | |
| response_time=response_time, | |
| success=True | |
| ) | |
| except Exception as e: | |
| response_time = time.time() - start_time | |
| logger.error(f"❌ Generation failed with {tier.value} model: {e}") | |
| return InferenceResult( | |
| response="", | |
| model_used=config.name, | |
| tokens_used=0, | |
| cost_estimate=0.0, | |
| response_time=response_time, | |
| success=False, | |
| error=str(e) | |
| ) | |
| def generate(self, | |
| prompt: str, | |
| tier: Optional[ModelTier] = None, | |
| max_tokens: Optional[int] = None) -> InferenceResult: | |
| """Synchronous text generation (wrapper for async)""" | |
| import asyncio | |
| # Create event loop if needed | |
| try: | |
| loop = asyncio.get_event_loop() | |
| except RuntimeError: | |
| loop = asyncio.new_event_loop() | |
| asyncio.set_event_loop(loop) | |
| return loop.run_until_complete( | |
| self.generate_async(prompt, tier, max_tokens) | |
| ) | |
| def get_langchain_llm(self, tier: ModelTier) -> Optional[LLM]: | |
| """Get LangChain LLM instance for agent integration""" | |
| return self.langchain_clients.get(tier) | |
| def get_usage_stats(self) -> Dict[str, Any]: | |
| """Get current usage and cost statistics""" | |
| return { | |
| "total_cost": self.total_cost, | |
| "request_count": self.request_count, | |
| "budget_limit": self.budget_limit, | |
| "budget_remaining": self.budget_limit - self.total_cost, | |
| "budget_used_percent": (self.total_cost / self.budget_limit) * 100, | |
| "average_cost_per_request": self.total_cost / max(self.request_count, 1), | |
| "models_available": self.get_model_status() | |
| } | |
| def reset_usage_tracking(self): | |
| """Reset usage statistics (for testing/development)""" | |
| self.total_cost = 0.0 | |
| self.request_count = 0 | |
| logger.info("Usage tracking reset") | |
| # Test functions | |
| def test_model_connection(client: QwenClient, tier: ModelTier): | |
| """Test connection to a specific model tier""" | |
| test_prompt = "Hello! Please respond with 'Connection successful' if you can read this." | |
| logger.info(f"Testing {tier.value} model...") | |
| result = client.generate(test_prompt, tier=tier, max_tokens=50) | |
| if result.success: | |
| logger.info(f"✅ {tier.value} model test successful: {result.response[:50]}...") | |
| logger.info(f" Response time: {result.response_time:.2f}s") | |
| logger.info(f" Cost estimate: ${result.cost_estimate:.6f}") | |
| else: | |
| logger.error(f"❌ {tier.value} model test failed: {result.error}") | |
| return result.success | |
| def test_all_models(): | |
| """Test all available models""" | |
| logger.info("🧪 Testing all Qwen models...") | |
| client = QwenClient() | |
| results = {} | |
| for tier in ModelTier: | |
| results[tier] = test_model_connection(client, tier) | |
| logger.info("📊 Test Results Summary:") | |
| for tier, success in results.items(): | |
| status = "✅ PASS" if success else "❌ FAIL" | |
| logger.info(f" {tier.value:8}: {status}") | |
| logger.info("💰 Usage Statistics:") | |
| stats = client.get_usage_stats() | |
| for key, value in stats.items(): | |
| if key != "models_available": | |
| logger.info(f" {key}: {value}") | |
| return results | |
| if __name__ == "__main__": | |
| # Load environment variables for testing | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| # Run tests when script executed directly | |
| test_all_models() |