|
|
|
|
|
import os
|
|
|
import re
|
|
|
import json
|
|
|
import time
|
|
|
import math
|
|
|
import nltk
|
|
|
try:
|
|
|
nltk.data.find('tokenizers/punkt')
|
|
|
except LookupError:
|
|
|
nltk.download("punkt")
|
|
|
|
|
|
import torch
|
|
|
import logging
|
|
|
import numpy as np
|
|
|
import importlib.util
|
|
|
from enum import Enum
|
|
|
from service_registry import registry, MODEL, PRETRAINED_MODEL
|
|
|
from sklearn.metrics.pairwise import cosine_similarity
|
|
|
from typing import List, Tuple, Dict, Type, Any, Optional
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
try:
|
|
|
from config import app_config
|
|
|
except ImportError:
|
|
|
logger.error("Failed to import app_config from config")
|
|
|
|
|
|
app_config = {
|
|
|
"PROMPT_ANALYZER_CONFIG": {
|
|
|
"MODEL_NAME": "gpt2",
|
|
|
"DATASET_PATH": None,
|
|
|
"SPECIALIZATION": None,
|
|
|
"HIDDEN_DIM": 768,
|
|
|
"MAX_CACHE_SIZE": 10
|
|
|
}
|
|
|
}
|
|
|
|
|
|
|
|
|
from utils.smartHybridAttention import SmartHybridAttention, get_hybrid_attention_config
|
|
|
|
|
|
|
|
|
try:
|
|
|
from utils.transformer_utils import get_sentence_transformer
|
|
|
except ImportError:
|
|
|
|
|
|
def get_sentence_transformer(model_name):
|
|
|
try:
|
|
|
from sentence_transformers import SentenceTransformer
|
|
|
return SentenceTransformer(model_name)
|
|
|
except ImportError:
|
|
|
logger.error("sentence_transformers package not available")
|
|
|
|
|
|
class MinimalSentenceTransformer:
|
|
|
def __init__(self, *args, **kwargs):
|
|
|
pass
|
|
|
def encode(self, text):
|
|
|
return [0.0] * 384
|
|
|
return MinimalSentenceTransformer()
|
|
|
|
|
|
from model_Custm import Wildnerve_tlm01 as CustomModel
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO)
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class ModelType(Enum):
|
|
|
CUSTOM = "model_Custm.py"
|
|
|
PRETRAINED = "model_PrTr.py"
|
|
|
|
|
|
|
|
|
|
|
|
from transformers import GPT2Tokenizer, GPT2LMHeadModel
|
|
|
|
|
|
class PromptAnalyzer:
|
|
|
"""
|
|
|
Enhanced prompt analyzer that combines:
|
|
|
- Simple reliable keyword matching for basic topic detection
|
|
|
- Advanced embedding-based analysis with SentenceTransformer when available
|
|
|
- Perplexity calculations with GPT-2 for complexity assessment
|
|
|
- SmartHybridAttention for analyzing complex or long prompts
|
|
|
- Performance tracking and caching for efficiency
|
|
|
"""
|
|
|
def __init__(self, model_name=None, dataset_path=None, specialization=None, hidden_dim=None):
|
|
|
self.logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
try:
|
|
|
if hasattr(app_config, "PROMPT_ANALYZER_CONFIG"):
|
|
|
self.config_data = app_config.PROMPT_ANALYZER_CONFIG
|
|
|
elif isinstance(app_config, dict) and "PROMPT_ANALYZER_CONFIG" in app_config:
|
|
|
self.config_data = app_config["PROMPT_ANALYZER_CONFIG"]
|
|
|
else:
|
|
|
self.config_data = {
|
|
|
"MODEL_NAME": "gpt2",
|
|
|
"DATASET_PATH": None,
|
|
|
"SPECIALIZATION": None,
|
|
|
"HIDDEN_DIM": 768,
|
|
|
"MAX_CACHE_SIZE": 10
|
|
|
}
|
|
|
except Exception as e:
|
|
|
self.logger.warning(f"Error loading config: {e}, using defaults")
|
|
|
self.config_data = {
|
|
|
"MODEL_NAME": "gpt2",
|
|
|
"DATASET_PATH": None,
|
|
|
"SPECIALIZATION": None,
|
|
|
"HIDDEN_DIM": 768,
|
|
|
"MAX_CACHE_SIZE": 10
|
|
|
}
|
|
|
|
|
|
|
|
|
self.model_name = model_name or self._safe_get("MODEL_NAME", "gpt2")
|
|
|
self.dataset_path = dataset_path or self._safe_get("DATASET_PATH")
|
|
|
self.specialization = specialization or self._safe_get("SPECIALIZATION")
|
|
|
self.hidden_dim = hidden_dim or self._safe_get("HIDDEN_DIM", 768)
|
|
|
|
|
|
self.logger.info(f"Initialized PromptAnalyzer with {self.model_name}")
|
|
|
self._model_cache: Dict[str, Type] = {}
|
|
|
self._performance_metrics: Dict[str, Dict[str, float]] = {}
|
|
|
|
|
|
|
|
|
self._load_predefined_topics()
|
|
|
|
|
|
|
|
|
if hasattr(self, 'sentence_model'):
|
|
|
del self.sentence_model
|
|
|
|
|
|
|
|
|
self.sentence_model = get_sentence_transformer('sentence-transformers/all-MiniLM-L6-v2')
|
|
|
self.logger.info(f"Using SentenceTransformer model: sentence-transformers/all-MiniLM-L6-v2")
|
|
|
|
|
|
|
|
|
self.tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
|
|
|
|
|
|
if self.tokenizer.pad_token is None:
|
|
|
self.tokenizer.pad_token = self.tokenizer.eos_token
|
|
|
|
|
|
self.model = GPT2LMHeadModel.from_pretrained("gpt2")
|
|
|
self.model.eval()
|
|
|
|
|
|
logger.info(f"Initialized PromptAnalyzer with {self.model_name}, specialization: {self.specialization}, hidden_dim: {self.hidden_dim}")
|
|
|
if self.dataset_path:
|
|
|
logger.info(f"Using dataset from: {self.dataset_path}")
|
|
|
|
|
|
|
|
|
self._model_cache = {}
|
|
|
self._performance_metrics = {}
|
|
|
|
|
|
|
|
|
self.model_class = None
|
|
|
|
|
|
|
|
|
self.attention = None
|
|
|
|
|
|
|
|
|
self._init_advanced_tools()
|
|
|
|
|
|
|
|
|
self.similarity_threshold = getattr(app_config, "SIMILARITY_THRESHOLD", 0.85)
|
|
|
self.max_cache_size = 10
|
|
|
try:
|
|
|
|
|
|
if hasattr(app_config, 'PROMPT_ANALYZER_CONFIG'):
|
|
|
self.max_cache_size = getattr(app_config.PROMPT_ANALYZER_CONFIG, "MAX_CACHE_SIZE", 10)
|
|
|
except Exception:
|
|
|
pass
|
|
|
|
|
|
def _safe_get(self, key, default=None):
|
|
|
"""Safely get a configuration value regardless of config type"""
|
|
|
try:
|
|
|
if isinstance(self.config_data, dict):
|
|
|
return self.config_data.get(key, default)
|
|
|
elif hasattr(self.config_data, key):
|
|
|
return getattr(self.config_data, key, default)
|
|
|
return default
|
|
|
except:
|
|
|
return default
|
|
|
|
|
|
def _load_predefined_topics(self):
|
|
|
"""Load topic keywords from config file or use defaults with caching"""
|
|
|
|
|
|
try:
|
|
|
if hasattr(app_config, 'TOPIC_KEYWORDS') and app_config.TOPIC_KEYWORDS:
|
|
|
logger.info("Loading topic keywords from config")
|
|
|
self.predefined_topics = app_config.TOPIC_KEYWORDS
|
|
|
return
|
|
|
|
|
|
|
|
|
topic_file = os.path.join(app_config.DATA_DIR, "topic_keywords.json")
|
|
|
if os.path.exists(topic_file):
|
|
|
with open(topic_file, 'r') as f:
|
|
|
self.predefined_topics = json.load(f)
|
|
|
logger.info(f"Loaded {len(self.predefined_topics)} topic categories from {topic_file}")
|
|
|
return
|
|
|
except Exception as e:
|
|
|
logger.warning(f"Error loading topic keywords: {e}, using defaults")
|
|
|
|
|
|
|
|
|
logger.info("Using default hardcoded topic keywords")
|
|
|
self.predefined_topics = {
|
|
|
"programming": [
|
|
|
"python", "java", "javascript", "typescript", "rust", "go", "golang",
|
|
|
|
|
|
],
|
|
|
"computer_science": [
|
|
|
|
|
|
],
|
|
|
"software_engineering": [
|
|
|
|
|
|
],
|
|
|
"web_development": [
|
|
|
|
|
|
]
|
|
|
}
|
|
|
|
|
|
|
|
|
try:
|
|
|
os.makedirs(app_config.DATA_DIR, exist_ok=True)
|
|
|
with open(os.path.join(app_config.DATA_DIR, "topic_keywords.json"), 'w') as f:
|
|
|
json.dump(self.predefined_topics, f, indent=2)
|
|
|
except Exception as e:
|
|
|
logger.debug(f"Could not cache topic keywords: {e}")
|
|
|
|
|
|
def _init_advanced_tools(self):
|
|
|
"""Initialize advanced analysis tools with proper error handling and fallbacks"""
|
|
|
self.sentence_model = None
|
|
|
self.gpt2_model = None
|
|
|
self.gpt2_tokenizer = None
|
|
|
|
|
|
|
|
|
MAX_RETRIES = 3
|
|
|
embedding_models = [
|
|
|
'sentence-transformers/all-MiniLM-L6-v2',
|
|
|
'sentence-transformers/paraphrase-MiniLM-L3-v2',
|
|
|
'sentence-transformers/distilbert-base-nli-mean-tokens'
|
|
|
]
|
|
|
|
|
|
for retry in range(MAX_RETRIES):
|
|
|
for model_name in embedding_models:
|
|
|
try:
|
|
|
from utils.transformer_utils import get_sentence_transformer
|
|
|
self.sentence_model = get_sentence_transformer(model_name)
|
|
|
self.logger.info(f"Successfully loaded SentenceTransformer: {model_name}")
|
|
|
break
|
|
|
except Exception as e:
|
|
|
self.logger.warning(f"Failed to load embedding model {model_name}: {e}")
|
|
|
|
|
|
if self.sentence_model:
|
|
|
break
|
|
|
|
|
|
|
|
|
time.sleep(2)
|
|
|
|
|
|
|
|
|
if not self.sentence_model:
|
|
|
self.logger.warning("All embedding models failed to load - using keyword fallback")
|
|
|
self._use_keyword_fallback = True
|
|
|
else:
|
|
|
self._use_keyword_fallback = False
|
|
|
|
|
|
|
|
|
try:
|
|
|
attention_config = get_hybrid_attention_config()
|
|
|
self.attention = SmartHybridAttention(
|
|
|
dim=attention_config.get("DIM", 768),
|
|
|
num_heads=attention_config.get("NUM_HEADS", 8),
|
|
|
window_size=attention_config.get("WINDOW_SIZE", 256),
|
|
|
use_sliding=attention_config.get("USE_SLIDING", True),
|
|
|
use_global=attention_config.get("USE_GLOBAL", True),
|
|
|
use_hierarchical=attention_config.get("USE_HIERARCHICAL", False)
|
|
|
)
|
|
|
self.logger.info("Initialized SmartHybridAttention for prompt analysis")
|
|
|
except Exception as e:
|
|
|
self.logger.warning(f"Failed to initialize SmartHybridAttention: {e}")
|
|
|
self.attention = None
|
|
|
|
|
|
def _track_model_performance(self, model_type: str, start_time: float) -> None:
|
|
|
"""Track model loading and performance metrics.
|
|
|
|
|
|
Args:
|
|
|
model_type: Type of model being tracked
|
|
|
start_time: Start time of operation
|
|
|
"""
|
|
|
end_time = time.time()
|
|
|
if model_type not in self._performance_metrics:
|
|
|
self._performance_metrics[model_type] = {
|
|
|
'load_time': 0.0,
|
|
|
'usage_count': 0,
|
|
|
'avg_response_time': 0.0
|
|
|
}
|
|
|
|
|
|
|
|
|
metrics = self._performance_metrics[model_type]
|
|
|
metrics['load_time'] = end_time - start_time
|
|
|
metrics['usage_count'] += 1
|
|
|
|
|
|
|
|
|
current_avg = metrics['avg_response_time']
|
|
|
metrics['avg_response_time'] = (
|
|
|
(current_avg * (metrics['usage_count'] - 1) + (end_time - start_time))
|
|
|
/ metrics['usage_count']
|
|
|
)
|
|
|
|
|
|
def manage_cache(self, max_cache_size: int = None) -> None:
|
|
|
"""Manage model cache size and cleanup least used models"""
|
|
|
try:
|
|
|
|
|
|
if max_cache_size is None:
|
|
|
max_cache_size = self.max_cache_size
|
|
|
|
|
|
if len(self._model_cache) > max_cache_size:
|
|
|
|
|
|
sorted_models = sorted(
|
|
|
self._performance_metrics.items(),
|
|
|
key=lambda x: (x[1]['usage_count'], -x[1]['avg_response_time'])
|
|
|
)
|
|
|
|
|
|
|
|
|
for model_type, _ in sorted_models[:-max_cache_size]:
|
|
|
self._model_cache.pop(model_type, None)
|
|
|
logger.info(f"Removed {model_type} from cache due to low usage")
|
|
|
|
|
|
|
|
|
logger.info(f"Cache cleaned up. Current size: {len(self._model_cache)}")
|
|
|
except Exception as e:
|
|
|
logger.error(f"Error managing cache: {e}")
|
|
|
|
|
|
def _load_model_class(self, model_type: str) -> Type:
|
|
|
"""Load model class with caching"""
|
|
|
start_time = time.time()
|
|
|
try:
|
|
|
|
|
|
if model_type in self._model_cache:
|
|
|
self._track_model_performance(model_type, start_time)
|
|
|
return self._model_cache[model_type]
|
|
|
|
|
|
|
|
|
clean_model_type = model_type.replace('.py', '')
|
|
|
|
|
|
|
|
|
if clean_model_type == "model_PrTr" or clean_model_type.endswith("PrTr"):
|
|
|
try:
|
|
|
module = importlib.import_module("model_PrTr")
|
|
|
model_class = getattr(module, "Wildnerve_tlm01")
|
|
|
except Exception as e:
|
|
|
logger.warning(f"Error loading model_PrTr: {e}")
|
|
|
|
|
|
module = importlib.import_module("model_Custm")
|
|
|
model_class = getattr(module, "Wildnerve_tlm01")
|
|
|
else:
|
|
|
|
|
|
try:
|
|
|
module_name = clean_model_type
|
|
|
if not module_name.startswith("model_"):
|
|
|
module_name = f"model_{module_name}"
|
|
|
module = importlib.import_module(module_name)
|
|
|
model_class = getattr(module, "Wildnerve_tlm01")
|
|
|
except Exception as e:
|
|
|
logger.warning(f"Error loading {model_type}: {e}, falling back to CustomModel")
|
|
|
|
|
|
module = importlib.import_module("model_Custm")
|
|
|
model_class = getattr(module, "Wildnerve_tlm01")
|
|
|
|
|
|
|
|
|
self._model_cache[model_type] = model_class
|
|
|
self._track_model_performance(model_type, start_time)
|
|
|
return model_class
|
|
|
|
|
|
except Exception as e:
|
|
|
logger.error(f"Error loading model class {model_type}: {e}")
|
|
|
|
|
|
try:
|
|
|
module = importlib.import_module("model_Custm")
|
|
|
return getattr(module, "Wildnerve_tlm01")
|
|
|
except Exception:
|
|
|
|
|
|
from types import new_class
|
|
|
return new_class("DummyModel", (), {})
|
|
|
|
|
|
def _analyze_with_attention(self, prompt):
|
|
|
"""Use SmartHybridAttention to analyze complex prompts"""
|
|
|
if not self.attention or not self.sentence_model:
|
|
|
return None
|
|
|
|
|
|
try:
|
|
|
|
|
|
sentences = nltk.sent_tokenize(prompt)
|
|
|
|
|
|
if len(sentences) <= 1:
|
|
|
return None
|
|
|
|
|
|
|
|
|
sentence_embeddings = [self.sentence_model.encode(s) for s in sentences]
|
|
|
embeddings_tensor = torch.tensor(sentence_embeddings).unsqueeze(1)
|
|
|
|
|
|
|
|
|
attended_embeddings, attention_weights = self.attention(
|
|
|
query=embeddings_tensor,
|
|
|
key=embeddings_tensor,
|
|
|
value=embeddings_tensor,
|
|
|
input_text=prompt
|
|
|
)
|
|
|
|
|
|
|
|
|
importance = attention_weights.mean(dim=(0,1)).squeeze()
|
|
|
if len(importance.shape) == 0:
|
|
|
importance = importance.unsqueeze(0)
|
|
|
|
|
|
|
|
|
top_indices = torch.argsort(importance, descending=True)[:min(3, len(sentences))]
|
|
|
|
|
|
|
|
|
topic_scores = {topic: 0.0 for topic in self.predefined_topics}
|
|
|
for idx in top_indices:
|
|
|
sentence = sentences[idx.item()]
|
|
|
weight = importance[idx].item() / importance.sum().item()
|
|
|
|
|
|
|
|
|
for topic, keywords in self.predefined_topics.items():
|
|
|
sent_lower = sentence.lower()
|
|
|
sent_score = sum(1 for keyword in keywords if keyword in sent_lower)
|
|
|
topic_scores[topic] += sent_score * weight * 1.5
|
|
|
|
|
|
return topic_scores
|
|
|
except Exception as e:
|
|
|
self.logger.error(f"Error in attention-based analysis: {e}")
|
|
|
return None
|
|
|
|
|
|
def _analyze_with_keywords(self, prompt: str) -> Tuple[str, float]:
|
|
|
"""Analyze prompt using only keywords when embeddings are unavailable"""
|
|
|
prompt_lower = prompt.lower()
|
|
|
technical_matches = 0
|
|
|
total_words = len(prompt_lower.split())
|
|
|
|
|
|
|
|
|
for category, keywords in self.predefined_topics.items():
|
|
|
for keyword in keywords:
|
|
|
if keyword in prompt_lower:
|
|
|
technical_matches += 1
|
|
|
|
|
|
|
|
|
match_ratio = technical_matches / max(1, min(15, total_words))
|
|
|
|
|
|
if match_ratio > 0.1:
|
|
|
return "model_Custm", match_ratio
|
|
|
else:
|
|
|
return "model_PrTr", 0.7
|
|
|
|
|
|
def analyze_prompt(self, prompt: str) -> Tuple[str, float]:
|
|
|
"""Analyze if a prompt is technical or general and return the appropriate model type and confidence score."""
|
|
|
|
|
|
if hasattr(self, '_use_keyword_fallback') and self._use_keyword_fallback:
|
|
|
return self._analyze_with_keywords(prompt)
|
|
|
|
|
|
|
|
|
prompt_lower = prompt.lower()
|
|
|
|
|
|
|
|
|
technical_matches = 0
|
|
|
word_count = len(prompt_lower.split())
|
|
|
|
|
|
|
|
|
prompt_words = set(prompt_lower.split())
|
|
|
|
|
|
|
|
|
for category, keywords in self.predefined_topics.items():
|
|
|
|
|
|
keywords_set = set(keywords)
|
|
|
matches = prompt_words.intersection(keywords_set)
|
|
|
technical_matches += len(matches)
|
|
|
|
|
|
|
|
|
for keyword in keywords:
|
|
|
if " " in keyword and keyword in prompt_lower:
|
|
|
technical_matches += 1
|
|
|
|
|
|
|
|
|
keyword_ratio = technical_matches / max(1, min(20, word_count))
|
|
|
|
|
|
|
|
|
attention_scores = None
|
|
|
if len(prompt) > 100 and self.attention:
|
|
|
try:
|
|
|
attention_scores = self._analyze_with_attention(prompt)
|
|
|
except Exception as e:
|
|
|
self.logger.warning(f"Error in attention analysis: {e}")
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
prompt_embedding = self.sentence_model.encode(prompt)
|
|
|
|
|
|
|
|
|
technical_reference = "Write code to solve a programming problem using algorithms and data structures."
|
|
|
general_reference = "Tell me about daily life topics like weather, food, or general conversation."
|
|
|
|
|
|
|
|
|
technical_embedding = self.sentence_model.encode(technical_reference)
|
|
|
general_embedding = self.sentence_model.encode(general_reference)
|
|
|
|
|
|
|
|
|
technical_similarity = cosine_similarity([prompt_embedding], [technical_embedding])[0][0]
|
|
|
general_similarity = cosine_similarity([prompt_embedding], [general_embedding])[0][0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
technical_score = 0.3 * keyword_ratio + 0.4 * technical_similarity
|
|
|
|
|
|
|
|
|
if attention_scores:
|
|
|
|
|
|
tech_attention_score = (
|
|
|
attention_scores.get("programming", 0) +
|
|
|
attention_scores.get("computer_science", 0) +
|
|
|
attention_scores.get("software_engineering", 0) +
|
|
|
attention_scores.get("web_development", 0)
|
|
|
) / 4.0
|
|
|
technical_score += 0.3 * tech_attention_score
|
|
|
|
|
|
|
|
|
if technical_score > 0.3:
|
|
|
return "model_Custm", technical_score
|
|
|
else:
|
|
|
return "model_PrTr", 1.0 - technical_score
|
|
|
|
|
|
except Exception as e:
|
|
|
self.logger.error(f"Error in prompt analysis: {e}")
|
|
|
|
|
|
|
|
|
if technical_matches > 0:
|
|
|
return "model_Custm", 0.7
|
|
|
else:
|
|
|
return "model_PrTr", 0.7
|
|
|
|
|
|
def analyze(self, prompt: str) -> int:
|
|
|
"""Legacy compatibility method that returns a candidate index."""
|
|
|
model_type, confidence = self.analyze_prompt(prompt)
|
|
|
|
|
|
|
|
|
if model_type == "model_Custm":
|
|
|
return 0
|
|
|
else:
|
|
|
return 1
|
|
|
|
|
|
def choose_model(self, prompt: str = None) -> Type:
|
|
|
"""Enhanced model selection that combines config and analysis"""
|
|
|
try:
|
|
|
start_time = time.time()
|
|
|
|
|
|
|
|
|
if self.model_class:
|
|
|
return self.model_class
|
|
|
|
|
|
|
|
|
candidate_index = 0
|
|
|
if prompt:
|
|
|
candidate_index = self.analyze(prompt)
|
|
|
|
|
|
|
|
|
selected_models = self.get_selected_models()
|
|
|
|
|
|
|
|
|
if candidate_index >= len(selected_models):
|
|
|
candidate_index %= len(selected_models)
|
|
|
|
|
|
|
|
|
model_type = selected_models[candidate_index]
|
|
|
|
|
|
|
|
|
model_class = self._load_model_class(model_type)
|
|
|
self.model_class = model_class
|
|
|
self._track_model_performance(model_type, start_time)
|
|
|
return model_class
|
|
|
|
|
|
except Exception as e:
|
|
|
logger.error(f"Error in model selection: {e}")
|
|
|
|
|
|
try:
|
|
|
from model_Custm import Wildnerve_tlm01
|
|
|
return Wildnerve_tlm01
|
|
|
except Exception:
|
|
|
logger.critical("Failed to import default model!")
|
|
|
|
|
|
class DummyModel:
|
|
|
def __init__(self, **kwargs): pass
|
|
|
return DummyModel
|
|
|
|
|
|
def get_selected_models(self) -> list:
|
|
|
"""Return the list of selected model types for use in the system"""
|
|
|
|
|
|
try:
|
|
|
if hasattr(app_config, 'SELECTED_MODEL'):
|
|
|
models = app_config.SELECTED_MODEL
|
|
|
if models:
|
|
|
return models
|
|
|
except Exception as e:
|
|
|
logger.warning(f"Error reading SELECTED_MODEL from config: {e}")
|
|
|
|
|
|
|
|
|
return ["model_Custm.py", "model_PrTr.py"]
|
|
|
|
|
|
def get_model_instance(self, prompt: str = None) -> Any:
|
|
|
"""Get an initialized model instance based on the analyzed prompt."""
|
|
|
model_class = self.choose_model(prompt)
|
|
|
try:
|
|
|
return model_class()
|
|
|
except Exception as e:
|
|
|
logger.error(f"Error initializing model: {e}")
|
|
|
try:
|
|
|
from model_Custm import Wildnerve_tlm01
|
|
|
return Wildnerve_tlm01()
|
|
|
except Exception:
|
|
|
logger.critical("Could not instantiate any model!")
|
|
|
return None
|
|
|
|
|
|
def get_performance_metrics(self) -> Dict[str, Dict[str, float]]:
|
|
|
"""Get performance metrics for all models."""
|
|
|
return self._performance_metrics
|
|
|
|
|
|
|
|
|
registry.register("prompt_analyzer", PromptAnalyzer())
|
|
|
|
|
|
def main():
|
|
|
|
|
|
analyzer = registry.get("prompt_analyzer")
|
|
|
sample_prompt = "I'm having trouble debugging my Python code for a sorting algorithm."
|
|
|
primary_topic, subtopics = analyzer.analyze_prompt(sample_prompt)
|
|
|
selected = analyzer.choose_model(sample_prompt)
|
|
|
logger.info(f"Sample prompt analysis:\nPrimary Topic: {primary_topic}\nSubtopics: {subtopics}\nSelected Model: {selected}")
|
|
|
|
|
|
|
|
|
if hasattr(analyzer, 'sentence_model') and analyzer.sentence_model:
|
|
|
complexity_index = analyzer.analyze(sample_prompt)
|
|
|
logger.info(f"Complexity analysis index: {complexity_index}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
main() |