# model_List.py - Model selection and analysis component with advanced features import os import re import json import time import math import nltk try: nltk.data.find('tokenizers/punkt') except LookupError: nltk.download("punkt") import torch import logging import numpy as np import importlib.util from enum import Enum # Add this import for Enum from service_registry import registry, MODEL, PRETRAINED_MODEL from sklearn.metrics.pairwise import cosine_similarity from typing import List, Tuple, Dict, Type, Any, Optional logger = logging.getLogger(__name__) # More robust config import try: from config import app_config except ImportError: logger.error("Failed to import app_config from config") # Create minimal app_config app_config = { "PROMPT_ANALYZER_CONFIG": { "MODEL_NAME": "gpt2", "DATASET_PATH": None, "SPECIALIZATION": None, "HIDDEN_DIM": 768, "MAX_CACHE_SIZE": 10 } } # Add SmartHybridAttention imports from utils.smartHybridAttention import SmartHybridAttention, get_hybrid_attention_config # Fix: Import get_sentence_transformer properly try: from utils.transformer_utils import get_sentence_transformer except ImportError: # Create a fallback implementation if the import fails def get_sentence_transformer(model_name): try: from sentence_transformers import SentenceTransformer return SentenceTransformer(model_name) except ImportError: logger.error("sentence_transformers package not available") # Return a minimal placeholder that won't crash initialization class MinimalSentenceTransformer: def __init__(self, *args, **kwargs): pass def encode(self, text): return [0.0] * 384 # Return zero vector with typical dimension return MinimalSentenceTransformer() from model_Custm import Wildnerve_tlm01 as CustomModel logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class ModelType(Enum): CUSTOM = "model_Custm.py" # Wildnerve-tlm01 custom implementation PRETRAINED = "model_PrTr.py" # GPT2 pretrained models # COMBINED = "model_Combn.py" # Hybrid approach with both # Replace generic Auto* classes with specific GPT-2 classes from transformers import GPT2Tokenizer, GPT2LMHeadModel class PromptAnalyzer: """ Enhanced prompt analyzer that combines: - Simple reliable keyword matching for basic topic detection - Advanced embedding-based analysis with SentenceTransformer when available - Perplexity calculations with GPT-2 for complexity assessment - SmartHybridAttention for analyzing complex or long prompts - Performance tracking and caching for efficiency """ def __init__(self, model_name=None, dataset_path=None, specialization=None, hidden_dim=None): self.logger = logging.getLogger(__name__) # Load config with better error handling try: if hasattr(app_config, "PROMPT_ANALYZER_CONFIG"): self.config_data = app_config.PROMPT_ANALYZER_CONFIG elif isinstance(app_config, dict) and "PROMPT_ANALYZER_CONFIG" in app_config: self.config_data = app_config["PROMPT_ANALYZER_CONFIG"] else: self.config_data = { "MODEL_NAME": "gpt2", "DATASET_PATH": None, "SPECIALIZATION": None, "HIDDEN_DIM": 768, "MAX_CACHE_SIZE": 10 } except Exception as e: self.logger.warning(f"Error loading config: {e}, using defaults") self.config_data = { "MODEL_NAME": "gpt2", "DATASET_PATH": None, "SPECIALIZATION": None, "HIDDEN_DIM": 768, "MAX_CACHE_SIZE": 10 } # Use provided values or config values with safe getters self.model_name = model_name or self._safe_get("MODEL_NAME", "gpt2") self.dataset_path = dataset_path or self._safe_get("DATASET_PATH") self.specialization = specialization or self._safe_get("SPECIALIZATION") self.hidden_dim = hidden_dim or self._safe_get("HIDDEN_DIM", 768) self.logger.info(f"Initialized PromptAnalyzer with {self.model_name}") self._model_cache: Dict[str, Type] = {} self._performance_metrics: Dict[str, Dict[str, float]] = {} # Load predefined topics from config or fall back to defaults self._load_predefined_topics() # Always use a proper SentenceTransformer model - fix this to avoid warnings if hasattr(self, 'sentence_model'): del self.sentence_model # Remove any existing instance # Use a proper SentenceTransformer model self.sentence_model = get_sentence_transformer('sentence-transformers/all-MiniLM-L6-v2') self.logger.info(f"Using SentenceTransformer model: sentence-transformers/all-MiniLM-L6-v2") # Use specific GPT-2 classes instead of Auto* classes self.tokenizer = GPT2Tokenizer.from_pretrained("gpt2") # Fix missing pad token in GPT-2 if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token self.model = GPT2LMHeadModel.from_pretrained("gpt2") self.model.eval() logger.info(f"Initialized PromptAnalyzer with {self.model_name}, specialization: {self.specialization}, hidden_dim: {self.hidden_dim}") if self.dataset_path: logger.info(f"Using dataset from: {self.dataset_path}") # For caching and performance tracking self._model_cache = {} self._performance_metrics = {} # Initialize model_class attribute self.model_class = None # Initialize attention mechanism self.attention = None # Try to load advanced analysis tools with proper error handling self._init_advanced_tools() # Load configuration for analysis self.similarity_threshold = getattr(app_config, "SIMILARITY_THRESHOLD", 0.85) self.max_cache_size = 10 try: # Try to get from config if available if hasattr(app_config, 'PROMPT_ANALYZER_CONFIG'): self.max_cache_size = getattr(app_config.PROMPT_ANALYZER_CONFIG, "MAX_CACHE_SIZE", 10) except Exception: pass def _safe_get(self, key, default=None): """Safely get a configuration value regardless of config type""" try: if isinstance(self.config_data, dict): return self.config_data.get(key, default) elif hasattr(self.config_data, key): return getattr(self.config_data, key, default) return default except: return default def _load_predefined_topics(self): """Load topic keywords from config file or use defaults with caching""" # Try to load from config first try: if hasattr(app_config, 'TOPIC_KEYWORDS') and app_config.TOPIC_KEYWORDS: logger.info("Loading topic keywords from config") self.predefined_topics = app_config.TOPIC_KEYWORDS return # Try loading from a JSON file in the data directory topic_file = os.path.join(app_config.DATA_DIR, "topic_keywords.json") if os.path.exists(topic_file): with open(topic_file, 'r') as f: self.predefined_topics = json.load(f) logger.info(f"Loaded {len(self.predefined_topics)} topic categories from {topic_file}") return except Exception as e: logger.warning(f"Error loading topic keywords: {e}, using defaults") # Fall back to default hardcoded topics logger.info("Using default hardcoded topic keywords") self.predefined_topics = { "programming": [ "python", "java", "javascript", "typescript", "rust", "go", "golang", # ...existing keywords... ], "computer_science": [ # ...existing keywords... ], "software_engineering": [ # ...existing keywords... ], "web_development": [ # ...existing keywords... ] } # Cache the topics to a file for future use try: os.makedirs(app_config.DATA_DIR, exist_ok=True) with open(os.path.join(app_config.DATA_DIR, "topic_keywords.json"), 'w') as f: json.dump(self.predefined_topics, f, indent=2) except Exception as e: logger.debug(f"Could not cache topic keywords: {e}") def _init_advanced_tools(self): """Initialize advanced analysis tools with proper error handling and fallbacks""" self.sentence_model = None self.gpt2_model = None self.gpt2_tokenizer = None # For embedding model, implement multiple fallbacks MAX_RETRIES = 3 embedding_models = [ 'sentence-transformers/all-MiniLM-L6-v2', # Primary choice 'sentence-transformers/paraphrase-MiniLM-L3-v2', # Smaller fallback 'sentence-transformers/distilbert-base-nli-mean-tokens' # Last resort ] for retry in range(MAX_RETRIES): for model_name in embedding_models: try: from utils.transformer_utils import get_sentence_transformer self.sentence_model = get_sentence_transformer(model_name) self.logger.info(f"Successfully loaded SentenceTransformer: {model_name}") break except Exception as e: self.logger.warning(f"Failed to load embedding model {model_name}: {e}") if self.sentence_model: break # Wait before retry time.sleep(2) # Create keyword-based fallback if embedding loading completely fails if not self.sentence_model: self.logger.warning("All embedding models failed to load - using keyword fallback") self._use_keyword_fallback = True else: self._use_keyword_fallback = False # Initialize SmartHybridAttention try: attention_config = get_hybrid_attention_config() self.attention = SmartHybridAttention( dim=attention_config.get("DIM", 768), num_heads=attention_config.get("NUM_HEADS", 8), window_size=attention_config.get("WINDOW_SIZE", 256), use_sliding=attention_config.get("USE_SLIDING", True), use_global=attention_config.get("USE_GLOBAL", True), use_hierarchical=attention_config.get("USE_HIERARCHICAL", False) ) self.logger.info("Initialized SmartHybridAttention for prompt analysis") except Exception as e: self.logger.warning(f"Failed to initialize SmartHybridAttention: {e}") self.attention = None def _track_model_performance(self, model_type: str, start_time: float) -> None: """Track model loading and performance metrics. Args: model_type: Type of model being tracked start_time: Start time of operation """ end_time = time.time() if model_type not in self._performance_metrics: self._performance_metrics[model_type] = { 'load_time': 0.0, 'usage_count': 0, 'avg_response_time': 0.0 } # Ensure we're not creating circular references that might impact serialization metrics = self._performance_metrics[model_type] metrics['load_time'] = end_time - start_time metrics['usage_count'] += 1 # Update average response time current_avg = metrics['avg_response_time'] metrics['avg_response_time'] = ( (current_avg * (metrics['usage_count'] - 1) + (end_time - start_time)) / metrics['usage_count'] ) def manage_cache(self, max_cache_size: int = None) -> None: """Manage model cache size and cleanup least used models""" try: # Use provided value or default if max_cache_size is None: max_cache_size = self.max_cache_size if len(self._model_cache) > max_cache_size: # Sort models by usage count sorted_models = sorted( self._performance_metrics.items(), key=lambda x: (x[1]['usage_count'], -x[1]['avg_response_time']) ) # Remove least used models for model_type, _ in sorted_models[:-max_cache_size]: self._model_cache.pop(model_type, None) logger.info(f"Removed {model_type} from cache due to low usage") # Log cache cleanup logger.info(f"Cache cleaned up. Current size: {len(self._model_cache)}") except Exception as e: logger.error(f"Error managing cache: {e}") def _load_model_class(self, model_type: str) -> Type: """Load model class with caching""" start_time = time.time() try: # If model is already cached, return it directly if model_type in self._model_cache: self._track_model_performance(model_type, start_time) return self._model_cache[model_type] # Clean up model name clean_model_type = model_type.replace('.py', '') # Handle different model types if clean_model_type == "model_PrTr" or clean_model_type.endswith("PrTr"): try: module = importlib.import_module("model_PrTr") model_class = getattr(module, "Wildnerve_tlm01") except Exception as e: logger.warning(f"Error loading model_PrTr: {e}") # Fallback to default model module = importlib.import_module("model_Custm") model_class = getattr(module, "Wildnerve_tlm01") else: # Default to getting Wildnerve_tlm01 try: module_name = clean_model_type if not module_name.startswith("model_"): module_name = f"model_{module_name}" module = importlib.import_module(module_name) model_class = getattr(module, "Wildnerve_tlm01") except Exception as e: logger.warning(f"Error loading {model_type}: {e}, falling back to CustomModel") # Fallback to main model module = importlib.import_module("model_Custm") model_class = getattr(module, "Wildnerve_tlm01") # Cache and track the model class self._model_cache[model_type] = model_class self._track_model_performance(model_type, start_time) return model_class except Exception as e: logger.error(f"Error loading model class {model_type}: {e}") # Try to get the default model as fallback try: module = importlib.import_module("model_Custm") return getattr(module, "Wildnerve_tlm01") except Exception: # This should never happen, but just in case from types import new_class return new_class("DummyModel", (), {}) def _analyze_with_attention(self, prompt): """Use SmartHybridAttention to analyze complex prompts""" if not self.attention or not self.sentence_model: return None try: # Split into sentences for better analysis sentences = nltk.sent_tokenize(prompt) if len(sentences) <= 1: return None # Not complex enough for attention analysis # Get embeddings for each sentence sentence_embeddings = [self.sentence_model.encode(s) for s in sentences] embeddings_tensor = torch.tensor(sentence_embeddings).unsqueeze(1) # [seq_len, batch, dim] # Apply attention to identify important relationships between sentences attended_embeddings, attention_weights = self.attention( query=embeddings_tensor, key=embeddings_tensor, value=embeddings_tensor, input_text=prompt # Pass original text for content-aware attention ) # Calculate importance of each sentence based on attention weights importance = attention_weights.mean(dim=(0,1)).squeeze() if len(importance.shape) == 0: # Handle single sentence case importance = importance.unsqueeze(0) # Get top sentences by importance top_indices = torch.argsort(importance, descending=True)[:min(3, len(sentences))] # Weight topic analysis by sentence importance topic_scores = {topic: 0.0 for topic in self.predefined_topics} for idx in top_indices: sentence = sentences[idx.item()] weight = importance[idx].item() / importance.sum().item() # Analyze this important sentence for topic, keywords in self.predefined_topics.items(): sent_lower = sentence.lower() sent_score = sum(1 for keyword in keywords if keyword in sent_lower) topic_scores[topic] += sent_score * weight * 1.5 # Boost importance of attention-weighted scores return topic_scores except Exception as e: self.logger.error(f"Error in attention-based analysis: {e}") return None def _analyze_with_keywords(self, prompt: str) -> Tuple[str, float]: """Analyze prompt using only keywords when embeddings are unavailable""" prompt_lower = prompt.lower() technical_matches = 0 total_words = len(prompt_lower.split()) # Count matches across all technical categories for category, keywords in self.predefined_topics.items(): for keyword in keywords: if keyword in prompt_lower: technical_matches += 1 # Simple ratio calculation match_ratio = technical_matches / max(1, min(15, total_words)) if match_ratio > 0.1: # Even a single match in a short query is significant return "model_Custm", match_ratio else: return "model_PrTr", 0.7 def analyze_prompt(self, prompt: str) -> Tuple[str, float]: """Analyze if a prompt is technical or general and return the appropriate model type and confidence score.""" # Check if we need to use keyword fallback due to embedding failure if hasattr(self, '_use_keyword_fallback') and self._use_keyword_fallback: return self._analyze_with_keywords(prompt) # Convert prompt to lowercase for case-insensitive matching prompt_lower = prompt.lower() # Check for technical keywords from predefined topics - use memory-efficient approach technical_matches = 0 word_count = len(prompt_lower.split()) # Use a set-based intersection approach for better performance on longer texts prompt_words = set(prompt_lower.split()) # Count keyword matches across all technical categories more efficiently for category, keywords in self.predefined_topics.items(): # Convert keywords to set for O(1) lookups - helps with longer texts keywords_set = set(keywords) matches = prompt_words.intersection(keywords_set) technical_matches += len(matches) # Also check for multi-word keywords not caught by simple splitting for keyword in keywords: if " " in keyword and keyword in prompt_lower: technical_matches += 1 # Calculate keyword match ratio (normalized by word count) keyword_ratio = technical_matches / max(1, min(20, word_count)) # Get attention-based analysis for complex prompts attention_scores = None if len(prompt) > 100 and self.attention: # Only use attention for longer prompts try: attention_scores = self._analyze_with_attention(prompt) except Exception as e: self.logger.warning(f"Error in attention analysis: {e}") # Use embedding similarity for semantic understanding try: # Get embedding of the prompt prompt_embedding = self.sentence_model.encode(prompt) # Example technical and general reference texts technical_reference = "Write code to solve a programming problem using algorithms and data structures." general_reference = "Tell me about daily life topics like weather, food, or general conversation." # Get embeddings for reference texts technical_embedding = self.sentence_model.encode(technical_reference) general_embedding = self.sentence_model.encode(general_reference) # Calculate cosine similarities technical_similarity = cosine_similarity([prompt_embedding], [technical_embedding])[0][0] general_similarity = cosine_similarity([prompt_embedding], [general_embedding])[0][0] # Calculate technical score combining all signals: # 1. Keyword matching (30%) # 2. Semantic similarity (40%) # 3. Attention analysis if available (30%) technical_score = 0.3 * keyword_ratio + 0.4 * technical_similarity # Add attention score contribution if available if attention_scores: # Calculate tech score from attention - sum of programming/computer_science categories tech_attention_score = ( attention_scores.get("programming", 0) + attention_scores.get("computer_science", 0) + attention_scores.get("software_engineering", 0) + attention_scores.get("web_development", 0) ) / 4.0 # Normalize technical_score += 0.3 * tech_attention_score # Decide based on combined score if technical_score > 0.3: # Threshold - tune this as needed return "model_Custm", technical_score else: return "model_PrTr", 1.0 - technical_score except Exception as e: self.logger.error(f"Error in prompt analysis: {e}") # Fallback to simple keyword matching if technical_matches > 0: return "model_Custm", 0.7 else: return "model_PrTr", 0.7 def analyze(self, prompt: str) -> int: """Legacy compatibility method that returns a candidate index.""" model_type, confidence = self.analyze_prompt(prompt) # Map model_type to candidate index if model_type == "model_Custm": return 0 # Index 0 corresponds to model_Custm else: return 1 # Index 1 corresponds to model_PrTr def choose_model(self, prompt: str = None) -> Type: """Enhanced model selection that combines config and analysis""" try: start_time = time.time() # If we have a cached model class, return it if self.model_class: return self.model_class # Get candidate index from analysis if prompt provided candidate_index = 0 if prompt: candidate_index = self.analyze(prompt) # Get selected models list selected_models = self.get_selected_models() # Ensure index is within bounds if candidate_index >= len(selected_models): candidate_index %= len(selected_models) # Get model type model_type = selected_models[candidate_index] # Load and return model class model_class = self._load_model_class(model_type) self.model_class = model_class # Cache for later self._track_model_performance(model_type, start_time) return model_class except Exception as e: logger.error(f"Error in model selection: {e}") # Always fallback to a valid model try: from model_Custm import Wildnerve_tlm01 return Wildnerve_tlm01 except Exception: logger.critical("Failed to import default model!") # This function must return something, so create a dummy class class DummyModel: def __init__(self, **kwargs): pass return DummyModel def get_selected_models(self) -> list: """Return the list of selected model types for use in the system""" # First try getting from config try: if hasattr(app_config, 'SELECTED_MODEL'): models = app_config.SELECTED_MODEL if models: return models except Exception as e: logger.warning(f"Error reading SELECTED_MODEL from config: {e}") # Default model types with fallbacks in case primary fails return ["model_Custm.py", "model_PrTr.py"] def get_model_instance(self, prompt: str = None) -> Any: """Get an initialized model instance based on the analyzed prompt.""" model_class = self.choose_model(prompt) try: return model_class() except Exception as e: logger.error(f"Error initializing model: {e}") try: from model_Custm import Wildnerve_tlm01 return Wildnerve_tlm01() except Exception: logger.critical("Could not instantiate any model!") return None def get_performance_metrics(self) -> Dict[str, Dict[str, float]]: """Get performance metrics for all models.""" return self._performance_metrics # Register the PromptAnalyzer in the service registry to resolve dependencies. registry.register("prompt_analyzer", PromptAnalyzer()) def main(): # For testing purposes; in production, model_manager will retrieve the analyzer. analyzer = registry.get("prompt_analyzer") sample_prompt = "I'm having trouble debugging my Python code for a sorting algorithm." primary_topic, subtopics = analyzer.analyze_prompt(sample_prompt) selected = analyzer.choose_model(sample_prompt) logger.info(f"Sample prompt analysis:\nPrimary Topic: {primary_topic}\nSubtopics: {subtopics}\nSelected Model: {selected}") # Test the advanced analysis if hasattr(analyzer, 'sentence_model') and analyzer.sentence_model: complexity_index = analyzer.analyze(sample_prompt) logger.info(f"Complexity analysis index: {complexity_index}") if __name__ == "__main__": main()