WildnerveAI's picture
Upload 4 files
c9d7656 verified
# model_List.py - Model selection and analysis component with advanced features
import os
import re
import json
import time
import math
import nltk
try:
nltk.data.find('tokenizers/punkt')
except LookupError:
nltk.download("punkt")
import torch
import logging
import numpy as np
import importlib.util
from enum import Enum # Add this import for Enum
from service_registry import registry, MODEL, PRETRAINED_MODEL
from sklearn.metrics.pairwise import cosine_similarity
from typing import List, Tuple, Dict, Type, Any, Optional
logger = logging.getLogger(__name__)
# More robust config import
try:
from config import app_config
except ImportError:
logger.error("Failed to import app_config from config")
# Create minimal app_config
app_config = {
"PROMPT_ANALYZER_CONFIG": {
"MODEL_NAME": "gpt2",
"DATASET_PATH": None,
"SPECIALIZATION": None,
"HIDDEN_DIM": 768,
"MAX_CACHE_SIZE": 10
}
}
# Add SmartHybridAttention imports
from utils.smartHybridAttention import SmartHybridAttention, get_hybrid_attention_config
# Fix: Import get_sentence_transformer properly
try:
from utils.transformer_utils import get_sentence_transformer
except ImportError:
# Create a fallback implementation if the import fails
def get_sentence_transformer(model_name):
try:
from sentence_transformers import SentenceTransformer
return SentenceTransformer(model_name)
except ImportError:
logger.error("sentence_transformers package not available")
# Return a minimal placeholder that won't crash initialization
class MinimalSentenceTransformer:
def __init__(self, *args, **kwargs):
pass
def encode(self, text):
return [0.0] * 384 # Return zero vector with typical dimension
return MinimalSentenceTransformer()
from model_Custm import Wildnerve_tlm01 as CustomModel
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class ModelType(Enum):
CUSTOM = "model_Custm.py" # Wildnerve-tlm01 custom implementation
PRETRAINED = "model_PrTr.py" # GPT2 pretrained models
# COMBINED = "model_Combn.py" # Hybrid approach with both
# Replace generic Auto* classes with specific GPT-2 classes
from transformers import GPT2Tokenizer, GPT2LMHeadModel
class PromptAnalyzer:
"""
Enhanced prompt analyzer that combines:
- Simple reliable keyword matching for basic topic detection
- Advanced embedding-based analysis with SentenceTransformer when available
- Perplexity calculations with GPT-2 for complexity assessment
- SmartHybridAttention for analyzing complex or long prompts
- Performance tracking and caching for efficiency
"""
def __init__(self, model_name=None, dataset_path=None, specialization=None, hidden_dim=None):
self.logger = logging.getLogger(__name__)
# Load config with better error handling
try:
if hasattr(app_config, "PROMPT_ANALYZER_CONFIG"):
self.config_data = app_config.PROMPT_ANALYZER_CONFIG
elif isinstance(app_config, dict) and "PROMPT_ANALYZER_CONFIG" in app_config:
self.config_data = app_config["PROMPT_ANALYZER_CONFIG"]
else:
self.config_data = {
"MODEL_NAME": "gpt2",
"DATASET_PATH": None,
"SPECIALIZATION": None,
"HIDDEN_DIM": 768,
"MAX_CACHE_SIZE": 10
}
except Exception as e:
self.logger.warning(f"Error loading config: {e}, using defaults")
self.config_data = {
"MODEL_NAME": "gpt2",
"DATASET_PATH": None,
"SPECIALIZATION": None,
"HIDDEN_DIM": 768,
"MAX_CACHE_SIZE": 10
}
# Use provided values or config values with safe getters
self.model_name = model_name or self._safe_get("MODEL_NAME", "gpt2")
self.dataset_path = dataset_path or self._safe_get("DATASET_PATH")
self.specialization = specialization or self._safe_get("SPECIALIZATION")
self.hidden_dim = hidden_dim or self._safe_get("HIDDEN_DIM", 768)
self.logger.info(f"Initialized PromptAnalyzer with {self.model_name}")
self._model_cache: Dict[str, Type] = {}
self._performance_metrics: Dict[str, Dict[str, float]] = {}
# Load predefined topics from config or fall back to defaults
self._load_predefined_topics()
# Always use a proper SentenceTransformer model - fix this to avoid warnings
if hasattr(self, 'sentence_model'):
del self.sentence_model # Remove any existing instance
# Use a proper SentenceTransformer model
self.sentence_model = get_sentence_transformer('sentence-transformers/all-MiniLM-L6-v2')
self.logger.info(f"Using SentenceTransformer model: sentence-transformers/all-MiniLM-L6-v2")
# Use specific GPT-2 classes instead of Auto* classes
self.tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
# Fix missing pad token in GPT-2
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
self.model = GPT2LMHeadModel.from_pretrained("gpt2")
self.model.eval()
logger.info(f"Initialized PromptAnalyzer with {self.model_name}, specialization: {self.specialization}, hidden_dim: {self.hidden_dim}")
if self.dataset_path:
logger.info(f"Using dataset from: {self.dataset_path}")
# For caching and performance tracking
self._model_cache = {}
self._performance_metrics = {}
# Initialize model_class attribute
self.model_class = None
# Initialize attention mechanism
self.attention = None
# Try to load advanced analysis tools with proper error handling
self._init_advanced_tools()
# Load configuration for analysis
self.similarity_threshold = getattr(app_config, "SIMILARITY_THRESHOLD", 0.85)
self.max_cache_size = 10
try:
# Try to get from config if available
if hasattr(app_config, 'PROMPT_ANALYZER_CONFIG'):
self.max_cache_size = getattr(app_config.PROMPT_ANALYZER_CONFIG, "MAX_CACHE_SIZE", 10)
except Exception:
pass
def _safe_get(self, key, default=None):
"""Safely get a configuration value regardless of config type"""
try:
if isinstance(self.config_data, dict):
return self.config_data.get(key, default)
elif hasattr(self.config_data, key):
return getattr(self.config_data, key, default)
return default
except:
return default
def _load_predefined_topics(self):
"""Load topic keywords from config file or use defaults with caching"""
# Try to load from config first
try:
if hasattr(app_config, 'TOPIC_KEYWORDS') and app_config.TOPIC_KEYWORDS:
logger.info("Loading topic keywords from config")
self.predefined_topics = app_config.TOPIC_KEYWORDS
return
# Try loading from a JSON file in the data directory
topic_file = os.path.join(app_config.DATA_DIR, "topic_keywords.json")
if os.path.exists(topic_file):
with open(topic_file, 'r') as f:
self.predefined_topics = json.load(f)
logger.info(f"Loaded {len(self.predefined_topics)} topic categories from {topic_file}")
return
except Exception as e:
logger.warning(f"Error loading topic keywords: {e}, using defaults")
# Fall back to default hardcoded topics
logger.info("Using default hardcoded topic keywords")
self.predefined_topics = {
"programming": [
"python", "java", "javascript", "typescript", "rust", "go", "golang",
# ...existing keywords...
],
"computer_science": [
# ...existing keywords...
],
"software_engineering": [
# ...existing keywords...
],
"web_development": [
# ...existing keywords...
]
}
# Cache the topics to a file for future use
try:
os.makedirs(app_config.DATA_DIR, exist_ok=True)
with open(os.path.join(app_config.DATA_DIR, "topic_keywords.json"), 'w') as f:
json.dump(self.predefined_topics, f, indent=2)
except Exception as e:
logger.debug(f"Could not cache topic keywords: {e}")
def _init_advanced_tools(self):
"""Initialize advanced analysis tools with proper error handling and fallbacks"""
self.sentence_model = None
self.gpt2_model = None
self.gpt2_tokenizer = None
# For embedding model, implement multiple fallbacks
MAX_RETRIES = 3
embedding_models = [
'sentence-transformers/all-MiniLM-L6-v2', # Primary choice
'sentence-transformers/paraphrase-MiniLM-L3-v2', # Smaller fallback
'sentence-transformers/distilbert-base-nli-mean-tokens' # Last resort
]
for retry in range(MAX_RETRIES):
for model_name in embedding_models:
try:
from utils.transformer_utils import get_sentence_transformer
self.sentence_model = get_sentence_transformer(model_name)
self.logger.info(f"Successfully loaded SentenceTransformer: {model_name}")
break
except Exception as e:
self.logger.warning(f"Failed to load embedding model {model_name}: {e}")
if self.sentence_model:
break
# Wait before retry
time.sleep(2)
# Create keyword-based fallback if embedding loading completely fails
if not self.sentence_model:
self.logger.warning("All embedding models failed to load - using keyword fallback")
self._use_keyword_fallback = True
else:
self._use_keyword_fallback = False
# Initialize SmartHybridAttention
try:
attention_config = get_hybrid_attention_config()
self.attention = SmartHybridAttention(
dim=attention_config.get("DIM", 768),
num_heads=attention_config.get("NUM_HEADS", 8),
window_size=attention_config.get("WINDOW_SIZE", 256),
use_sliding=attention_config.get("USE_SLIDING", True),
use_global=attention_config.get("USE_GLOBAL", True),
use_hierarchical=attention_config.get("USE_HIERARCHICAL", False)
)
self.logger.info("Initialized SmartHybridAttention for prompt analysis")
except Exception as e:
self.logger.warning(f"Failed to initialize SmartHybridAttention: {e}")
self.attention = None
def _track_model_performance(self, model_type: str, start_time: float) -> None:
"""Track model loading and performance metrics.
Args:
model_type: Type of model being tracked
start_time: Start time of operation
"""
end_time = time.time()
if model_type not in self._performance_metrics:
self._performance_metrics[model_type] = {
'load_time': 0.0,
'usage_count': 0,
'avg_response_time': 0.0
}
# Ensure we're not creating circular references that might impact serialization
metrics = self._performance_metrics[model_type]
metrics['load_time'] = end_time - start_time
metrics['usage_count'] += 1
# Update average response time
current_avg = metrics['avg_response_time']
metrics['avg_response_time'] = (
(current_avg * (metrics['usage_count'] - 1) + (end_time - start_time))
/ metrics['usage_count']
)
def manage_cache(self, max_cache_size: int = None) -> None:
"""Manage model cache size and cleanup least used models"""
try:
# Use provided value or default
if max_cache_size is None:
max_cache_size = self.max_cache_size
if len(self._model_cache) > max_cache_size:
# Sort models by usage count
sorted_models = sorted(
self._performance_metrics.items(),
key=lambda x: (x[1]['usage_count'], -x[1]['avg_response_time'])
)
# Remove least used models
for model_type, _ in sorted_models[:-max_cache_size]:
self._model_cache.pop(model_type, None)
logger.info(f"Removed {model_type} from cache due to low usage")
# Log cache cleanup
logger.info(f"Cache cleaned up. Current size: {len(self._model_cache)}")
except Exception as e:
logger.error(f"Error managing cache: {e}")
def _load_model_class(self, model_type: str) -> Type:
"""Load model class with caching"""
start_time = time.time()
try:
# If model is already cached, return it directly
if model_type in self._model_cache:
self._track_model_performance(model_type, start_time)
return self._model_cache[model_type]
# Clean up model name
clean_model_type = model_type.replace('.py', '')
# Handle different model types
if clean_model_type == "model_PrTr" or clean_model_type.endswith("PrTr"):
try:
module = importlib.import_module("model_PrTr")
model_class = getattr(module, "Wildnerve_tlm01")
except Exception as e:
logger.warning(f"Error loading model_PrTr: {e}")
# Fallback to default model
module = importlib.import_module("model_Custm")
model_class = getattr(module, "Wildnerve_tlm01")
else:
# Default to getting Wildnerve_tlm01
try:
module_name = clean_model_type
if not module_name.startswith("model_"):
module_name = f"model_{module_name}"
module = importlib.import_module(module_name)
model_class = getattr(module, "Wildnerve_tlm01")
except Exception as e:
logger.warning(f"Error loading {model_type}: {e}, falling back to CustomModel")
# Fallback to main model
module = importlib.import_module("model_Custm")
model_class = getattr(module, "Wildnerve_tlm01")
# Cache and track the model class
self._model_cache[model_type] = model_class
self._track_model_performance(model_type, start_time)
return model_class
except Exception as e:
logger.error(f"Error loading model class {model_type}: {e}")
# Try to get the default model as fallback
try:
module = importlib.import_module("model_Custm")
return getattr(module, "Wildnerve_tlm01")
except Exception:
# This should never happen, but just in case
from types import new_class
return new_class("DummyModel", (), {})
def _analyze_with_attention(self, prompt):
"""Use SmartHybridAttention to analyze complex prompts"""
if not self.attention or not self.sentence_model:
return None
try:
# Split into sentences for better analysis
sentences = nltk.sent_tokenize(prompt)
if len(sentences) <= 1:
return None # Not complex enough for attention analysis
# Get embeddings for each sentence
sentence_embeddings = [self.sentence_model.encode(s) for s in sentences]
embeddings_tensor = torch.tensor(sentence_embeddings).unsqueeze(1) # [seq_len, batch, dim]
# Apply attention to identify important relationships between sentences
attended_embeddings, attention_weights = self.attention(
query=embeddings_tensor,
key=embeddings_tensor,
value=embeddings_tensor,
input_text=prompt # Pass original text for content-aware attention
)
# Calculate importance of each sentence based on attention weights
importance = attention_weights.mean(dim=(0,1)).squeeze()
if len(importance.shape) == 0: # Handle single sentence case
importance = importance.unsqueeze(0)
# Get top sentences by importance
top_indices = torch.argsort(importance, descending=True)[:min(3, len(sentences))]
# Weight topic analysis by sentence importance
topic_scores = {topic: 0.0 for topic in self.predefined_topics}
for idx in top_indices:
sentence = sentences[idx.item()]
weight = importance[idx].item() / importance.sum().item()
# Analyze this important sentence
for topic, keywords in self.predefined_topics.items():
sent_lower = sentence.lower()
sent_score = sum(1 for keyword in keywords if keyword in sent_lower)
topic_scores[topic] += sent_score * weight * 1.5 # Boost importance of attention-weighted scores
return topic_scores
except Exception as e:
self.logger.error(f"Error in attention-based analysis: {e}")
return None
def _analyze_with_keywords(self, prompt: str) -> Tuple[str, float]:
"""Analyze prompt using only keywords when embeddings are unavailable"""
prompt_lower = prompt.lower()
technical_matches = 0
total_words = len(prompt_lower.split())
# Count matches across all technical categories
for category, keywords in self.predefined_topics.items():
for keyword in keywords:
if keyword in prompt_lower:
technical_matches += 1
# Simple ratio calculation
match_ratio = technical_matches / max(1, min(15, total_words))
if match_ratio > 0.1: # Even a single match in a short query is significant
return "model_Custm", match_ratio
else:
return "model_PrTr", 0.7
def analyze_prompt(self, prompt: str) -> Tuple[str, float]:
"""Analyze if a prompt is technical or general and return the appropriate model type and confidence score."""
# Check if we need to use keyword fallback due to embedding failure
if hasattr(self, '_use_keyword_fallback') and self._use_keyword_fallback:
return self._analyze_with_keywords(prompt)
# Convert prompt to lowercase for case-insensitive matching
prompt_lower = prompt.lower()
# Check for technical keywords from predefined topics - use memory-efficient approach
technical_matches = 0
word_count = len(prompt_lower.split())
# Use a set-based intersection approach for better performance on longer texts
prompt_words = set(prompt_lower.split())
# Count keyword matches across all technical categories more efficiently
for category, keywords in self.predefined_topics.items():
# Convert keywords to set for O(1) lookups - helps with longer texts
keywords_set = set(keywords)
matches = prompt_words.intersection(keywords_set)
technical_matches += len(matches)
# Also check for multi-word keywords not caught by simple splitting
for keyword in keywords:
if " " in keyword and keyword in prompt_lower:
technical_matches += 1
# Calculate keyword match ratio (normalized by word count)
keyword_ratio = technical_matches / max(1, min(20, word_count))
# Get attention-based analysis for complex prompts
attention_scores = None
if len(prompt) > 100 and self.attention: # Only use attention for longer prompts
try:
attention_scores = self._analyze_with_attention(prompt)
except Exception as e:
self.logger.warning(f"Error in attention analysis: {e}")
# Use embedding similarity for semantic understanding
try:
# Get embedding of the prompt
prompt_embedding = self.sentence_model.encode(prompt)
# Example technical and general reference texts
technical_reference = "Write code to solve a programming problem using algorithms and data structures."
general_reference = "Tell me about daily life topics like weather, food, or general conversation."
# Get embeddings for reference texts
technical_embedding = self.sentence_model.encode(technical_reference)
general_embedding = self.sentence_model.encode(general_reference)
# Calculate cosine similarities
technical_similarity = cosine_similarity([prompt_embedding], [technical_embedding])[0][0]
general_similarity = cosine_similarity([prompt_embedding], [general_embedding])[0][0]
# Calculate technical score combining all signals:
# 1. Keyword matching (30%)
# 2. Semantic similarity (40%)
# 3. Attention analysis if available (30%)
technical_score = 0.3 * keyword_ratio + 0.4 * technical_similarity
# Add attention score contribution if available
if attention_scores:
# Calculate tech score from attention - sum of programming/computer_science categories
tech_attention_score = (
attention_scores.get("programming", 0) +
attention_scores.get("computer_science", 0) +
attention_scores.get("software_engineering", 0) +
attention_scores.get("web_development", 0)
) / 4.0 # Normalize
technical_score += 0.3 * tech_attention_score
# Decide based on combined score
if technical_score > 0.3: # Threshold - tune this as needed
return "model_Custm", technical_score
else:
return "model_PrTr", 1.0 - technical_score
except Exception as e:
self.logger.error(f"Error in prompt analysis: {e}")
# Fallback to simple keyword matching
if technical_matches > 0:
return "model_Custm", 0.7
else:
return "model_PrTr", 0.7
def analyze(self, prompt: str) -> int:
"""Legacy compatibility method that returns a candidate index."""
model_type, confidence = self.analyze_prompt(prompt)
# Map model_type to candidate index
if model_type == "model_Custm":
return 0 # Index 0 corresponds to model_Custm
else:
return 1 # Index 1 corresponds to model_PrTr
def choose_model(self, prompt: str = None) -> Type:
"""Enhanced model selection that combines config and analysis"""
try:
start_time = time.time()
# If we have a cached model class, return it
if self.model_class:
return self.model_class
# Get candidate index from analysis if prompt provided
candidate_index = 0
if prompt:
candidate_index = self.analyze(prompt)
# Get selected models list
selected_models = self.get_selected_models()
# Ensure index is within bounds
if candidate_index >= len(selected_models):
candidate_index %= len(selected_models)
# Get model type
model_type = selected_models[candidate_index]
# Load and return model class
model_class = self._load_model_class(model_type)
self.model_class = model_class # Cache for later
self._track_model_performance(model_type, start_time)
return model_class
except Exception as e:
logger.error(f"Error in model selection: {e}")
# Always fallback to a valid model
try:
from model_Custm import Wildnerve_tlm01
return Wildnerve_tlm01
except Exception:
logger.critical("Failed to import default model!")
# This function must return something, so create a dummy class
class DummyModel:
def __init__(self, **kwargs): pass
return DummyModel
def get_selected_models(self) -> list:
"""Return the list of selected model types for use in the system"""
# First try getting from config
try:
if hasattr(app_config, 'SELECTED_MODEL'):
models = app_config.SELECTED_MODEL
if models:
return models
except Exception as e:
logger.warning(f"Error reading SELECTED_MODEL from config: {e}")
# Default model types with fallbacks in case primary fails
return ["model_Custm.py", "model_PrTr.py"]
def get_model_instance(self, prompt: str = None) -> Any:
"""Get an initialized model instance based on the analyzed prompt."""
model_class = self.choose_model(prompt)
try:
return model_class()
except Exception as e:
logger.error(f"Error initializing model: {e}")
try:
from model_Custm import Wildnerve_tlm01
return Wildnerve_tlm01()
except Exception:
logger.critical("Could not instantiate any model!")
return None
def get_performance_metrics(self) -> Dict[str, Dict[str, float]]:
"""Get performance metrics for all models."""
return self._performance_metrics
# Register the PromptAnalyzer in the service registry to resolve dependencies.
registry.register("prompt_analyzer", PromptAnalyzer())
def main():
# For testing purposes; in production, model_manager will retrieve the analyzer.
analyzer = registry.get("prompt_analyzer")
sample_prompt = "I'm having trouble debugging my Python code for a sorting algorithm."
primary_topic, subtopics = analyzer.analyze_prompt(sample_prompt)
selected = analyzer.choose_model(sample_prompt)
logger.info(f"Sample prompt analysis:\nPrimary Topic: {primary_topic}\nSubtopics: {subtopics}\nSelected Model: {selected}")
# Test the advanced analysis
if hasattr(analyzer, 'sentence_model') and analyzer.sentence_model:
complexity_index = analyzer.analyze(sample_prompt)
logger.info(f"Complexity analysis index: {complexity_index}")
if __name__ == "__main__":
main()