|
|
import gc, os, sys, time, torch, logging, inspect, numpy as np, pandas as pd, importlib.util
|
|
|
from pathlib import Path
|
|
|
from threading import Lock
|
|
|
from collections import OrderedDict
|
|
|
from nltk.stem import WordNetLemmatizer
|
|
|
from typing import List, Dict, Any, Tuple, Optional
|
|
|
from sklearn.metrics.pairwise import cosine_similarity
|
|
|
from config import app_config
|
|
|
from dataset import TensorDataset
|
|
|
from utils.transformer_utils import get_sentence_transformer
|
|
|
from utils.smartHybridAttention import SmartHybridAttention, get_hybrid_attention_config
|
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
|
from service_registry import registry, MODEL, TOKENIZER, MODEL_MANAGER, COMMUNICATOR
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
try:
|
|
|
import psutil
|
|
|
PSUTIL_AVAILABLE = True
|
|
|
except ImportError:
|
|
|
logger.warning("psutil not available")
|
|
|
PSUTIL_AVAILABLE = False
|
|
|
class DummyProcess:
|
|
|
def __init__(self, pid=None): self.pid = pid or 1
|
|
|
def memory_info(self):
|
|
|
class MemInfo:
|
|
|
def __init__(self): self.rss = 1e6; self.vms = 1e6
|
|
|
return MemInfo()
|
|
|
def memory_percent(self): return 1.0
|
|
|
class DummyPsutil:
|
|
|
@staticmethod
|
|
|
def Process(pid=None): return DummyProcess(pid)
|
|
|
psutil = DummyPsutil()
|
|
|
|
|
|
def safe_get_config(config_obj, key, default=None):
|
|
|
if isinstance(config_obj, dict):
|
|
|
return config_obj.get(key, default)
|
|
|
elif hasattr(config_obj, key):
|
|
|
return getattr(config_obj, key, default)
|
|
|
return default
|
|
|
|
|
|
def safe_get_config_value(config_obj, key, default=None):
|
|
|
try:
|
|
|
if isinstance(config_obj, dict):
|
|
|
return config_obj.get(key, default)
|
|
|
elif hasattr(config_obj, key):
|
|
|
return getattr(config_obj, key, default)
|
|
|
elif isinstance(config_obj, (int, float, str, bool)):
|
|
|
return config_obj
|
|
|
return default
|
|
|
except:
|
|
|
return default
|
|
|
|
|
|
class DatasetManager:
|
|
|
def __init__(self):
|
|
|
self.datasets: Dict[str, Any] = {}
|
|
|
self.lock = Lock()
|
|
|
|
|
|
def load_dataset(self, path: str, specialization: str) -> Any:
|
|
|
with self.lock:
|
|
|
if specialization in self.datasets:
|
|
|
logger.info(f"Using cached dataset for {specialization}")
|
|
|
return self.datasets[specialization]
|
|
|
dataset = self._load_and_process_dataset(path, specialization)
|
|
|
self.datasets[specialization] = dataset
|
|
|
return dataset
|
|
|
|
|
|
def _load_and_process_dataset(self, path: str, specialization: str) -> TensorDataset:
|
|
|
if not os.path.exists(path):
|
|
|
raise FileNotFoundError(f"Dataset {path} not found.")
|
|
|
logger.info(f"Loading dataset: {specialization}")
|
|
|
data = pd.read_csv(path)
|
|
|
if "label" not in data.columns:
|
|
|
raise ValueError("Dataset must have a 'label' column.")
|
|
|
features = data.drop("label", axis=1).values
|
|
|
labels = data["label"].values
|
|
|
features_tensor = torch.tensor(features, dtype=torch.float32)
|
|
|
labels_tensor = torch.tensor(labels, dtype=torch.long)
|
|
|
return TensorDataset(features_tensor, labels_tensor)
|
|
|
|
|
|
def get_status(self) -> Dict[str, Any]:
|
|
|
return {"loaded_datasets": list(self.datasets.keys()), "cache_size": len(self.datasets)}
|
|
|
|
|
|
def clear_cache(self):
|
|
|
with self.lock:
|
|
|
self.datasets.clear()
|
|
|
|
|
|
class ModelManager:
|
|
|
HF_ALIAS = {
|
|
|
"EvolphTech/Wildnerve-tlm01_Hybrid_Model": "model_Combn.Wildnerve_tlm01_Hybrid_Model",
|
|
|
"model_Custm": "EvolphTech/Wildnerve-tlm01_Hybrid_Model",
|
|
|
"model_Custm.py": "EvolphTech/Wildnerve-tlm01_Hybrid_Model",
|
|
|
}
|
|
|
|
|
|
def __init__(self, tokenizer=None, max_active_models=5, model_idle_threshold=600):
|
|
|
self.models = {}
|
|
|
self.lock = Lock()
|
|
|
self.model_pool = OrderedDict()
|
|
|
self.max_active_models = max_active_models if isinstance(max_active_models, int) and max_active_models > 0 else 2
|
|
|
self.model_idle_threshold = model_idle_threshold if isinstance(model_idle_threshold, int) and model_idle_threshold > 0 else 600
|
|
|
self.tokenizer = tokenizer
|
|
|
|
|
|
|
|
|
|
|
|
if hasattr(app_config, 'SPECIALIZATIONS') and app_config.SPECIALIZATIONS:
|
|
|
self.specializations = app_config.SPECIALIZATIONS
|
|
|
|
|
|
elif hasattr(app_config, 'DATASET_PATHS') and isinstance(app_config.DATASET_PATHS, dict):
|
|
|
self.specializations = list(app_config.DATASET_PATHS.keys())
|
|
|
|
|
|
else:
|
|
|
self.specializations = ["general", "programming", "mathematics"]
|
|
|
|
|
|
logger.info(f"Using {len(self.specializations)} specializations from config")
|
|
|
|
|
|
self._performance_metrics = {}
|
|
|
attention_config = get_hybrid_attention_config()
|
|
|
self.smart_attention = SmartHybridAttention(
|
|
|
dim=attention_config["DIM"],
|
|
|
num_heads=attention_config["NUM_HEADS"],
|
|
|
window_size=attention_config["WINDOW_SIZE"],
|
|
|
use_sliding=attention_config["USE_SLIDING"],
|
|
|
use_global=attention_config["USE_GLOBAL"],
|
|
|
use_hierarchical=attention_config["USE_HIERARCHICAL"],
|
|
|
global_token_ratio=attention_config["GLOBAL_TOKEN_RATIO"],
|
|
|
memory_tokens=attention_config["MEMORY_TOKENS"]
|
|
|
)
|
|
|
self.dataset_manager = DatasetManager()
|
|
|
transformer_config = safe_get_config(app_config, "TRANSFORMER_CONFIG", {})
|
|
|
model_name = safe_get_config(transformer_config, "MODEL_NAME", "Wildnerve-tlm01-0.05Bx12")
|
|
|
self.embedding_model = get_sentence_transformer(model_name)
|
|
|
self.similarity_threshold = safe_get_config(app_config, "SIMILARITY_THRESHOLD", 0.85)
|
|
|
self.top_k = safe_get_config(app_config, "TOP_K", 3)
|
|
|
self.prompt_analyzer = None
|
|
|
self.selected_models = self._get_selected_models()
|
|
|
try:
|
|
|
self._load_models()
|
|
|
except Exception:
|
|
|
logger.critical("Startup model loading failed, aborting ModelManager init", exc_info=True)
|
|
|
raise
|
|
|
logger.info(f"ModelManager initialized with {len(self.specializations)} specializations")
|
|
|
|
|
|
def _get_selected_models(self) -> List[str]:
|
|
|
model_files = safe_get_config(app_config, "SELECTED_MODEL", ["model_Custm.py"])
|
|
|
return model_files if model_files else ["model_Custm.py"]
|
|
|
|
|
|
def _import_model_class(self, model_key: str):
|
|
|
"""Robust import of model classes or HF hub repos; raises on failure."""
|
|
|
key = model_key.rstrip(".py")
|
|
|
alias = self.HF_ALIAS.get(key)
|
|
|
|
|
|
try:
|
|
|
|
|
|
if alias and "/" in alias:
|
|
|
logger.info(f"Loading HF model from repo '{alias}'")
|
|
|
model = AutoModelForCausalLM.from_pretrained(alias, use_auth_token=os.getenv("HF_TOKEN"))
|
|
|
tok = AutoTokenizer.from_pretrained(alias, use_auth_token=os.getenv("HF_TOKEN"))
|
|
|
model.tokenizer = tok
|
|
|
return model.__class__
|
|
|
|
|
|
|
|
|
if alias and "." in alias:
|
|
|
module_name, cls_name = alias.split(".", 1)
|
|
|
mod = importlib.import_module(module_name)
|
|
|
return getattr(mod, cls_name)
|
|
|
|
|
|
|
|
|
file_path = os.path.join(os.path.dirname(__file__), f"{key}.py")
|
|
|
if os.path.isfile(file_path):
|
|
|
spec = importlib.util.spec_from_file_location(key, file_path)
|
|
|
mod = importlib.util.module_from_spec(spec)
|
|
|
spec.loader.exec_module(mod)
|
|
|
cls = getattr(mod, key, None) or getattr(mod, "Wildnerve_tlm01", None)
|
|
|
if cls:
|
|
|
return cls
|
|
|
raise ImportError(f"No class '{key}' or 'Wildnerve_tlm01' in {file_path}")
|
|
|
|
|
|
|
|
|
mod = importlib.import_module(key)
|
|
|
cls = getattr(mod, key, None) or getattr(mod, "Wildnerve_tlm01", None)
|
|
|
if cls:
|
|
|
return cls
|
|
|
raise ImportError(f"No class '{key}' or 'Wildnerve_tlm01' in module '{key}'")
|
|
|
|
|
|
except Exception as e:
|
|
|
logger.error(f"Failed to import model class for '{model_key}': {e}", exc_info=True)
|
|
|
raise
|
|
|
|
|
|
def _initialize_model_for_specialization(self, spec: str, data_dir: str):
|
|
|
"""Instantiate a model for a given spec, with dataset-path handling and timeout warnings."""
|
|
|
|
|
|
ds_paths = safe_get_config_value(app_config, "DATASET_PATHS", {})
|
|
|
raw = ds_paths.get(spec)
|
|
|
if isinstance(raw, (list, tuple)):
|
|
|
dataset_path = raw[0]
|
|
|
else:
|
|
|
dataset_path = raw or os.path.join(data_dir, f"{spec}.csv")
|
|
|
|
|
|
|
|
|
if not os.path.exists(dataset_path):
|
|
|
try:
|
|
|
with open(dataset_path, "w") as f:
|
|
|
f.write("text,label\nsample text,0\n")
|
|
|
logger.info(f"Created placeholder dataset for '{spec}'")
|
|
|
except Exception as e:
|
|
|
logger.warning(f"Could not create dataset for '{spec}': {e}")
|
|
|
|
|
|
|
|
|
model_cls = self._import_model_class(self.selected_models[0])
|
|
|
params = dict(
|
|
|
vocab_size=50257,
|
|
|
specialization=spec,
|
|
|
dataset_path=dataset_path,
|
|
|
model_name=safe_get_config_value(app_config, "TRANSFORMER_CONFIG", {}).get("MODEL_NAME", "gpt2"),
|
|
|
embedding_dim=768,
|
|
|
num_heads=12,
|
|
|
hidden_dim=768,
|
|
|
num_layers=12,
|
|
|
output_size=50257,
|
|
|
dropout=0.1,
|
|
|
max_seq_length=1024,
|
|
|
pooling_mode=safe_get_config_value(app_config, "TRANSFORMER_CONFIG", {}).get("POOLING_MODE", "last"),
|
|
|
tokenizer=self.tokenizer
|
|
|
)
|
|
|
|
|
|
start = time.time()
|
|
|
try:
|
|
|
model = model_cls(**params)
|
|
|
except Exception as e:
|
|
|
logger.error(f"Error instantiating '{model_cls.__name__}' for '{spec}': {e}", exc_info=True)
|
|
|
raise
|
|
|
|
|
|
elapsed = time.time() - start
|
|
|
if elapsed > 30:
|
|
|
logger.warning(f"Model creation for '{spec}' took {elapsed:.1f}s (>30s)")
|
|
|
|
|
|
|
|
|
self.models[spec] = model
|
|
|
self.model_pool[spec] = None
|
|
|
self._performance_metrics[spec] = {
|
|
|
"inference_time": 0.0,
|
|
|
"memory_usage": 0.0,
|
|
|
"last_accessed": time.time(),
|
|
|
"num_inferences": 0
|
|
|
}
|
|
|
|
|
|
def _load_models(self):
|
|
|
"""Load initial specializations - now only preloading 'general' to save resources."""
|
|
|
|
|
|
initial = ["general"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
data_dir = os.environ.get("TLM_DATA_DIR", "/tmp/tlm_data")
|
|
|
os.makedirs(data_dir, exist_ok=True)
|
|
|
|
|
|
for spec in initial:
|
|
|
logger.info(f"Initializing model for specialization '{spec}'")
|
|
|
try:
|
|
|
self._initialize_model_for_specialization(spec, data_dir)
|
|
|
logger.info(f"Model for '{spec}' loaded successfully")
|
|
|
except Exception:
|
|
|
logger.error(f"Failed to load model for '{spec}'", exc_info=True)
|
|
|
raise
|
|
|
|
|
|
logger.info(f"{len(self.models)} models loaded at startup (of {len(self.specializations)} total)")
|
|
|
|
|
|
|
|
|
logger.debug(f"Available specializations: {', '.join(self.specializations)}")
|
|
|
|
|
|
def get_or_create_model(self, specialization: str) -> Any:
|
|
|
"""Get an existing model or create it on demand if not already loaded"""
|
|
|
with self.lock:
|
|
|
|
|
|
model = self.get_model(specialization)
|
|
|
if (model):
|
|
|
logger.info(f"Using existing model for {specialization}")
|
|
|
return model
|
|
|
|
|
|
|
|
|
if specialization not in self.all_specializations and specialization != "general":
|
|
|
logger.warning(f"Unknown specialization: {specialization}, using general")
|
|
|
specialization = "general"
|
|
|
|
|
|
|
|
|
logger.info(f"Lazily loading model for {specialization}")
|
|
|
|
|
|
|
|
|
if len(self.models) >= self.max_active_models:
|
|
|
lru_specialization = next(iter(self.model_pool))
|
|
|
self.remove_model_instance(lru_specialization)
|
|
|
|
|
|
|
|
|
data_dir = os.environ.get("TLM_DATA_DIR", "/tmp/tlm_data")
|
|
|
try:
|
|
|
self._initialize_model_for_specialization(specialization, data_dir)
|
|
|
return self.models.get(specialization)
|
|
|
except Exception as e:
|
|
|
logger.error(f"Error initializing model: {e}")
|
|
|
|
|
|
|
|
|
if specialization != "general" and "general" in self.models:
|
|
|
return self.models["general"]
|
|
|
|
|
|
|
|
|
return self._create_minimal_model()
|
|
|
|
|
|
def _create_minimal_model(self):
|
|
|
"""Create a minimal fallback model for emergencies"""
|
|
|
try:
|
|
|
from model_Custm import Wildnerve_tlm01
|
|
|
model = Wildnerve_tlm01(
|
|
|
vocab_size=50257,
|
|
|
specialization="minimal",
|
|
|
dataset_path=None,
|
|
|
model_name="gpt2",
|
|
|
embedding_dim=768,
|
|
|
num_heads=12,
|
|
|
hidden_dim=768,
|
|
|
num_layers=2,
|
|
|
output_size=50257,
|
|
|
dropout=0.1,
|
|
|
max_seq_length=128,
|
|
|
pooling_mode="last",
|
|
|
tokenizer=self.tokenizer
|
|
|
)
|
|
|
model._is_minimal = True
|
|
|
return model
|
|
|
except Exception as e:
|
|
|
logger.error(f"Failed to create minimal model: {e}")
|
|
|
return None
|
|
|
|
|
|
def get_model(self, specialization: str) -> Any:
|
|
|
with self.lock:
|
|
|
model = self.models.get(specialization)
|
|
|
if model:
|
|
|
self.model_pool.move_to_end(specialization)
|
|
|
if specialization in self._performance_metrics:
|
|
|
self._performance_metrics[specialization]["last_accessed"] = time.time()
|
|
|
return model
|
|
|
|
|
|
def route_input(self, input_text: str) -> dict:
|
|
|
|
|
|
input_embedding = self.embedding_model.encode(input_text)
|
|
|
|
|
|
|
|
|
if hasattr(self, 'smart_attention') and self.smart_attention:
|
|
|
try:
|
|
|
|
|
|
import torch
|
|
|
input_tensor = torch.tensor(input_embedding).unsqueeze(0).unsqueeze(0)
|
|
|
|
|
|
|
|
|
enhanced, _ = self.smart_attention(
|
|
|
query=input_tensor,
|
|
|
key=input_tensor,
|
|
|
value=input_tensor
|
|
|
)
|
|
|
|
|
|
|
|
|
if isinstance(enhanced, torch.Tensor):
|
|
|
enhanced_embedding = enhanced.squeeze().cpu().numpy()
|
|
|
|
|
|
input_embedding = enhanced_embedding
|
|
|
logger.info("Using SmartHybridAttention for enhanced prompt routing")
|
|
|
except Exception as e:
|
|
|
logger.warning(f"Error using SmartHybridAttention: {e}")
|
|
|
|
|
|
|
|
|
similarities = {}
|
|
|
for spec in self.specializations:
|
|
|
model = self.get_model(spec)
|
|
|
if model and hasattr(model, "embedding"):
|
|
|
sim = cosine_similarity(input_embedding.reshape(1, -1), model.embedding.reshape(1, -1))[0][0]
|
|
|
similarities[spec] = sim
|
|
|
if similarities:
|
|
|
best_match = max(similarities.items(), key=lambda x: x[1])
|
|
|
return {"matched_specialization": best_match[0], "confidence": best_match[1], "all_scores": similarities}
|
|
|
return {"matched_specialization": self.specializations[0], "confidence": 0.0, "all_scores": similarities}
|
|
|
|
|
|
def get_model_for_prompt(self, prompt: str) -> Tuple[Any, str]:
|
|
|
try:
|
|
|
routing_result = self.route_input(prompt)
|
|
|
specialization = routing_result.get("matched_specialization", self.specializations[0])
|
|
|
model = self.get_or_create_model(specialization)
|
|
|
start_time = time.time()
|
|
|
def update_metrics():
|
|
|
if specialization in self._performance_metrics:
|
|
|
m = self._performance_metrics[specialization]
|
|
|
elapsed = time.time() - start_time
|
|
|
n = m.get("num_inferences", 0) + 1
|
|
|
m["inference_time"] = ((m.get("inference_time", 0) * (n-1)) + elapsed) / n
|
|
|
m["num_inferences"] = n
|
|
|
m["last_accessed"] = time.time()
|
|
|
if hasattr(model, "get_memory_usage"):
|
|
|
m["memory_usage"] = model.get_memory_usage()
|
|
|
update_metrics()
|
|
|
return model, specialization
|
|
|
except Exception as e:
|
|
|
logger.error(f"Error selecting model: {e}")
|
|
|
if self.models:
|
|
|
default_key = list(self.models.keys())[0]
|
|
|
return self.models[default_key], default_key
|
|
|
else:
|
|
|
logger.error("No models available for routing")
|
|
|
return None, "none"
|
|
|
|
|
|
def generate(self, prompt: str, **kwargs):
|
|
|
self.validate_input(prompt)
|
|
|
model, specialization = self.get_model_for_prompt(prompt)
|
|
|
start_time = time.time()
|
|
|
try:
|
|
|
result = model.generate(prompt=prompt, **kwargs)
|
|
|
elapsed = time.time() - start_time
|
|
|
if specialization in self._performance_metrics:
|
|
|
m = self._performance_metrics[specialization]
|
|
|
n = m.get("num_inferences", 0) + 1
|
|
|
m["inference_time"] = ((m.get("inference_time", 0) * (n-1)) + elapsed) / n
|
|
|
m["num_inferences"] = n
|
|
|
m["last_accessed"] = time.time()
|
|
|
return result
|
|
|
except Exception as e:
|
|
|
logger.error(f"Error generating with {specialization}: {e}")
|
|
|
default_spec = self.specializations[0]
|
|
|
default_model = self.get_or_create_model(default_spec)
|
|
|
return default_model.generate(prompt=prompt, **kwargs)
|
|
|
|
|
|
def generate_streaming(self, prompt: str, **kwargs):
|
|
|
self.validate_input(prompt)
|
|
|
model, specialization = self.get_model_for_prompt(prompt)
|
|
|
start_time = time.time()
|
|
|
try:
|
|
|
if hasattr(model, "generate_streaming"):
|
|
|
for token in model.generate_streaming(prompt=prompt, **kwargs):
|
|
|
yield token
|
|
|
else:
|
|
|
logger.info("Simulating streaming generation")
|
|
|
result = model.generate(prompt=prompt, **kwargs)
|
|
|
for word in result.split():
|
|
|
yield word + " "
|
|
|
elapsed = time.time() - start_time
|
|
|
if specialization in self._performance_metrics:
|
|
|
m = self._performance_metrics[specialization]
|
|
|
n = m.get("num_inferences", 0) + 1
|
|
|
m["inference_time"] = ((m.get("inference_time", 0) * (n-1)) + elapsed) / n
|
|
|
m["num_inferences"] = n
|
|
|
m["last_accessed"] = time.time()
|
|
|
except Exception as e:
|
|
|
logger.error(f"Error in streaming generation: {e}")
|
|
|
default_spec = self.specializations[0]
|
|
|
default_model = self.get_or_create_model(default_spec)
|
|
|
if hasattr(default_model, "generate_streaming"):
|
|
|
for token in default_model.generate_streaming(prompt=prompt, **kwargs):
|
|
|
yield token
|
|
|
else:
|
|
|
fallback_result = default_model.generate(prompt=prompt, **kwargs)
|
|
|
for word in fallback_result.split():
|
|
|
yield word + " "
|
|
|
|
|
|
def remove_model_instance(self, specialization: str) -> bool:
|
|
|
with self.lock:
|
|
|
if specialization in self.models:
|
|
|
del self.models[specialization]
|
|
|
self.model_pool.pop(specialization, None)
|
|
|
gc.collect()
|
|
|
if torch.cuda.is_available():
|
|
|
torch.cuda.empty_cache()
|
|
|
logger.info(f"Removed model for {specialization}")
|
|
|
return True
|
|
|
return False
|
|
|
|
|
|
def validate_input(self, input_text: str) -> bool:
|
|
|
if not input_text or len(input_text.strip()) == 0:
|
|
|
raise ValueError("Empty input text")
|
|
|
max_length = safe_get_config(app_config, "MAX_INPUT_LENGTH", safe_get_config(app_config, "MAX_SEQ_LENGTH", 128))
|
|
|
if len(input_text) > max_length:
|
|
|
raise ValueError(f"Input exceeds maximum length of {max_length}")
|
|
|
return True
|
|
|
|
|
|
def get_health_status(self) -> Dict[str, Any]:
|
|
|
with self.lock:
|
|
|
process = psutil.Process(os.getpid())
|
|
|
mem_info = process.memory_info()
|
|
|
return {
|
|
|
"active_models": len(self.models),
|
|
|
"memory_usage": {
|
|
|
"rss_mb": mem_info.rss / (1024 * 1024),
|
|
|
"vms_mb": mem_info.vms / (1024 * 1024),
|
|
|
"percent": process.memory_percent()
|
|
|
},
|
|
|
"model_performance": self._get_model_metrics(),
|
|
|
"dataset_status": self.dataset_manager.get_status(),
|
|
|
"cache_efficiency": len(self.model_pool) / max(1, self.max_active_models)
|
|
|
}
|
|
|
|
|
|
def _get_model_metrics(self) -> Dict[str, Dict[str, Any]]:
|
|
|
metrics = {}
|
|
|
for spec, model in self.models.items():
|
|
|
base = self._performance_metrics.get(spec, {})
|
|
|
mem_usage = 0
|
|
|
if hasattr(model, "get_memory_usage"):
|
|
|
mem_usage = model.get_memory_usage()
|
|
|
elif hasattr(model, "parameters"):
|
|
|
mem_usage = sum(p.numel() * p.element_size() for p in model.parameters()) / (1024 * 1024)
|
|
|
metrics[spec] = {
|
|
|
"inference_time": base.get("inference_time", 0),
|
|
|
"memory_usage_mb": mem_usage,
|
|
|
"last_accessed": base.get("last_accessed", 0),
|
|
|
"num_inferences": base.get("num_inferences", 0),
|
|
|
"model_type": model.__class__.__name__
|
|
|
}
|
|
|
return metrics
|
|
|
|
|
|
def get_available_models(self) -> Dict[str, Any]:
|
|
|
with self.lock:
|
|
|
return dict(self.models)
|
|
|
|
|
|
def shutdown(self):
|
|
|
try:
|
|
|
logger.info("Initiating shutdown")
|
|
|
for spec in list(self.models.keys()):
|
|
|
self.remove_model_instance(spec)
|
|
|
self.dataset_manager.clear_cache()
|
|
|
logger.info("Shutdown complete")
|
|
|
except Exception as e:
|
|
|
logger.error(f"Error during shutdown: {e}")
|
|
|
|
|
|
def manage_model_cache(self):
|
|
|
try:
|
|
|
current = time.time()
|
|
|
with self.lock:
|
|
|
while len(self.models) > self.max_active_models:
|
|
|
oldest = next(iter(self.model_pool))
|
|
|
self.remove_model_instance(oldest)
|
|
|
logger.info(f"Removed LRU model: {oldest}")
|
|
|
for spec, last in list(self.model_pool.items()):
|
|
|
m = self._performance_metrics.get(spec, {})
|
|
|
if m.get("last_accessed", 0) and (current - m["last_accessed"] > self.model_idle_threshold):
|
|
|
self.remove_model_instance(spec)
|
|
|
logger.info(f"Removed idle model: {spec}")
|
|
|
sorted_models = sorted(self.model_pool.items(), key=lambda x: self._performance_metrics.get(x[0], {}).get("last_accessed", 0), reverse=True)
|
|
|
self.model_pool = OrderedDict(sorted_models)
|
|
|
except Exception as e:
|
|
|
logger.error(f"Error in cache management: {e}")
|
|
|
|
|
|
def set_tokenizer(self, tokenizer):
|
|
|
self.tokenizer = tokenizer
|
|
|
with self.lock:
|
|
|
for name, model in self.models.items():
|
|
|
if hasattr(model, "set_tokenizer"):
|
|
|
try:
|
|
|
model.tokenizer = tokenizer
|
|
|
logger.debug(f"Updated tokenizer for {name}")
|
|
|
except Exception as ex:
|
|
|
logger.warning(f"Failed to set tokenizer for {name}: {ex}")
|
|
|
logger.info("Tokenizer updated for models")
|
|
|
return self
|
|
|
|
|
|
def initialize_models(self):
|
|
|
try:
|
|
|
logger.info("Initializing models from weights")
|
|
|
prompt_analyzer = registry.get("prompt_analyzer")
|
|
|
if not prompt_analyzer:
|
|
|
try:
|
|
|
from pathlib import Path
|
|
|
model_list_path = Path(__file__).parent / "model_List.py"
|
|
|
if model_list_path.exists():
|
|
|
spec = importlib.util.find_spec("model_List")
|
|
|
if spec:
|
|
|
model_list = importlib.util.module_from_spec(spec)
|
|
|
spec.loader.exec_module(model_list)
|
|
|
if hasattr(model_list, "PromptAnalyzer"):
|
|
|
prompt_analyzer = model_list.PromptAnalyzer()
|
|
|
registry.register("prompt_analyzer", prompt_analyzer)
|
|
|
logger.info("Imported PromptAnalyzer")
|
|
|
except Exception as e:
|
|
|
logger.error(f"Error importing PromptAnalyzer: {e}")
|
|
|
self.prompt_analyzer = prompt_analyzer
|
|
|
selected_models_list = prompt_analyzer.get_selected_models() if prompt_analyzer and hasattr(prompt_analyzer, "get_selected_models") else ["model_Custm.py"]
|
|
|
logger.info(f"Selected model types: {selected_models_list}")
|
|
|
|
|
|
|
|
|
selected_specializations = self.specializations[:5]
|
|
|
|
|
|
for spec in selected_specializations:
|
|
|
try:
|
|
|
model_name = selected_models_list[0].replace(".py", "")
|
|
|
from pathlib import Path
|
|
|
model_path = Path(__file__).parent / f"{model_name}.py"
|
|
|
if model_path.exists():
|
|
|
spec_obj = importlib.util.find_spec(model_name)
|
|
|
if spec_obj:
|
|
|
model_module = importlib.util.module_from_spec(spec_obj)
|
|
|
spec_obj.loader.exec_module(model_module)
|
|
|
if hasattr(model_module, "Wildnerve_tlm01"):
|
|
|
model_class = getattr(model_module, "Wildnerve_tlm01")
|
|
|
embedding_dim = 768
|
|
|
num_heads = 12 if embedding_dim % 12 == 0 else 1
|
|
|
model_instance = model_class(
|
|
|
vocab_size=50257,
|
|
|
specialization=spec,
|
|
|
dataset_path=None,
|
|
|
model_name="gpt2",
|
|
|
embedding_dim=embedding_dim,
|
|
|
num_heads=num_heads,
|
|
|
hidden_dim=768,
|
|
|
num_layers=2,
|
|
|
output_size=50257,
|
|
|
dropout=0.1,
|
|
|
max_seq_length=128,
|
|
|
pooling_mode="last"
|
|
|
)
|
|
|
self.models[spec] = model_instance
|
|
|
logger.info(f"Created model for {spec}")
|
|
|
except Exception as e:
|
|
|
logger.error(f"Error creating model for {spec}: {e}")
|
|
|
if not self.models:
|
|
|
logger.error("No models created")
|
|
|
return False
|
|
|
try:
|
|
|
import os
|
|
|
attention_config_path = os.path.join(app_config.DATA_DIR, "attention_configuration.json")
|
|
|
from utils.attention_connector import get_attention_connector
|
|
|
attention_connector = get_attention_connector()
|
|
|
if hasattr(attention_connector, "config_path"):
|
|
|
attention_connector.config_path = attention_config_path
|
|
|
attention_connector._init_profile_selector()
|
|
|
logger.info(f"Initialized attention connector with config: {attention_config_path}")
|
|
|
except Exception as e:
|
|
|
logger.warning(f"Failed to initialize attention connector: {e}")
|
|
|
logger.info(f"Successfully initialized {len(self.models)} models")
|
|
|
return True
|
|
|
except Exception as e:
|
|
|
logger.error(f"Error initializing models: {e}", exc_info=True)
|
|
|
return False
|
|
|
|
|
|
def get_alternative_model_for_prompt(self, prompt: str, current_model=None) -> any:
|
|
|
try:
|
|
|
if self.prompt_analyzer and hasattr(self.prompt_analyzer, "choose_model"):
|
|
|
model_type = self.prompt_analyzer.choose_model(prompt)
|
|
|
if model_type:
|
|
|
|
|
|
from model_Custm import Wildnerve_tlm01
|
|
|
alt_model = Wildnerve_tlm01(
|
|
|
vocab_size=50257,
|
|
|
specialization="general",
|
|
|
dataset_path=None,
|
|
|
model_name="gpt2",
|
|
|
embedding_dim=768,
|
|
|
num_heads=12,
|
|
|
hidden_dim=768,
|
|
|
num_layers=6,
|
|
|
output_size=50257,
|
|
|
dropout=0.1,
|
|
|
max_seq_length=512,
|
|
|
pooling_mode="last",
|
|
|
tokenizer=self.tokenizer
|
|
|
)
|
|
|
if alt_model != current_model:
|
|
|
logger.info("Found alternative model via prompt_analyzer")
|
|
|
return alt_model
|
|
|
for name, model in self.get_available_models().items():
|
|
|
if model != current_model:
|
|
|
logger.info(f"Using alternative model: {name}")
|
|
|
return model
|
|
|
try:
|
|
|
from model_Custm import Wildnerve_tlm01
|
|
|
fallback_model = Wildnerve_tlm01(
|
|
|
vocab_size=50257,
|
|
|
specialization="general",
|
|
|
model_name="gpt2",
|
|
|
embedding_dim=768,
|
|
|
num_heads=12,
|
|
|
hidden_dim=768,
|
|
|
num_layers=6,
|
|
|
output_size=50257,
|
|
|
dropout=0.1,
|
|
|
max_seq_length=512,
|
|
|
pooling_mode="last",
|
|
|
tokenizer=self.tokenizer
|
|
|
)
|
|
|
logger.info("Created fallback model")
|
|
|
return fallback_model
|
|
|
except Exception as e:
|
|
|
logger.error(f"Error creating fallback model: {e}")
|
|
|
return None
|
|
|
except Exception as e:
|
|
|
logger.error(f"Error getting alternative model: {e}")
|
|
|
return None
|
|
|
|
|
|
def prepare_model_input(self, text: str, model) -> dict:
|
|
|
device = next(model.parameters()).device
|
|
|
try:
|
|
|
tokenizer = getattr(model, "tokenizer", None)
|
|
|
if tokenizer:
|
|
|
inputs = tokenizer(
|
|
|
text,
|
|
|
return_tensors="pt",
|
|
|
padding=True,
|
|
|
truncation=True,
|
|
|
max_length=safe_get_config_value(app_config, "MAX_SEQ_LENGTH", 512)
|
|
|
)
|
|
|
input_ids = inputs["input_ids"].to(device)
|
|
|
return {"input_ids": input_ids, "max_length": safe_get_config_value(app_config, "MAX_SEQ_LENGTH", 512), "device": device, "temperature": getattr(self, "generation_config", {}).get("temperature", 0.7)}
|
|
|
else:
|
|
|
logger.warning("No tokenizer in model; using basic input")
|
|
|
return {"input_text": text, "max_length": safe_get_config_value(app_config, "MAX_SEQ_LENGTH", 512)}
|
|
|
except Exception as e:
|
|
|
logger.error(f"Error preparing model input: {e}")
|
|
|
return {"input_text": text}
|
|
|
|
|
|
def process_with_context(self, input_text: str, context: Optional[dict] = None) -> dict:
|
|
|
conversation_context = self.get_conversation_context(window_size=3)
|
|
|
contextualized_prompt = input_text
|
|
|
if (conversation_context):
|
|
|
max_seq_length = safe_get_config_value(app_config, "MAX_SEQ_LENGTH", 512)
|
|
|
max_seq_length = int(max_seq_length) if isinstance(max_seq_length, (int, float)) else 512
|
|
|
contextualized_prompt = f"Previous conversation:\n{conversation_context}\n\nCurrent question: {input_text}"
|
|
|
result = self.process_input(contextualized_prompt, context)
|
|
|
if isinstance(result, dict):
|
|
|
result["original_query"] = input_text
|
|
|
return result
|
|
|
|
|
|
def get_conversation_context(self, window_size: int = 3) -> str:
|
|
|
if not hasattr(self, "conversation_history"):
|
|
|
self.conversation_history = []
|
|
|
recent = self.conversation_history[-window_size*2:]
|
|
|
lines = []
|
|
|
for entry in recent:
|
|
|
prefix = "User: " if entry["role"]=="user" else "Assistant: "
|
|
|
lines.append(f"{prefix}{entry['content']}")
|
|
|
return "\n".join(lines)
|
|
|
|
|
|
|
|
|
def create_model_manager(tokenizer=None) -> ModelManager:
|
|
|
try:
|
|
|
max_active_models = safe_get_config_value(app_config, "MAX_ACTIVE_MODELS", 2)
|
|
|
model_idle_threshold = safe_get_config_value(app_config, "MODEL_IDLE_THRESHOLD", 600)
|
|
|
manager = ModelManager(tokenizer=tokenizer, max_active_models=max_active_models, model_idle_threshold=model_idle_threshold)
|
|
|
if tokenizer:
|
|
|
manager.set_tokenizer(tokenizer)
|
|
|
elif registry.has(TOKENIZER):
|
|
|
manager.set_tokenizer(registry.get(TOKENIZER))
|
|
|
registry.register(MODEL_MANAGER, manager)
|
|
|
return manager
|
|
|
except Exception as e:
|
|
|
logger.error(f"Error creating ModelManager: {e}")
|
|
|
minimal_manager = ModelManager(tokenizer=tokenizer, max_active_models=1)
|
|
|
registry.register(MODEL_MANAGER, minimal_manager)
|
|
|
return minimal_manager
|
|
|
|
|
|
def create_model_manager_with_tokenizer(tokenizer):
|
|
|
try:
|
|
|
max_active_models = safe_get_config_value(app_config, "MAX_ACTIVE_MODELS", 2)
|
|
|
model_idle_threshold = safe_get_config_value(app_config, "MODEL_IDLE_THRESHOLD", 600)
|
|
|
manager = ModelManager(max_active_models=max_active_models, model_idle_threshold=model_idle_threshold)
|
|
|
manager.tokenizer = tokenizer
|
|
|
manager.initialize_models()
|
|
|
registry.register(MODEL_MANAGER, manager)
|
|
|
return manager
|
|
|
except Exception as e:
|
|
|
logger.error(f"Error creating ModelManager with tokenizer: {e}")
|
|
|
minimal_manager = ModelManager(max_active_models=1)
|
|
|
minimal_manager.tokenizer = tokenizer
|
|
|
registry.register(MODEL_MANAGER, minimal_manager)
|
|
|
return minimal_manager
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
tokenizer = registry.get(TOKENIZER)
|
|
|
if not tokenizer:
|
|
|
from utils.transformer_utils import get_tokenizer
|
|
|
tokenizer = get_tokenizer("bert-base-uncased")
|
|
|
registry.register(TOKENIZER, tokenizer)
|
|
|
model_manager = create_model_manager(tokenizer)
|
|
|
logger.info(f"Model Manager initialized with {len(model_manager.models)} models")
|
|
|
else:
|
|
|
model_manager = None
|
|
|
logger.info("ModelManager module imported; initialization deferred")
|
|
|
|
|
|
|
|
|
def register_models():
|
|
|
"""Register models after imports to avoid circular dependencies."""
|
|
|
import os
|
|
|
from service_registry import registry, MODEL, PRETRAINED_MODEL, TOKENIZER
|
|
|
from tokenizer import TokenizerWrapper
|
|
|
|
|
|
|
|
|
try:
|
|
|
from model_Custm import Wildnerve_tlm01
|
|
|
from model_PrTr import Wildnerve_tlm01 as PretrainedModel
|
|
|
|
|
|
|
|
|
tok = TokenizerWrapper()
|
|
|
registry.register(TOKENIZER, tok, overwrite=True)
|
|
|
|
|
|
|
|
|
custom = Wildnerve_tlm01(tokenizer=tok)
|
|
|
registry.register(MODEL, custom, overwrite=True)
|
|
|
|
|
|
|
|
|
pre = PretrainedModel(tokenizer=tok)
|
|
|
registry.register(PRETRAINED_MODEL, pre, overwrite=True)
|
|
|
return True
|
|
|
except Exception as e:
|
|
|
logger.error(f"Failed to register models: {e}")
|
|
|
return False
|
|
|
|