import gc, os, sys, time, torch, logging, inspect, numpy as np, pandas as pd, importlib.util from pathlib import Path from threading import Lock from collections import OrderedDict from nltk.stem import WordNetLemmatizer from typing import List, Dict, Any, Tuple, Optional from sklearn.metrics.pairwise import cosine_similarity from config import app_config from dataset import TensorDataset from utils.transformer_utils import get_sentence_transformer from utils.smartHybridAttention import SmartHybridAttention, get_hybrid_attention_config from transformers import AutoModelForCausalLM, AutoTokenizer from service_registry import registry, MODEL, TOKENIZER, MODEL_MANAGER, COMMUNICATOR logger = logging.getLogger(__name__) try: import psutil PSUTIL_AVAILABLE = True except ImportError: logger.warning("psutil not available") PSUTIL_AVAILABLE = False class DummyProcess: def __init__(self, pid=None): self.pid = pid or 1 def memory_info(self): class MemInfo: def __init__(self): self.rss = 1e6; self.vms = 1e6 return MemInfo() def memory_percent(self): return 1.0 class DummyPsutil: @staticmethod def Process(pid=None): return DummyProcess(pid) psutil = DummyPsutil() def safe_get_config(config_obj, key, default=None): if isinstance(config_obj, dict): return config_obj.get(key, default) elif hasattr(config_obj, key): return getattr(config_obj, key, default) return default def safe_get_config_value(config_obj, key, default=None): try: if isinstance(config_obj, dict): return config_obj.get(key, default) elif hasattr(config_obj, key): return getattr(config_obj, key, default) elif isinstance(config_obj, (int, float, str, bool)): return config_obj return default except: return default class DatasetManager: def __init__(self): self.datasets: Dict[str, Any] = {} self.lock = Lock() def load_dataset(self, path: str, specialization: str) -> Any: with self.lock: if specialization in self.datasets: logger.info(f"Using cached dataset for {specialization}") return self.datasets[specialization] dataset = self._load_and_process_dataset(path, specialization) self.datasets[specialization] = dataset return dataset def _load_and_process_dataset(self, path: str, specialization: str) -> TensorDataset: if not os.path.exists(path): raise FileNotFoundError(f"Dataset {path} not found.") logger.info(f"Loading dataset: {specialization}") data = pd.read_csv(path) if "label" not in data.columns: raise ValueError("Dataset must have a 'label' column.") features = data.drop("label", axis=1).values labels = data["label"].values features_tensor = torch.tensor(features, dtype=torch.float32) labels_tensor = torch.tensor(labels, dtype=torch.long) return TensorDataset(features_tensor, labels_tensor) def get_status(self) -> Dict[str, Any]: return {"loaded_datasets": list(self.datasets.keys()), "cache_size": len(self.datasets)} def clear_cache(self): with self.lock: self.datasets.clear() class ModelManager: HF_ALIAS = { "EvolphTech/Wildnerve-tlm01_Hybrid_Model": "model_Combn.Wildnerve_tlm01_Hybrid_Model", "model_Custm": "EvolphTech/Wildnerve-tlm01_Hybrid_Model", "model_Custm.py": "EvolphTech/Wildnerve-tlm01_Hybrid_Model", } def __init__(self, tokenizer=None, max_active_models=5, model_idle_threshold=600): self.models = {} self.lock = Lock() self.model_pool = OrderedDict() self.max_active_models = max_active_models if isinstance(max_active_models, int) and max_active_models > 0 else 2 self.model_idle_threshold = model_idle_threshold if isinstance(model_idle_threshold, int) and model_idle_threshold > 0 else 600 self.tokenizer = tokenizer # Remove hardcoded specializations and use config values # First try SPECIALIZATIONS directly from config if hasattr(app_config, 'SPECIALIZATIONS') and app_config.SPECIALIZATIONS: self.specializations = app_config.SPECIALIZATIONS # Then try keys from DATASET_PATHS elif hasattr(app_config, 'DATASET_PATHS') and isinstance(app_config.DATASET_PATHS, dict): self.specializations = list(app_config.DATASET_PATHS.keys()) # Fallback to minimal set else: self.specializations = ["general", "programming", "mathematics"] logger.info(f"Using {len(self.specializations)} specializations from config") self._performance_metrics = {} attention_config = get_hybrid_attention_config() self.smart_attention = SmartHybridAttention( dim=attention_config["DIM"], num_heads=attention_config["NUM_HEADS"], window_size=attention_config["WINDOW_SIZE"], use_sliding=attention_config["USE_SLIDING"], use_global=attention_config["USE_GLOBAL"], use_hierarchical=attention_config["USE_HIERARCHICAL"], global_token_ratio=attention_config["GLOBAL_TOKEN_RATIO"], memory_tokens=attention_config["MEMORY_TOKENS"] ) self.dataset_manager = DatasetManager() transformer_config = safe_get_config(app_config, "TRANSFORMER_CONFIG", {}) model_name = safe_get_config(transformer_config, "MODEL_NAME", "Wildnerve-tlm01-0.05Bx12") self.embedding_model = get_sentence_transformer(model_name) self.similarity_threshold = safe_get_config(app_config, "SIMILARITY_THRESHOLD", 0.85) self.top_k = safe_get_config(app_config, "TOP_K", 3) self.prompt_analyzer = None self.selected_models = self._get_selected_models() try: self._load_models() except Exception: logger.critical("Startup model loading failed, aborting ModelManager init", exc_info=True) raise logger.info(f"ModelManager initialized with {len(self.specializations)} specializations") def _get_selected_models(self) -> List[str]: model_files = safe_get_config(app_config, "SELECTED_MODEL", ["model_Custm.py"]) return model_files if model_files else ["model_Custm.py"] def _import_model_class(self, model_key: str): """Robust import of model classes or HF hub repos; raises on failure.""" key = model_key.rstrip(".py") alias = self.HF_ALIAS.get(key) try: # 1) HF hub repo if alias and "/" in alias: logger.info(f"Loading HF model from repo '{alias}'") model = AutoModelForCausalLM.from_pretrained(alias, use_auth_token=os.getenv("HF_TOKEN")) tok = AutoTokenizer.from_pretrained(alias, use_auth_token=os.getenv("HF_TOKEN")) model.tokenizer = tok return model.__class__ # return the class (caller will instantiate) # 2) Explicit module.Class mapping if alias and "." in alias: module_name, cls_name = alias.split(".", 1) mod = importlib.import_module(module_name) return getattr(mod, cls_name) # 3) Local file fallback file_path = os.path.join(os.path.dirname(__file__), f"{key}.py") if os.path.isfile(file_path): spec = importlib.util.spec_from_file_location(key, file_path) mod = importlib.util.module_from_spec(spec) spec.loader.exec_module(mod) cls = getattr(mod, key, None) or getattr(mod, "Wildnerve_tlm01", None) if cls: return cls raise ImportError(f"No class '{key}' or 'Wildnerve_tlm01' in {file_path}") # 4) Standard Python import mod = importlib.import_module(key) cls = getattr(mod, key, None) or getattr(mod, "Wildnerve_tlm01", None) if cls: return cls raise ImportError(f"No class '{key}' or 'Wildnerve_tlm01' in module '{key}'") except Exception as e: logger.error(f"Failed to import model class for '{model_key}': {e}", exc_info=True) raise def _initialize_model_for_specialization(self, spec: str, data_dir: str): """Instantiate a model for a given spec, with dataset-path handling and timeout warnings.""" # Resolve dataset path ds_paths = safe_get_config_value(app_config, "DATASET_PATHS", {}) raw = ds_paths.get(spec) if isinstance(raw, (list, tuple)): dataset_path = raw[0] else: dataset_path = raw or os.path.join(data_dir, f"{spec}.csv") # Ensure dataset exists (create minimal CSV if missing) if not os.path.exists(dataset_path): try: with open(dataset_path, "w") as f: f.write("text,label\nsample text,0\n") logger.info(f"Created placeholder dataset for '{spec}'") except Exception as e: logger.warning(f"Could not create dataset for '{spec}': {e}") # Import and instantiate model with GPT-2 parameters instead of BERT model_cls = self._import_model_class(self.selected_models[0]) params = dict( vocab_size=50257, # GPT-2 vocab size specialization=spec, dataset_path=dataset_path, model_name=safe_get_config_value(app_config, "TRANSFORMER_CONFIG", {}).get("MODEL_NAME", "gpt2"), embedding_dim=768, # Ensure 768-dimensional model num_heads=12, # 12 heads for 768-dim hidden_dim=768, # Ensure 768-dimensional model num_layers=12, # More layers for larger model output_size=50257, # GPT-2 vocab size dropout=0.1, max_seq_length=1024, # Increased for 768-dim model pooling_mode=safe_get_config_value(app_config, "TRANSFORMER_CONFIG", {}).get("POOLING_MODE", "last"), tokenizer=self.tokenizer ) start = time.time() try: model = model_cls(**params) except Exception as e: logger.error(f"Error instantiating '{model_cls.__name__}' for '{spec}': {e}", exc_info=True) raise elapsed = time.time() - start if elapsed > 30: logger.warning(f"Model creation for '{spec}' took {elapsed:.1f}s (>30s)") # Register self.models[spec] = model self.model_pool[spec] = None self._performance_metrics[spec] = { "inference_time": 0.0, "memory_usage": 0.0, "last_accessed": time.time(), "num_inferences": 0 } def _load_models(self): """Load initial specializations - now only preloading 'general' to save resources.""" # List of specializations to preload - can be expanded if desired initial = ["general"] # Only preload the 'general' specialization for efficiency # Uncomment the following line to preload more specializations # initial = self.specializations[:2] # Load first 2 specializations data_dir = os.environ.get("TLM_DATA_DIR", "/tmp/tlm_data") os.makedirs(data_dir, exist_ok=True) for spec in initial: logger.info(f"Initializing model for specialization '{spec}'") try: self._initialize_model_for_specialization(spec, data_dir) logger.info(f"Model for '{spec}' loaded successfully") except Exception: logger.error(f"Failed to load model for '{spec}'", exc_info=True) raise logger.info(f"{len(self.models)} models loaded at startup (of {len(self.specializations)} total)") # Add debug info about available specializations logger.debug(f"Available specializations: {', '.join(self.specializations)}") def get_or_create_model(self, specialization: str) -> Any: """Get an existing model or create it on demand if not already loaded""" with self.lock: # Check if model already exists model = self.get_model(specialization) if (model): logger.info(f"Using existing model for {specialization}") return model # Check if it's a valid specialization if specialization not in self.all_specializations and specialization != "general": logger.warning(f"Unknown specialization: {specialization}, using general") specialization = "general" # Create model if needed logger.info(f"Lazily loading model for {specialization}") # Remove least recently used model if needed if len(self.models) >= self.max_active_models: lru_specialization = next(iter(self.model_pool)) self.remove_model_instance(lru_specialization) # Initialize the requested model data_dir = os.environ.get("TLM_DATA_DIR", "/tmp/tlm_data") try: self._initialize_model_for_specialization(specialization, data_dir) return self.models.get(specialization) except Exception as e: logger.error(f"Error initializing model: {e}") # Fallback to general model if specialization != "general" and "general" in self.models: return self.models["general"] # Last resort - create a minimal model return self._create_minimal_model() def _create_minimal_model(self): """Create a minimal fallback model for emergencies""" try: from model_Custm import Wildnerve_tlm01 model = Wildnerve_tlm01( vocab_size=50257, # GPT-2 vocab size specialization="minimal", dataset_path=None, model_name="gpt2", # Use GPT-2 instead of BERT embedding_dim=768, num_heads=12, hidden_dim=768, num_layers=2, # Reduced layers output_size=50257, # Match GPT-2 vocab dropout=0.1, max_seq_length=128, # Reduced sequence length pooling_mode="last", # GPT-2 uses last token tokenizer=self.tokenizer ) model._is_minimal = True # Mark as minimal model return model except Exception as e: logger.error(f"Failed to create minimal model: {e}") return None def get_model(self, specialization: str) -> Any: with self.lock: model = self.models.get(specialization) if model: self.model_pool.move_to_end(specialization) if specialization in self._performance_metrics: self._performance_metrics[specialization]["last_accessed"] = time.time() return model def route_input(self, input_text: str) -> dict: # Create embedding for input text input_embedding = self.embedding_model.encode(input_text) # Process input through SmartHybridAttention for enhanced understanding if hasattr(self, 'smart_attention') and self.smart_attention: try: # Convert embedding to tensor format needed by attention import torch input_tensor = torch.tensor(input_embedding).unsqueeze(0).unsqueeze(0) # [1, 1, dim] # Process through attention mechanism to extract key patterns enhanced, _ = self.smart_attention( # FIXED: Properly unpack tuple query=input_tensor, key=input_tensor, value=input_tensor ) # Convert back to numpy for similarity calculations if isinstance(enhanced, torch.Tensor): enhanced_embedding = enhanced.squeeze().cpu().numpy() # Use enhanced embedding for similarity calculation input_embedding = enhanced_embedding logger.info("Using SmartHybridAttention for enhanced prompt routing") except Exception as e: logger.warning(f"Error using SmartHybridAttention: {e}") # Continue with existing similarity calculation similarities = {} for spec in self.specializations: model = self.get_model(spec) if model and hasattr(model, "embedding"): sim = cosine_similarity(input_embedding.reshape(1, -1), model.embedding.reshape(1, -1))[0][0] similarities[spec] = sim if similarities: best_match = max(similarities.items(), key=lambda x: x[1]) return {"matched_specialization": best_match[0], "confidence": best_match[1], "all_scores": similarities} return {"matched_specialization": self.specializations[0], "confidence": 0.0, "all_scores": similarities} def get_model_for_prompt(self, prompt: str) -> Tuple[Any, str]: try: routing_result = self.route_input(prompt) specialization = routing_result.get("matched_specialization", self.specializations[0]) model = self.get_or_create_model(specialization) start_time = time.time() def update_metrics(): if specialization in self._performance_metrics: m = self._performance_metrics[specialization] elapsed = time.time() - start_time n = m.get("num_inferences", 0) + 1 m["inference_time"] = ((m.get("inference_time", 0) * (n-1)) + elapsed) / n m["num_inferences"] = n m["last_accessed"] = time.time() if hasattr(model, "get_memory_usage"): m["memory_usage"] = model.get_memory_usage() update_metrics() return model, specialization except Exception as e: logger.error(f"Error selecting model: {e}") if self.models: default_key = list(self.models.keys())[0] return self.models[default_key], default_key else: logger.error("No models available for routing") return None, "none" def generate(self, prompt: str, **kwargs): self.validate_input(prompt) model, specialization = self.get_model_for_prompt(prompt) start_time = time.time() try: result = model.generate(prompt=prompt, **kwargs) elapsed = time.time() - start_time if specialization in self._performance_metrics: m = self._performance_metrics[specialization] n = m.get("num_inferences", 0) + 1 m["inference_time"] = ((m.get("inference_time", 0) * (n-1)) + elapsed) / n m["num_inferences"] = n m["last_accessed"] = time.time() return result except Exception as e: logger.error(f"Error generating with {specialization}: {e}") default_spec = self.specializations[0] default_model = self.get_or_create_model(default_spec) return default_model.generate(prompt=prompt, **kwargs) def generate_streaming(self, prompt: str, **kwargs): self.validate_input(prompt) model, specialization = self.get_model_for_prompt(prompt) start_time = time.time() try: if hasattr(model, "generate_streaming"): for token in model.generate_streaming(prompt=prompt, **kwargs): yield token else: logger.info("Simulating streaming generation") result = model.generate(prompt=prompt, **kwargs) for word in result.split(): yield word + " " elapsed = time.time() - start_time if specialization in self._performance_metrics: m = self._performance_metrics[specialization] n = m.get("num_inferences", 0) + 1 m["inference_time"] = ((m.get("inference_time", 0) * (n-1)) + elapsed) / n m["num_inferences"] = n m["last_accessed"] = time.time() except Exception as e: logger.error(f"Error in streaming generation: {e}") default_spec = self.specializations[0] default_model = self.get_or_create_model(default_spec) if hasattr(default_model, "generate_streaming"): for token in default_model.generate_streaming(prompt=prompt, **kwargs): yield token else: fallback_result = default_model.generate(prompt=prompt, **kwargs) for word in fallback_result.split(): yield word + " " def remove_model_instance(self, specialization: str) -> bool: with self.lock: if specialization in self.models: del self.models[specialization] self.model_pool.pop(specialization, None) gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() logger.info(f"Removed model for {specialization}") return True return False def validate_input(self, input_text: str) -> bool: if not input_text or len(input_text.strip()) == 0: raise ValueError("Empty input text") max_length = safe_get_config(app_config, "MAX_INPUT_LENGTH", safe_get_config(app_config, "MAX_SEQ_LENGTH", 128)) if len(input_text) > max_length: raise ValueError(f"Input exceeds maximum length of {max_length}") return True def get_health_status(self) -> Dict[str, Any]: with self.lock: process = psutil.Process(os.getpid()) mem_info = process.memory_info() return { "active_models": len(self.models), "memory_usage": { "rss_mb": mem_info.rss / (1024 * 1024), "vms_mb": mem_info.vms / (1024 * 1024), "percent": process.memory_percent() }, "model_performance": self._get_model_metrics(), "dataset_status": self.dataset_manager.get_status(), "cache_efficiency": len(self.model_pool) / max(1, self.max_active_models) } def _get_model_metrics(self) -> Dict[str, Dict[str, Any]]: metrics = {} for spec, model in self.models.items(): base = self._performance_metrics.get(spec, {}) mem_usage = 0 if hasattr(model, "get_memory_usage"): mem_usage = model.get_memory_usage() elif hasattr(model, "parameters"): mem_usage = sum(p.numel() * p.element_size() for p in model.parameters()) / (1024 * 1024) metrics[spec] = { "inference_time": base.get("inference_time", 0), "memory_usage_mb": mem_usage, "last_accessed": base.get("last_accessed", 0), "num_inferences": base.get("num_inferences", 0), "model_type": model.__class__.__name__ } return metrics def get_available_models(self) -> Dict[str, Any]: with self.lock: return dict(self.models) def shutdown(self): try: logger.info("Initiating shutdown") for spec in list(self.models.keys()): self.remove_model_instance(spec) self.dataset_manager.clear_cache() logger.info("Shutdown complete") except Exception as e: logger.error(f"Error during shutdown: {e}") def manage_model_cache(self): try: current = time.time() with self.lock: while len(self.models) > self.max_active_models: oldest = next(iter(self.model_pool)) self.remove_model_instance(oldest) logger.info(f"Removed LRU model: {oldest}") for spec, last in list(self.model_pool.items()): m = self._performance_metrics.get(spec, {}) if m.get("last_accessed", 0) and (current - m["last_accessed"] > self.model_idle_threshold): self.remove_model_instance(spec) logger.info(f"Removed idle model: {spec}") sorted_models = sorted(self.model_pool.items(), key=lambda x: self._performance_metrics.get(x[0], {}).get("last_accessed", 0), reverse=True) self.model_pool = OrderedDict(sorted_models) except Exception as e: logger.error(f"Error in cache management: {e}") def set_tokenizer(self, tokenizer): self.tokenizer = tokenizer with self.lock: for name, model in self.models.items(): if hasattr(model, "set_tokenizer"): try: model.tokenizer = tokenizer logger.debug(f"Updated tokenizer for {name}") except Exception as ex: logger.warning(f"Failed to set tokenizer for {name}: {ex}") logger.info("Tokenizer updated for models") return self def initialize_models(self): try: logger.info("Initializing models from weights") prompt_analyzer = registry.get("prompt_analyzer") if not prompt_analyzer: try: from pathlib import Path model_list_path = Path(__file__).parent / "model_List.py" if model_list_path.exists(): spec = importlib.util.find_spec("model_List") if spec: model_list = importlib.util.module_from_spec(spec) spec.loader.exec_module(model_list) if hasattr(model_list, "PromptAnalyzer"): prompt_analyzer = model_list.PromptAnalyzer() registry.register("prompt_analyzer", prompt_analyzer) logger.info("Imported PromptAnalyzer") except Exception as e: logger.error(f"Error importing PromptAnalyzer: {e}") self.prompt_analyzer = prompt_analyzer selected_models_list = prompt_analyzer.get_selected_models() if prompt_analyzer and hasattr(prompt_analyzer, "get_selected_models") else ["model_Custm.py"] logger.info(f"Selected model types: {selected_models_list}") # Use specializations from class property rather than hardcoding selected_specializations = self.specializations[:5] # Only load the first 5 for spec in selected_specializations: try: model_name = selected_models_list[0].replace(".py", "") from pathlib import Path model_path = Path(__file__).parent / f"{model_name}.py" if model_path.exists(): spec_obj = importlib.util.find_spec(model_name) if spec_obj: model_module = importlib.util.module_from_spec(spec_obj) spec_obj.loader.exec_module(model_module) if hasattr(model_module, "Wildnerve_tlm01"): model_class = getattr(model_module, "Wildnerve_tlm01") embedding_dim = 768 num_heads = 12 if embedding_dim % 12 == 0 else 1 model_instance = model_class( vocab_size=50257, # GPT-2 vocab size specialization=spec, dataset_path=None, model_name="gpt2", # Changed from bert-base-uncased embedding_dim=embedding_dim, num_heads=num_heads, hidden_dim=768, num_layers=2, output_size=50257, # Match GPT-2 vocab dropout=0.1, max_seq_length=128, pooling_mode="last" # GPT-2 uses last token ) self.models[spec] = model_instance logger.info(f"Created model for {spec}") except Exception as e: logger.error(f"Error creating model for {spec}: {e}") if not self.models: logger.error("No models created") return False try: import os attention_config_path = os.path.join(app_config.DATA_DIR, "attention_configuration.json") from utils.attention_connector import get_attention_connector attention_connector = get_attention_connector() if hasattr(attention_connector, "config_path"): attention_connector.config_path = attention_config_path attention_connector._init_profile_selector() logger.info(f"Initialized attention connector with config: {attention_config_path}") except Exception as e: logger.warning(f"Failed to initialize attention connector: {e}") logger.info(f"Successfully initialized {len(self.models)} models") return True except Exception as e: logger.error(f"Error initializing models: {e}", exc_info=True) return False def get_alternative_model_for_prompt(self, prompt: str, current_model=None) -> any: try: if self.prompt_analyzer and hasattr(self.prompt_analyzer, "choose_model"): model_type = self.prompt_analyzer.choose_model(prompt) if model_type: # Creates an instance of the model chosen by prompt_analyzer from model_Custm import Wildnerve_tlm01 alt_model = Wildnerve_tlm01( vocab_size=50257, # GPT-2 vocab size specialization="general", dataset_path=None, model_name="gpt2", # Changed from bert-base-uncased embedding_dim=768, num_heads=12, hidden_dim=768, num_layers=6, output_size=50257, # Match GPT-2 vocab dropout=0.1, max_seq_length=512, pooling_mode="last", # GPT-2 uses last token tokenizer=self.tokenizer ) if alt_model != current_model: logger.info("Found alternative model via prompt_analyzer") return alt_model for name, model in self.get_available_models().items(): if model != current_model: logger.info(f"Using alternative model: {name}") return model try: from model_Custm import Wildnerve_tlm01 fallback_model = Wildnerve_tlm01( vocab_size=50257, # GPT-2 vocab size specialization="general", model_name="gpt2", # Changed from bert-base-uncased embedding_dim=768, num_heads=12, hidden_dim=768, num_layers=6, output_size=50257, # Match GPT-2 vocab dropout=0.1, max_seq_length=512, pooling_mode="last", # GPT-2 uses last token tokenizer=self.tokenizer ) logger.info("Created fallback model") return fallback_model except Exception as e: logger.error(f"Error creating fallback model: {e}") return None except Exception as e: logger.error(f"Error getting alternative model: {e}") return None def prepare_model_input(self, text: str, model) -> dict: device = next(model.parameters()).device try: tokenizer = getattr(model, "tokenizer", None) if tokenizer: inputs = tokenizer( text, return_tensors="pt", padding=True, truncation=True, max_length=safe_get_config_value(app_config, "MAX_SEQ_LENGTH", 512) ) input_ids = inputs["input_ids"].to(device) return {"input_ids": input_ids, "max_length": safe_get_config_value(app_config, "MAX_SEQ_LENGTH", 512), "device": device, "temperature": getattr(self, "generation_config", {}).get("temperature", 0.7)} else: logger.warning("No tokenizer in model; using basic input") return {"input_text": text, "max_length": safe_get_config_value(app_config, "MAX_SEQ_LENGTH", 512)} except Exception as e: logger.error(f"Error preparing model input: {e}") return {"input_text": text} def process_with_context(self, input_text: str, context: Optional[dict] = None) -> dict: conversation_context = self.get_conversation_context(window_size=3) contextualized_prompt = input_text if (conversation_context): max_seq_length = safe_get_config_value(app_config, "MAX_SEQ_LENGTH", 512) max_seq_length = int(max_seq_length) if isinstance(max_seq_length, (int, float)) else 512 contextualized_prompt = f"Previous conversation:\n{conversation_context}\n\nCurrent question: {input_text}" result = self.process_input(contextualized_prompt, context) if isinstance(result, dict): result["original_query"] = input_text return result def get_conversation_context(self, window_size: int = 3) -> str: if not hasattr(self, "conversation_history"): self.conversation_history = [] recent = self.conversation_history[-window_size*2:] lines = [] for entry in recent: prefix = "User: " if entry["role"]=="user" else "Assistant: " lines.append(f"{prefix}{entry['content']}") return "\n".join(lines) # Factory methods for model manager creation def create_model_manager(tokenizer=None) -> ModelManager: try: max_active_models = safe_get_config_value(app_config, "MAX_ACTIVE_MODELS", 2) model_idle_threshold = safe_get_config_value(app_config, "MODEL_IDLE_THRESHOLD", 600) manager = ModelManager(tokenizer=tokenizer, max_active_models=max_active_models, model_idle_threshold=model_idle_threshold) if tokenizer: manager.set_tokenizer(tokenizer) elif registry.has(TOKENIZER): manager.set_tokenizer(registry.get(TOKENIZER)) registry.register(MODEL_MANAGER, manager) return manager except Exception as e: logger.error(f"Error creating ModelManager: {e}") minimal_manager = ModelManager(tokenizer=tokenizer, max_active_models=1) registry.register(MODEL_MANAGER, minimal_manager) return minimal_manager def create_model_manager_with_tokenizer(tokenizer): try: max_active_models = safe_get_config_value(app_config, "MAX_ACTIVE_MODELS", 2) model_idle_threshold = safe_get_config_value(app_config, "MODEL_IDLE_THRESHOLD", 600) manager = ModelManager(max_active_models=max_active_models, model_idle_threshold=model_idle_threshold) manager.tokenizer = tokenizer manager.initialize_models() registry.register(MODEL_MANAGER, manager) return manager except Exception as e: logger.error(f"Error creating ModelManager with tokenizer: {e}") minimal_manager = ModelManager(max_active_models=1) minimal_manager.tokenizer = tokenizer registry.register(MODEL_MANAGER, minimal_manager) return minimal_manager if __name__ == "__main__": tokenizer = registry.get(TOKENIZER) if not tokenizer: from utils.transformer_utils import get_tokenizer tokenizer = get_tokenizer("bert-base-uncased") registry.register(TOKENIZER, tokenizer) model_manager = create_model_manager(tokenizer) logger.info(f"Model Manager initialized with {len(model_manager.models)} models") else: model_manager = None logger.info("ModelManager module imported; initialization deferred") # Optional late registration - can be moved to a function to be called after imports def register_models(): """Register models after imports to avoid circular dependencies.""" import os from service_registry import registry, MODEL, PRETRAINED_MODEL, TOKENIZER from tokenizer import TokenizerWrapper # Import here to avoid circular imports try: from model_Custm import Wildnerve_tlm01 from model_PrTr import Wildnerve_tlm01 as PretrainedModel # Instantiate & register tokenizer tok = TokenizerWrapper() registry.register(TOKENIZER, tok, overwrite=True) # Instantiate & register custom model custom = Wildnerve_tlm01(tokenizer=tok) registry.register(MODEL, custom, overwrite=True) # Instantiate & register pretrained model pre = PretrainedModel(tokenizer=tok) registry.register(PRETRAINED_MODEL, pre, overwrite=True) return True except Exception as e: logger.error(f"Failed to register models: {e}") return False