| |
|
| | import os
|
| | import sys
|
| | import time
|
| | import json
|
| | import gc
|
| | import logging
|
| | import argparse
|
| | import importlib
|
| | import threading
|
| | from typing import Dict, Any, Optional, List, Union, Generator, Tuple
|
| | from pathlib import Path
|
| |
|
| |
|
| | sys.stdout.reconfigure(line_buffering=True)
|
| |
|
| |
|
| | logging.basicConfig(
|
| | level=logging.INFO,
|
| | format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
| | force=True
|
| | )
|
| | logger = logging.getLogger(__name__)
|
| |
|
| |
|
| | file_handler = logging.FileHandler('/tmp/app_debug.log')
|
| | file_handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
|
| | logger.addHandler(file_handler)
|
| |
|
| | fh = logging.FileHandler("/tmp/container.log")
|
| | fh.setLevel(logging.DEBUG)
|
| | fh.setFormatter(logging.Formatter("%(asctime)s - %(name)s - %(levellevel)s - %(message)s"))
|
| | logging.getLogger().addHandler(fh)
|
| | logger.info("Logging configured")
|
| |
|
| |
|
| | if not os.environ.get("TLM_DATA_DIR"):
|
| | os.environ["TLM_DATA_DIR"] = "/tmp/tlm_data"
|
| |
|
| |
|
| | import torch
|
| | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| | logger.info(f"Using device: {device}")
|
| |
|
| |
|
| | try:
|
| | import torch
|
| | if torch.cuda.is_available():
|
| | torch.cuda.empty_cache()
|
| | logger.info(f"GPU Memory: {torch.cuda.memory_allocated()/1e9:.2f}GB / {torch.cuda.get_device_properties(0).total_memory/1e9:.2f}GB")
|
| | except Exception as e:
|
| | logger.warning(f"Error with PyTorch setup: {e}")
|
| |
|
| |
|
| | try:
|
| | from config import app_config, load_config, get_model_architecture_params
|
| |
|
| |
|
| | if hasattr(app_config, 'TRANSFORMER_CONFIG'):
|
| | if not hasattr(app_config.TRANSFORMER_CONFIG, 'config_data') and isinstance(app_config.TRANSFORMER_CONFIG, dict):
|
| |
|
| | app_config.TRANSFORMER_CONFIG.config_data = app_config.TRANSFORMER_CONFIG
|
| |
|
| |
|
| | app_config.TRANSFORMER_CONFIG.MODEL_NAME = "gpt2"
|
| | app_config.TRANSFORMER_CONFIG.VOCAB_SIZE = 50257
|
| |
|
| | elif hasattr(app_config.TRANSFORMER_CONFIG, 'specialization'):
|
| |
|
| | if isinstance(app_config.TRANSFORMER_CONFIG.specialization, str):
|
| | if ',' in app_config.TRANSFORMER_CONFIG.specialization:
|
| | app_config.TRANSFORMER_CONFIG.specialization = [
|
| | s.strip() for s in app_config.TRANSFORMER_CONFIG.specialization.split(',')
|
| | ]
|
| | else:
|
| | app_config.TRANSFORMER_CONFIG.specialization = [app_config.TRANSFORMER_CONFIG.specialization]
|
| | logger.info(f"Fixed specialization: {app_config.TRANSFORMER_CONFIG.specialization}")
|
| | except Exception as e:
|
| | logger.error(f"Error loading configuration: {e}")
|
| | raise
|
| |
|
| |
|
| | try:
|
| | import transformer_patches
|
| | except ImportError:
|
| | logger.warning("Could not import transformer_patches")
|
| |
|
| |
|
| | try:
|
| | from service_registry import registry, MODEL, TOKENIZER, MODEL_MANAGER, COMMUNICATOR, PIPELINE, PRETRAINED_MODEL
|
| |
|
| |
|
| | from utils.event_system import (
|
| | EVENT_STDP_REQUEST, EVENT_STDP_RESPONSE, EVENT_TOKEN_GENERATED,
|
| | EVENT_USER_INPUT, EVENT_MODEL_REQUEST, EVENT_MODEL_RESPONSE,
|
| | EVENT_RESPONSE_COMPLETE, EVENT_ERROR
|
| | )
|
| |
|
| |
|
| | from utils.event_bus import event_bus
|
| | except ImportError as e:
|
| | logger.error(f"Failed to import core modules: {e}")
|
| |
|
| | class Registry:
|
| | def __init__(self):
|
| | self._registry = {}
|
| |
|
| | def register(self, key, value):
|
| | self._registry[key] = value
|
| |
|
| | def get(self, key, default=None):
|
| | return self._registry.get(key, default)
|
| |
|
| | def has(self, key):
|
| | return key in self._registry
|
| |
|
| | registry = Registry()
|
| | MODEL = "model"
|
| | TOKENIZER = "tokenizer"
|
| | MODEL_MANAGER = "model_manager"
|
| | COMMUNICATOR = "communicator"
|
| | PIPELINE = "pipeline"
|
| | PRETRAINED_MODEL = "pretrained_model"
|
| |
|
| |
|
| | EVENT_STDP_REQUEST = "stdp_request"
|
| | EVENT_STDP_RESPONSE = "stdp_response"
|
| | EVENT_TOKEN_GENERATED = "token_generated"
|
| | EVENT_USER_INPUT = "user_input"
|
| | EVENT_MODEL_REQUEST = "model_request"
|
| | EVENT_MODEL_RESPONSE = "model_response"
|
| | EVENT_RESPONSE_COMPLETE = "response_complete"
|
| | EVENT_ERROR = "error"
|
| |
|
| |
|
| | class EventBus:
|
| | def publish(self, event_type, data):
|
| | pass
|
| |
|
| | event_bus = EventBus()
|
| |
|
| | from find_weights import find_transformer_weights, find_snn_weights
|
| |
|
| |
|
| | try:
|
| | from api_wp import TLMInterface
|
| | from verify_repo import verify_model_repo_access
|
| | except Exception as e:
|
| | logger.error(f"Error importing API components: {e}")
|
| |
|
| | class TLMInterface:
|
| | def initialize(self, force=False):
|
| | return False
|
| | def process_input(self, text):
|
| | return {"response": "API unavailable"}
|
| |
|
| | def verify_model_repo_access():
|
| | return False
|
| |
|
| |
|
| |
|
| | def fix_config_file(config_path="config.json"):
|
| | """Fix the config file directly"""
|
| | import os
|
| | import json
|
| |
|
| | try:
|
| |
|
| | if not os.path.exists(config_path):
|
| | logger.error(f"Config file not found at {config_path}")
|
| | return False
|
| |
|
| |
|
| | with open(config_path, 'r') as f:
|
| | config_data = json.load(f)
|
| |
|
| |
|
| | if 'TRANSFORMER_CONFIG' in config_data and isinstance(config_data['TRANSFORMER_CONFIG'], dict):
|
| | transformer_config = config_data['TRANSFORMER_CONFIG']
|
| |
|
| |
|
| | if 'specialization' in transformer_config and isinstance(transformer_config['specialization'], list):
|
| | if transformer_config['specialization']:
|
| | logger.info(f"Converting specialization from list to string: {transformer_config['specialization'][0]}")
|
| | transformer_config['specialization'] = transformer_config['specialization'][0]
|
| | else:
|
| | logger.info("Setting empty specialization list to 'general'")
|
| | transformer_config['specialization'] = "general"
|
| |
|
| |
|
| | if 'DATASET_PATH' in transformer_config:
|
| | if isinstance(transformer_config['DATASET_PATH'], dict):
|
| |
|
| | if transformer_config['DATASET_PATH']:
|
| | first_key = next(iter(transformer_config['DATASET_PATH']))
|
| | path_value = transformer_config['DATASET_PATH'][first_key]
|
| | if isinstance(path_value, list) and path_value:
|
| | logger.info(f"Converting DATASET_PATH from dict of list to string: {path_value[0]}")
|
| | transformer_config['DATASET_PATH'] = path_value[0]
|
| | else:
|
| | logger.info(f"Converting DATASET_PATH from dict to string: {path_value}")
|
| | transformer_config['DATASET_PATH'] = str(path_value)
|
| | else:
|
| | transformer_config['DATASET_PATH'] = ""
|
| | elif isinstance(transformer_config['DATASET_PATH'], list):
|
| |
|
| | if transformer_config['DATASET_PATH']:
|
| | logger.info(f"Converting DATASET_PATH from list to string: {transformer_config['DATASET_PATH'][0]}")
|
| | transformer_config['DATASET_PATH'] = transformer_config['DATASET_PATH'][0]
|
| | else:
|
| | transformer_config['DATASET_PATH'] = ""
|
| |
|
| |
|
| | backup_path = f"{config_path}.bak"
|
| | os.rename(config_path, backup_path)
|
| |
|
| |
|
| | with open(config_path, 'w') as f:
|
| | json.dump(config_data, f, indent=2)
|
| |
|
| | logger.info(f"Config file fixed successfully (backup saved to {backup_path})")
|
| | return True
|
| |
|
| | except Exception as e:
|
| | logger.error(f"Error fixing config file: {e}")
|
| | return False
|
| |
|
| | def get_component(component_key: str, default: Any = None) -> Any:
|
| | """Get a component from the registry with proper error handling."""
|
| | component = registry.get(component_key)
|
| | if component is None:
|
| | logger.warning(f"Component {component_key} not found in registry - using default")
|
| | return component or default
|
| |
|
| |
|
| | def ensure_critical_components():
|
| | """Make sure critical components exist in registry, create minimal versions if not"""
|
| | logger.info("Ensuring critical components are available in registry")
|
| |
|
| |
|
| | if not registry.has(TOKENIZER):
|
| | try:
|
| | logger.info("Creating minimal tokenizer")
|
| | from transformers import AutoTokenizer
|
| | minimal_tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
|
| | registry.register(TOKENIZER, minimal_tokenizer)
|
| | logger.info("Registered minimal tokenizer in registry")
|
| | except Exception as e:
|
| | logger.error(f"Failed to create minimal tokenizer: {e}")
|
| |
|
| |
|
| | if not registry.has(MODEL_MANAGER):
|
| | logger.info("Creating minimal model manager")
|
| | minimal_manager = _create_minimal_model_manager()
|
| | registry.register(MODEL_MANAGER, minimal_manager)
|
| | logger.info("Registered minimal model manager in registry")
|
| |
|
| |
|
| | if not registry.has(COMMUNICATOR):
|
| | logger.info("Creating communicator via factory")
|
| | from communicator import create_communicator
|
| | create_communicator(registry.get(MODEL_MANAGER))
|
| | logger.info("Registered communicator in registry")
|
| |
|
| |
|
| | if not registry.has("communicator_stdp"):
|
| | try:
|
| |
|
| | import importlib
|
| | if importlib.util.find_spec("communicator_STDP"):
|
| | module = importlib.import_module("communicator_STDP")
|
| |
|
| | if hasattr(module, "Communicator_STDP"):
|
| | logger.info("Creating minimal STDP communicator")
|
| | minimal_stdp = module.Communicator_STDP()
|
| | registry.register("communicator_stdp", minimal_stdp)
|
| | logger.info("Registered minimal STDP communicator")
|
| | except Exception as e:
|
| | logger.warning(f"Could not create STDP communicator: {e}")
|
| |
|
| |
|
| | def initialize_pipeline():
|
| | """Initialize the processing pipeline components."""
|
| | try:
|
| |
|
| | from pipeline import Pipeline
|
| |
|
| |
|
| | pipeline = Pipeline()
|
| |
|
| |
|
| | registry.register(PIPELINE, pipeline)
|
| | logger.info("Initialized pipeline successfully")
|
| | return pipeline
|
| | except ImportError:
|
| | logger.warning("Pipeline module not available - skipping initialization")
|
| | return None
|
| | except Exception as e:
|
| | logger.error(f"Error initializing pipeline: {e}")
|
| | return None
|
| |
|
| |
|
| | def _create_minimal_model_manager():
|
| | """Create a minimal model manager for basic functionality"""
|
| | class MinimalModelManager:
|
| | def __init__(self):
|
| | self.models = {}
|
| | self.model_pool = {}
|
| | self.attempted_full_load = False
|
| | self.retry_count = 0
|
| | self.max_retries = 5
|
| | self.last_attempt_time = 0
|
| | self.retry_backoff = 60
|
| |
|
| |
|
| | self._attempt_load_full_model()
|
| |
|
| | def _attempt_load_full_model(self):
|
| | """Attempt to load a full model with weights"""
|
| | self.attempted_full_load = True
|
| | self.last_attempt_time = time.time()
|
| | self.retry_count += 1
|
| |
|
| | try:
|
| | logger.info(f"Attempting to load full model (attempt {self.retry_count}/{self.max_retries})")
|
| |
|
| |
|
| | try:
|
| |
|
| | from load_model_weights import download_model_files
|
| |
|
| |
|
| | repo_id_base = "Wildnerve/tlm-0.05Bx12"
|
| |
|
| |
|
| | logger.info(f"Trying to download weights from {repo_id_base}")
|
| | result = download_model_files(
|
| | repo_id_base=repo_id_base,
|
| | cache_dir=None
|
| | )
|
| |
|
| | if result and "transformer" in result:
|
| | logger.info(f"Successfully downloaded weights from {repo_id_base}")
|
| | else:
|
| |
|
| | fallback_repo = "EvolphTech/Checkpoints"
|
| | logger.info(f"Trying to download weights from {fallback_repo}")
|
| | result = download_model_files(
|
| | repo_id_base=fallback_repo,
|
| | cache_dir=None
|
| | )
|
| |
|
| | if not result or "transformer" not in result:
|
| |
|
| | logger.info(f"Trying to download weights from {fallback_repo}/Transformer")
|
| | result = download_model_files(
|
| | repo_id_base=f"{fallback_repo}/Transformer",
|
| | cache_dir=None
|
| | )
|
| |
|
| |
|
| | if result and "transformer" in result:
|
| |
|
| | os.environ["TLM_TRANSFORMER_WEIGHTS"] = result["transformer"]
|
| |
|
| |
|
| | from model_Custm import Wildnerve_tlm01
|
| |
|
| |
|
| | model_params = {
|
| | "specialization": "general",
|
| | "vocab_size": 30522,
|
| | "embedding_dim": 768,
|
| | "num_heads": 12,
|
| | "hidden_dim": 768,
|
| | "num_layers": 6,
|
| | "output_size": 30522,
|
| | "dropout": 0.1,
|
| | "max_seq_length": 512
|
| | }
|
| |
|
| | model = Wildnerve_tlm01(**model_params)
|
| |
|
| |
|
| | try:
|
| |
|
| | from load_model_weights import load_weights_into_model
|
| |
|
| |
|
| | if load_weights_into_model(model, result["transformer"], strict=False):
|
| | logger.info(f"Loading weights from TLM_TRANSFORMER_WEIGHTS: {result['transformer']}")
|
| |
|
| |
|
| | registry.register(MODEL, model)
|
| | self.models["default"] = model
|
| |
|
| |
|
| | self.primary_model_name = "Wildnerve-tlm01-0.05Bx12"
|
| | logger.info(f"Successfully initialized {self.primary_model_name} model")
|
| | return True
|
| | else:
|
| | logger.warning(f"Failed to load weights from {result['transformer']}")
|
| | except Exception as load_error:
|
| | logger.warning(f"Failed to load weights from TLM_TRANSFORMER_WEIGHTS: {load_error}")
|
| | except Exception as dl_error:
|
| | logger.error(f"Error downloading model weights: {dl_error}")
|
| |
|
| |
|
| | model_modules = ["model_Custm", "model_Combn", "model_PrTr"]
|
| | model_class_names = ["Wildnerve_tlm01", "PretrainedTransformer", "CombinedModel"]
|
| |
|
| |
|
| | for module_name in model_modules:
|
| | try:
|
| | module = importlib.import_module(module_name)
|
| | if hasattr(module, model_class_names[0]):
|
| | model_class = getattr(module, model_class_names[0])
|
| | model = model_class()
|
| |
|
| |
|
| | if "TLM_TRANSFORMER_WEIGHTS" in os.environ and os.path.exists(os.environ["TLM_TRANSFORMER_WEIGHTS"]):
|
| | try:
|
| | from load_model_weights import load_weights_into_model
|
| | if load_weights_into_model(model, os.environ["TLM_TRANSFORMER_WEIGHTS"], strict=False):
|
| | logger.info(f"Loading weights from TLM_TRANSFORMER_WEIGHTS: {os.environ['TLM_TRANSFORMER_WEIGHTS']}")
|
| | else:
|
| | logger.warning(f"Failed to load weights from {os.environ['TLM_TRANSFORMER_WEIGHTS']}")
|
| | except Exception as load_error:
|
| | logger.warning(f"Failed to load weights from TLM_TRANSFORMER_WEIGHTS: {load_error}")
|
| |
|
| |
|
| | registry.register(MODEL, model)
|
| | self.models["default"] = model
|
| | logger.info(f"Successfully loaded model from {module_name}")
|
| | return True
|
| | except Exception as e:
|
| | logger.warning(f"Error loading {module_name}: {e}")
|
| |
|
| |
|
| | logger.warning("All attempts to load full models failed")
|
| |
|
| |
|
| | if self.retry_count < self.max_retries:
|
| | backoff = self.retry_backoff * (2 ** (self.retry_count - 1))
|
| | logger.info(f"Will retry in {backoff} seconds (attempt {self.retry_count}/{self.max_retries})")
|
| | else:
|
| | logger.warning(f"Reached maximum retry attempts ({self.max_retries})")
|
| |
|
| | return False
|
| | except Exception as e:
|
| | logger.warning(f"Could not load an actual model: {e}")
|
| | return False
|
| | finally:
|
| |
|
| | if not self.models:
|
| |
|
| | self._create_minimal_model()
|
| |
|
| | def _create_minimal_model(self):
|
| | """Create a minimal model as fallback"""
|
| | logger.info("Creating minimal model instance")
|
| | try:
|
| | class MinimalModel:
|
| | def __init__(self):
|
| | self._is_minimal = True
|
| |
|
| | def forward(self, input_ids):
|
| | return torch.zeros((input_ids.shape[0], input_ids.shape[1], 30522))
|
| |
|
| | def generate_with_decoding(self, input_ids):
|
| | return "I'm running in minimal mode due to initialization issues. Please try again later for full model responses."
|
| |
|
| | minimal_model = MinimalModel()
|
| | self.models["default"] = minimal_model
|
| |
|
| | try:
|
| | registry.register(MODEL, minimal_model)
|
| | logger.warning("Registered MINIMAL model in registry - FULL MODEL UNAVAILABLE")
|
| | except Exception as reg_error:
|
| | logger.error(f"Failed to register minimal model: {reg_error}")
|
| | except Exception as minimal_error:
|
| | logger.error(f"Failed to create minimal model: {minimal_error}")
|
| |
|
| | def select_model_for_prompt(self, prompt):
|
| |
|
| | should_retry = (
|
| | len(self.models) > 0 and
|
| | "default" in self.models and
|
| | hasattr(self.models["default"], "_is_minimal") and
|
| | self.models["default"]._is_minimal and
|
| | self.retry_count < self.max_retries and
|
| | (time.time() - self.last_attempt_time) > self.retry_backoff * (2 ** (self.retry_count - 1))
|
| | )
|
| |
|
| | if should_retry:
|
| |
|
| | logger.info("Attempting to reload full model on demand")
|
| | self._attempt_load_full_model()
|
| |
|
| |
|
| | if self.models and "default" in self.models:
|
| | model = self.models["default"]
|
| | if hasattr(model, '_is_minimal') and model._is_minimal:
|
| | logger.warning("Using minimal model for prompt - FULL MODEL UNAVAILABLE")
|
| | else:
|
| | logger.info("Using full model for prompt")
|
| |
|
| | return model
|
| |
|
| | logger.debug(f"Minimal model manager received prompt but has no models: {prompt[:30]}...")
|
| | return None
|
| |
|
| | def get_available_models(self):
|
| | return self.models
|
| |
|
| | return MinimalModelManager()
|
| |
|
| | def _create_minimal_communicator():
|
| | """Create a minimal communicator for basic functionality"""
|
| | class MinimalCommunicator:
|
| | def __init__(self):
|
| |
|
| | self.knowledge_base = {
|
| | "malaysia": """
|
| | Malaysia is a Southeast Asian country located on the Malay Peninsula and parts of Borneo island. Key facts:
|
| |
|
| | - Capital: Kuala Lumpur (administrative capital is Putrajaya)
|
| | - Population: Approximately 32 million
|
| | - Languages: Malay (official), English, Chinese dialects, Tamil
|
| | - Government: Federal constitutional elective monarchy
|
| | - Currency: Malaysian Ringgit (MYR)
|
| | - Major ethnic groups: Malay, Chinese, Indian, indigenous peoples
|
| | - Notable landmarks: Petronas Twin Towers, Mount Kinabalu, Langkawi Island
|
| | - Cuisine: Famous for dishes like nasi lemak, satay, laksa, and roti canai
|
| | - Economy: Significant sectors include manufacturing, services, and agriculture (especially palm oil)
|
| |
|
| | Malaysia gained independence from British colonial rule in 1957 and is known for its diverse culture and rainforests.
|
| | """,
|
| | "python": """Python is a high-level programming language known for its readability and versatility.
|
| | It's great for web development, data science, AI, and more.""",
|
| | "javascript": "JavaScript is a programming language used primarily for web development.",
|
| | "ai": "Artificial Intelligence refers to systems that can perform tasks requiring human intelligence.",
|
| | "machine learning": "Machine learning is a subset of AI where systems learn from data without explicit programming."
|
| | }
|
| | logger.info("Initialized minimal communicator with basic knowledge base")
|
| |
|
| |
|
| | self.model = registry.get(MODEL)
|
| | self.tokenizer = registry.get(TOKENIZER)
|
| |
|
| | if self.model:
|
| | logger.info(f"Minimal communicator found model: {type(self.model).__name__}")
|
| |
|
| | def process_input(self, prompt, **kwargs):
|
| |
|
| | if self.model and self.tokenizer:
|
| | try:
|
| | logger.info("Attempting model inference with actual model")
|
| | inputs = self.tokenizer(prompt, return_tensors="pt", padding=True, truncation=True)
|
| |
|
| | if hasattr(self.model, "generate_with_decoding"):
|
| | response = self.model.generate_with_decoding(inputs.input_ids)
|
| | elif hasattr(self.model, "generate"):
|
| | output_ids = self.model.generate(inputs.input_ids)
|
| | response = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
| | else:
|
| |
|
| | outputs = self.model(inputs.input_ids)
|
| | response = self.tokenizer.decode(torch.argmax(outputs, dim=-1)[0], skip_special_tokens=True)
|
| |
|
| | if response and len(response) > 10:
|
| | logger.info("Generated model response successfully")
|
| | return {"response": response}
|
| | except Exception as e:
|
| | logger.warning(f"Model inference failed: {e}")
|
| |
|
| |
|
| | logger.debug(f"Minimal communicator processing: {prompt[:30]}...")
|
| | response = self._get_knowledge_response(prompt)
|
| |
|
| | if response:
|
| | return {"response": response}
|
| | return {"response": f"I'm analyzing your request about {prompt.split()[0] if prompt.split() else 'this topic'}..."}
|
| |
|
| | def process_request(self, prompt, model=None):
|
| |
|
| | if model:
|
| | try:
|
| | logger.info(f"Attempting inference with provided model: {type(model).__name__}")
|
| | tokenizer = registry.get(TOKENIZER)
|
| | inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True)
|
| |
|
| | if hasattr(model, "generate_with_decoding"):
|
| | return model.generate_with_decoding(inputs.input_ids)
|
| | elif hasattr(model, "generate"):
|
| | output_ids = model.generate(inputs.input_ids)
|
| | return tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
| | else:
|
| | outputs = model(inputs.input_ids)
|
| | return tokenizer.decode(torch.argmax(outputs, dim=-1)[0], skip_special_tokens=True)
|
| | except Exception as e:
|
| | logger.warning(f"Model inference failed: {e}")
|
| |
|
| |
|
| | result = self.process_input(prompt)
|
| | return result.get("response", f"I'm analyzing '{prompt[:20]}...'")
|
| |
|
| | def process_request_streaming(self, prompt, model=None):
|
| | response = self.process_request(prompt, model)
|
| | words = response.split()
|
| |
|
| | for word in words:
|
| | yield word + " "
|
| | time.sleep(0.05)
|
| |
|
| | def _get_knowledge_response(self, prompt):
|
| | """Check if we have knowledge about topics in the prompt"""
|
| | prompt_lower = prompt.lower()
|
| |
|
| |
|
| | for topic, info in self.knowledge_base.items():
|
| | if topic in prompt_lower:
|
| | return info
|
| |
|
| |
|
| | return None
|
| |
|
| | logger.info("Created enhanced minimal communicator for fallback functionality")
|
| | communicator = MinimalCommunicator()
|
| | return communicator
|
| |
|
| | def generate_response(prompt: str, stream: bool = False) -> Union[str, Generator[str, None, None]]:
|
| | """Generate a response using the appropriate models and communicators."""
|
| |
|
| | ensure_critical_components()
|
| |
|
| |
|
| | model_manager = get_component(MODEL_MANAGER)
|
| | communicator = get_component(COMMUNICATOR)
|
| | communicator_stdp = get_component("communicator_stdp")
|
| |
|
| | logger.info(f"Generating response for prompt: {prompt[:50]}...")
|
| |
|
| | try:
|
| |
|
| | if stream:
|
| | return _generate_streaming_response(prompt, model_manager, communicator, communicator_stdp)
|
| | else:
|
| | return _generate_complete_response(prompt, model_manager, communicator, communicator_stdp)
|
| | except Exception as e:
|
| | logger.error(f"Error generating response: {e}")
|
| | return f"Error generating response: {str(e)}"
|
| |
|
| | def _generate_complete_response(prompt: str, model_manager: Any, communicator: Any, communicator_stdp: Any) -> str:
|
| | """Generate a complete response (non-streaming)."""
|
| | logger.info(f"Generating response for prompt: {prompt[:50]}...")
|
| |
|
| |
|
| | start_time = time.time()
|
| |
|
| |
|
| | if communicator and hasattr(communicator, 'process_request'):
|
| | try:
|
| |
|
| | model = None
|
| | if model_manager and hasattr(model_manager, 'select_model_for_prompt'):
|
| | model = model_manager.select_model_for_prompt(prompt)
|
| |
|
| |
|
| | response = communicator.process_request(prompt, model)
|
| |
|
| | if response and not response.endswith("[PAD]") and response != prompt:
|
| | logger.info(f"Got response from communicator: {response[:50]}...")
|
| | return response
|
| | else:
|
| | logger.warning("Communicator returned echo or padded response, trying alternative methods")
|
| | except Exception as e:
|
| | logger.error(f"Error in communicator.process_request: {e}")
|
| |
|
| |
|
| | if communicator and hasattr(communicator, 'process_input'):
|
| | try:
|
| | result = communicator.process_input(prompt)
|
| | if isinstance(result, dict) and "response" in result:
|
| | response = result["response"]
|
| |
|
| |
|
| | if response and response != prompt and not response.endswith("[PAD]"):
|
| | logger.info(f"Got response from process_input: {response[:50]}...")
|
| | return response
|
| | else:
|
| | logger.warning("process_input returned echo or padded response")
|
| | except Exception as e:
|
| | logger.error(f"Error in communicator.process_input: {e}")
|
| |
|
| |
|
| | if communicator_stdp:
|
| | try:
|
| | stdp_response = communicator_stdp.process_request(prompt, None)
|
| | if stdp_response and stdp_response != prompt and not stdp_response.endswith("[PAD]"):
|
| | logger.info(f"Got response from STDP communicator: {stdp_response[:50]}...")
|
| | return stdp_response
|
| | else:
|
| | logger.warning("STDP communicator returned echo or padded response")
|
| | except Exception as e:
|
| | logger.error(f"Error in STDP communicator: {e}")
|
| |
|
| |
|
| | return f"I'm processing your request about '{prompt[:20]}...'"
|
| |
|
| | def _generate_streaming_response(prompt: str, model_manager: Any, communicator: Any, communicator_stdp: Any) -> Generator[str, None, None]:
|
| | """Generate a streaming response (yields tokens)."""
|
| | logger.info(f"Generating streaming response for: {prompt[:50]}...")
|
| |
|
| |
|
| | if communicator and hasattr(communicator, 'process_request_streaming'):
|
| | try:
|
| |
|
| | model = None
|
| | if model_manager and hasattr(model_manager, 'select_model_for_prompt'):
|
| | model = model_manager.select_model_for_prompt(prompt)
|
| |
|
| |
|
| | for token in communicator.process_request_streaming(prompt, model):
|
| | yield token
|
| | return
|
| | except Exception as e:
|
| | logger.error(f"Error in communicator streaming: {e}")
|
| |
|
| |
|
| | if communicator_stdp and hasattr(communicator_stdp, 'process_request_streaming'):
|
| | try:
|
| | for token in communicator_stdp.process_request_streaming(prompt):
|
| | yield token
|
| | return
|
| | except Exception as e:
|
| | logger.error(f"Error in STDP streaming: {e}")
|
| |
|
| |
|
| | response = _generate_complete_response(prompt, model_manager, communicator, communicator_stdp)
|
| | words = response.split()
|
| |
|
| | for word in words:
|
| | yield word + " "
|
| | time.sleep(0.05)
|
| |
|
| | def parse_args():
|
| | """Parse command line arguments."""
|
| | parser = argparse.ArgumentParser(description="Tiny Language Model Server")
|
| | parser.add_argument("--port", type=int, default=7860, help="Port to run the server on")
|
| | parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to run the server on")
|
| | parser.add_argument("--share", action="store_true", help="Enable sharing")
|
| | parser.add_argument("--debug", action="store_true", help="Enable debug mode")
|
| | parser.add_argument("--data_dir", type=str, help="Data directory")
|
| | parser.add_argument("--model_dir", type=str, help="Model directory")
|
| | parser.add_argument("--api_only", action="store_true", help="Run as API server only (no UI)")
|
| | parser.add_argument("--initialize", action="store_true", help="Initialize the models and exit")
|
| | return parser.parse_args()
|
| |
|
| | def setup_environment(args):
|
| | """Set up environment variables and directories."""
|
| | if args.data_dir:
|
| | logger.info(f"Using custom data directory: {args.data_dir}")
|
| | os.environ['TLM_DATA_DIR'] = args.data_dir
|
| | if args.model_dir:
|
| | logger.info(f"Using custom model directory: {args.model_dir}")
|
| | os.environ['TLM_MODEL_DIR'] = args.model_dir
|
| | if args.debug:
|
| | logger.info("Debug mode enabled")
|
| | logging.getLogger().setLevel(logging.DEBUG)
|
| | try:
|
| | from config import ensure_data_directories
|
| | ensure_data_directories()
|
| | except ImportError:
|
| | data_dir = os.environ.get('TLM_DATA_DIR', '/tmp/tlm_data')
|
| | os.makedirs(data_dir, exist_ok=True)
|
| | model_dir = os.path.join(data_dir, "models")
|
| | os.makedirs(model_dir, exist_ok=True)
|
| |
|
| | def initialize_system():
|
| | """Initialize all components in the correct order"""
|
| | logger.info("Starting system initialization")
|
| |
|
| |
|
| | try:
|
| | from transformers import GPT2Tokenizer
|
| | tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
|
| |
|
| | if tokenizer.pad_token is None:
|
| | tokenizer.pad_token = tokenizer.eos_token
|
| | except Exception as e:
|
| | logger.warning(f"Could not load GPT-2 tokenizer, falling back to wrapper: {e}")
|
| | from tokenizer import TokenizerWrapper
|
| | tokenizer = TokenizerWrapper(model_name="gpt2")
|
| |
|
| |
|
| | registry.register(TOKENIZER, tokenizer, overwrite=True)
|
| | logger.info("Tokenizer registered")
|
| |
|
| |
|
| | try:
|
| | from model_PrTr import GPT_2 as PretrainedModel
|
| | pretrained = PretrainedModel(model_name="gpt2", tokenizer=tokenizer)
|
| | registry.register(PRETRAINED_MODEL, pretrained, overwrite=True)
|
| | logger.info("GPT-2 pretrained model registered")
|
| | except Exception as e:
|
| | logger.error(f"Failed to initialize GPT-2 model: {e}", exc_info=True)
|
| |
|
| |
|
| | try:
|
| | from model_Custm import Wildnerve_tlm01
|
| |
|
| | arch_params = get_model_architecture_params()
|
| |
|
| | model = Wildnerve_tlm01(
|
| | vocab_size=arch_params["vocab_size"],
|
| | specialization="general",
|
| | dataset_path=None,
|
| | model_name="gpt2",
|
| | embedding_dim=arch_params["embedding_dim"],
|
| | num_heads=arch_params["num_heads"],
|
| | hidden_dim=arch_params["hidden_dim"],
|
| | num_layers=arch_params["num_layers"],
|
| | output_size=arch_params["vocab_size"],
|
| | dropout=arch_params["dropout"],
|
| | max_seq_length=arch_params["max_seq_length"],
|
| | pooling_mode="last",
|
| | tokenizer=tokenizer
|
| | )
|
| |
|
| |
|
| | registry.register(MODEL, model, overwrite=True)
|
| | logger.info("Custom model registered successfully")
|
| | return True
|
| | except Exception as e:
|
| | logger.error(f"Failed to initialize custom model: {e}", exc_info=True)
|
| | return False
|
| |
|
| | def main():
|
| | """Main application entry point with consolidated functionality"""
|
| |
|
| | success = initialize_system()
|
| | logger.info(f"System initialization {'successful' if success else 'failed'}")
|
| |
|
| |
|
| | from app import app
|
| | import uvicorn
|
| |
|
| | logger.info("Starting TLM application")
|
| | uvicorn.run(
|
| | app,
|
| | host="0.0.0.0",
|
| | port=int(os.getenv("PORT", 7860)),
|
| | workers=os.cpu_count() or 1,
|
| | loop="auto"
|
| | )
|
| |
|
| | if __name__ == "__main__":
|
| | main()
|
| |
|