# Main.py - Main entry point for Wildnerve-tlm_HF import os import sys import time import json import gc import logging import argparse import importlib import threading from typing import Dict, Any, Optional, List, Union, Generator, Tuple from pathlib import Path # Set up line buffering early sys.stdout.reconfigure(line_buffering=True) # Configure logging once at the top level logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", force=True ) logger = logging.getLogger(__name__) # Add file handlers for persistent logs file_handler = logging.FileHandler('/tmp/app_debug.log') file_handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')) logger.addHandler(file_handler) fh = logging.FileHandler("/tmp/container.log") fh.setLevel(logging.DEBUG) fh.setFormatter(logging.Formatter("%(asctime)s - %(name)s - %(levellevel)s - %(message)s")) logging.getLogger().addHandler(fh) logger.info("Logging configured") # Force early initialization of vital environment variables if not os.environ.get("TLM_DATA_DIR"): os.environ["TLM_DATA_DIR"] = "/tmp/tlm_data" # Select GPU if available import torch device = torch.device("cuda" if torch.cuda.is_available() else "cpu") logger.info(f"Using device: {device}") # Add GPU memory monitoring try: import torch if torch.cuda.is_available(): torch.cuda.empty_cache() logger.info(f"GPU Memory: {torch.cuda.memory_allocated()/1e9:.2f}GB / {torch.cuda.get_device_properties(0).total_memory/1e9:.2f}GB") except Exception as e: logger.warning(f"Error with PyTorch setup: {e}") # Import configuration try: from config import app_config, load_config, get_model_architecture_params # Create an emergency patch for config.py issue if hasattr(app_config, 'TRANSFORMER_CONFIG'): if not hasattr(app_config.TRANSFORMER_CONFIG, 'config_data') and isinstance(app_config.TRANSFORMER_CONFIG, dict): # Create a minimal config_data attribute to avoid attribute errors app_config.TRANSFORMER_CONFIG.config_data = app_config.TRANSFORMER_CONFIG # Also ensure MODEL_NAME is gpt2, not distilbert app_config.TRANSFORMER_CONFIG.MODEL_NAME = "gpt2" app_config.TRANSFORMER_CONFIG.VOCAB_SIZE = 50257 # GPT-2 vocab size elif hasattr(app_config.TRANSFORMER_CONFIG, 'specialization'): # Ensure specialization is a list if isinstance(app_config.TRANSFORMER_CONFIG.specialization, str): if ',' in app_config.TRANSFORMER_CONFIG.specialization: app_config.TRANSFORMER_CONFIG.specialization = [ s.strip() for s in app_config.TRANSFORMER_CONFIG.specialization.split(',') ] else: app_config.TRANSFORMER_CONFIG.specialization = [app_config.TRANSFORMER_CONFIG.specialization] logger.info(f"Fixed specialization: {app_config.TRANSFORMER_CONFIG.specialization}") except Exception as e: logger.error(f"Error loading configuration: {e}") raise # stop startup on config load failure # Apply transformers patches try: import transformer_patches except ImportError: logger.warning("Could not import transformer_patches") # Import service registry try: from service_registry import registry, MODEL, TOKENIZER, MODEL_MANAGER, COMMUNICATOR, PIPELINE, PRETRAINED_MODEL # Import event system types from utils.event_system import ( EVENT_STDP_REQUEST, EVENT_STDP_RESPONSE, EVENT_TOKEN_GENERATED, EVENT_USER_INPUT, EVENT_MODEL_REQUEST, EVENT_MODEL_RESPONSE, EVENT_RESPONSE_COMPLETE, EVENT_ERROR ) # Also import event bus for lightweight communication from utils.event_bus import event_bus except ImportError as e: logger.error(f"Failed to import core modules: {e}") # Define minimal registry class Registry: def __init__(self): self._registry = {} def register(self, key, value): self._registry[key] = value def get(self, key, default=None): return self._registry.get(key, default) def has(self, key): return key in self._registry registry = Registry() MODEL = "model" TOKENIZER = "tokenizer" MODEL_MANAGER = "model_manager" COMMUNICATOR = "communicator" PIPELINE = "pipeline" PRETRAINED_MODEL = "pretrained_model" # Added this constant # Define minimal event constants EVENT_STDP_REQUEST = "stdp_request" EVENT_STDP_RESPONSE = "stdp_response" EVENT_TOKEN_GENERATED = "token_generated" EVENT_USER_INPUT = "user_input" EVENT_MODEL_REQUEST = "model_request" EVENT_MODEL_RESPONSE = "model_response" EVENT_RESPONSE_COMPLETE = "response_complete" EVENT_ERROR = "error" # Define minimal event bus class EventBus: def publish(self, event_type, data): pass event_bus = EventBus() from find_weights import find_transformer_weights, find_snn_weights # Import API components try: from api_wp import TLMInterface from verify_repo import verify_model_repo_access except Exception as e: logger.error(f"Error importing API components: {e}") # Define minimal placeholders class TLMInterface: def initialize(self, force=False): return False def process_input(self, text): return {"response": "API unavailable"} def verify_model_repo_access(): return False # --- Helper functions --- def fix_config_file(config_path="config.json"): """Fix the config file directly""" import os import json try: # Check if file exists if not os.path.exists(config_path): logger.error(f"Config file not found at {config_path}") return False # Read the config file with open(config_path, 'r') as f: config_data = json.load(f) # Fix TRANSFORMER_CONFIG section if it exists if 'TRANSFORMER_CONFIG' in config_data and isinstance(config_data['TRANSFORMER_CONFIG'], dict): transformer_config = config_data['TRANSFORMER_CONFIG'] # Fix specialization if it's a list if 'specialization' in transformer_config and isinstance(transformer_config['specialization'], list): if transformer_config['specialization']: logger.info(f"Converting specialization from list to string: {transformer_config['specialization'][0]}") transformer_config['specialization'] = transformer_config['specialization'][0] else: logger.info("Setting empty specialization list to 'general'") transformer_config['specialization'] = "general" # Fix DATASET_PATH if it's a dict or list if 'DATASET_PATH' in transformer_config: if isinstance(transformer_config['DATASET_PATH'], dict): # Take first value from dict if transformer_config['DATASET_PATH']: first_key = next(iter(transformer_config['DATASET_PATH'])) path_value = transformer_config['DATASET_PATH'][first_key] if isinstance(path_value, list) and path_value: logger.info(f"Converting DATASET_PATH from dict of list to string: {path_value[0]}") transformer_config['DATASET_PATH'] = path_value[0] else: logger.info(f"Converting DATASET_PATH from dict to string: {path_value}") transformer_config['DATASET_PATH'] = str(path_value) else: transformer_config['DATASET_PATH'] = "" elif isinstance(transformer_config['DATASET_PATH'], list): # Take first value from list if transformer_config['DATASET_PATH']: logger.info(f"Converting DATASET_PATH from list to string: {transformer_config['DATASET_PATH'][0]}") transformer_config['DATASET_PATH'] = transformer_config['DATASET_PATH'][0] else: transformer_config['DATASET_PATH'] = "" # Create backup of original backup_path = f"{config_path}.bak" os.rename(config_path, backup_path) # Write the fixed config with open(config_path, 'w') as f: json.dump(config_data, f, indent=2) logger.info(f"Config file fixed successfully (backup saved to {backup_path})") return True except Exception as e: logger.error(f"Error fixing config file: {e}") return False def get_component(component_key: str, default: Any = None) -> Any: """Get a component from the registry with proper error handling.""" component = registry.get(component_key) if component is None: logger.warning(f"Component {component_key} not found in registry - using default") return component or default # Function to ensure critical components exist def ensure_critical_components(): """Make sure critical components exist in registry, create minimal versions if not""" logger.info("Ensuring critical components are available in registry") # Ensure tokenizer exists if not registry.has(TOKENIZER): try: logger.info("Creating minimal tokenizer") from transformers import AutoTokenizer minimal_tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") registry.register(TOKENIZER, minimal_tokenizer) logger.info("Registered minimal tokenizer in registry") except Exception as e: logger.error(f"Failed to create minimal tokenizer: {e}") # Ensure model_manager exists if not registry.has(MODEL_MANAGER): logger.info("Creating minimal model manager") minimal_manager = _create_minimal_model_manager() registry.register(MODEL_MANAGER, minimal_manager) logger.info("Registered minimal model manager in registry") # Ensure communicator exists if not registry.has(COMMUNICATOR): logger.info("Creating communicator via factory") from communicator import create_communicator create_communicator(registry.get(MODEL_MANAGER)) logger.info("Registered communicator in registry") # Ensure communicator_stdp exists if possible if not registry.has("communicator_stdp"): try: # Try to import STDP communicator or create minimal version import importlib if importlib.util.find_spec("communicator_STDP"): module = importlib.import_module("communicator_STDP") if hasattr(module, "Communicator_STDP"): logger.info("Creating minimal STDP communicator") minimal_stdp = module.Communicator_STDP() registry.register("communicator_stdp", minimal_stdp) logger.info("Registered minimal STDP communicator") except Exception as e: logger.warning(f"Could not create STDP communicator: {e}") # Initialize the pipeline components def initialize_pipeline(): """Initialize the processing pipeline components.""" try: # Import pipeline dynamically to avoid circular dependencies from pipeline import Pipeline # Create pipeline instance pipeline = Pipeline() # Register in service registry registry.register(PIPELINE, pipeline) logger.info("Initialized pipeline successfully") return pipeline except ImportError: logger.warning("Pipeline module not available - skipping initialization") return None except Exception as e: logger.error(f"Error initializing pipeline: {e}") return None # Helper functions for model management def _create_minimal_model_manager(): """Create a minimal model manager for basic functionality""" class MinimalModelManager: def __init__(self): self.models = {} self.model_pool = {} # Add model_pool attribute that was missing self.attempted_full_load = False self.retry_count = 0 self.max_retries = 5 self.last_attempt_time = 0 self.retry_backoff = 60 # Start with 1 minute between retries # First load attempt self._attempt_load_full_model() def _attempt_load_full_model(self): """Attempt to load a full model with weights""" self.attempted_full_load = True self.last_attempt_time = time.time() self.retry_count += 1 try: logger.info(f"Attempting to load full model (attempt {self.retry_count}/{self.max_retries})") # First try to download model weights - Use primary Wildnerve model repo try: # Import weight downloader from load_model_weights import download_model_files # Use primary Wildnerve repo - this is the CORRECT primary model repo_id_base = "Wildnerve/tlm-0.05Bx12" # Try both default and Transformer/SNN subdirectories logger.info(f"Trying to download weights from {repo_id_base}") result = download_model_files( repo_id_base=repo_id_base, cache_dir=None # Use default cache dir ) if result and "transformer" in result: logger.info(f"Successfully downloaded weights from {repo_id_base}") else: # Try original repo from logs fallback_repo = "EvolphTech/Checkpoints" logger.info(f"Trying to download weights from {fallback_repo}") result = download_model_files( repo_id_base=fallback_repo, cache_dir=None ) if not result or "transformer" not in result: # Try with Transformer subdirectory logger.info(f"Trying to download weights from {fallback_repo}/Transformer") result = download_model_files( repo_id_base=f"{fallback_repo}/Transformer", cache_dir=None ) # Load weights into model if available if result and "transformer" in result: # Set environment variable for other components os.environ["TLM_TRANSFORMER_WEIGHTS"] = result["transformer"] # Import the model classes from model_Custm import Wildnerve_tlm01 # Create model instance with correct parameters for Wildnerve-tlm01-0.05Bx12 model_params = { "specialization": "general", "vocab_size": 30522, "embedding_dim": 768, "num_heads": 12, "hidden_dim": 768, "num_layers": 6, "output_size": 30522, "dropout": 0.1, "max_seq_length": 512 } model = Wildnerve_tlm01(**model_params) # Load weights into model try: # Import the weight loader function from load_model_weights import load_weights_into_model # Load weights if load_weights_into_model(model, result["transformer"], strict=False): logger.info(f"Loading weights from TLM_TRANSFORMER_WEIGHTS: {result['transformer']}") # Register model in registry registry.register(MODEL, model) self.models["default"] = model # Store model name for reference self.primary_model_name = "Wildnerve-tlm01-0.05Bx12" logger.info(f"Successfully initialized {self.primary_model_name} model") return True else: logger.warning(f"Failed to load weights from {result['transformer']}") except Exception as load_error: logger.warning(f"Failed to load weights from TLM_TRANSFORMER_WEIGHTS: {load_error}") except Exception as dl_error: logger.error(f"Error downloading model weights: {dl_error}") # Try loading model implementations model_modules = ["model_Custm", "model_Combn", "model_PrTr"] model_class_names = ["Wildnerve_tlm01", "PretrainedTransformer", "CombinedModel"] # Try each module and class combination for module_name in model_modules: try: module = importlib.import_module(module_name) if hasattr(module, model_class_names[0]): model_class = getattr(module, model_class_names[0]) model = model_class() # Try to load weights if TLM_TRANSFORMER_WEIGHTS is set if "TLM_TRANSFORMER_WEIGHTS" in os.environ and os.path.exists(os.environ["TLM_TRANSFORMER_WEIGHTS"]): try: from load_model_weights import load_weights_into_model if load_weights_into_model(model, os.environ["TLM_TRANSFORMER_WEIGHTS"], strict=False): logger.info(f"Loading weights from TLM_TRANSFORMER_WEIGHTS: {os.environ['TLM_TRANSFORMER_WEIGHTS']}") else: logger.warning(f"Failed to load weights from {os.environ['TLM_TRANSFORMER_WEIGHTS']}") except Exception as load_error: logger.warning(f"Failed to load weights from TLM_TRANSFORMER_WEIGHTS: {load_error}") # Register model in registry registry.register(MODEL, model) self.models["default"] = model logger.info(f"Successfully loaded model from {module_name}") return True except Exception as e: logger.warning(f"Error loading {module_name}: {e}") # If we get here, all attempts to load full models failed logger.warning("All attempts to load full models failed") # Schedule next retry with exponential backoff if needed if self.retry_count < self.max_retries: backoff = self.retry_backoff * (2 ** (self.retry_count - 1)) logger.info(f"Will retry in {backoff} seconds (attempt {self.retry_count}/{self.max_retries})") else: logger.warning(f"Reached maximum retry attempts ({self.max_retries})") return False except Exception as e: logger.warning(f"Could not load an actual model: {e}") return False finally: # Create minimal model as fallback if not self.models: # If we have no models at all, create a minimal one self._create_minimal_model() def _create_minimal_model(self): """Create a minimal model as fallback""" logger.info("Creating minimal model instance") try: class MinimalModel: def __init__(self): self._is_minimal = True def forward(self, input_ids): return torch.zeros((input_ids.shape[0], input_ids.shape[1], 30522)) def generate_with_decoding(self, input_ids): return "I'm running in minimal mode due to initialization issues. Please try again later for full model responses." minimal_model = MinimalModel() self.models["default"] = minimal_model try: registry.register(MODEL, minimal_model) logger.warning("Registered MINIMAL model in registry - FULL MODEL UNAVAILABLE") except Exception as reg_error: logger.error(f"Failed to register minimal model: {reg_error}") except Exception as minimal_error: logger.error(f"Failed to create minimal model: {minimal_error}") def select_model_for_prompt(self, prompt): # Check if we should retry loading a full model should_retry = ( len(self.models) > 0 and "default" in self.models and hasattr(self.models["default"], "_is_minimal") and self.models["default"]._is_minimal and self.retry_count < self.max_retries and (time.time() - self.last_attempt_time) > self.retry_backoff * (2 ** (self.retry_count - 1)) ) if should_retry: # Try to load full model again logger.info("Attempting to reload full model on demand") self._attempt_load_full_model() # Return the model if self.models and "default" in self.models: model = self.models["default"] if hasattr(model, '_is_minimal') and model._is_minimal: logger.warning("Using minimal model for prompt - FULL MODEL UNAVAILABLE") else: logger.info("Using full model for prompt") return model logger.debug(f"Minimal model manager received prompt but has no models: {prompt[:30]}...") return None def get_available_models(self): return self.models return MinimalModelManager() def _create_minimal_communicator(): """Create a minimal communicator for basic functionality""" class MinimalCommunicator: def __init__(self): # Knowledge base for minimal mode self.knowledge_base = { "malaysia": """ Malaysia is a Southeast Asian country located on the Malay Peninsula and parts of Borneo island. Key facts: - Capital: Kuala Lumpur (administrative capital is Putrajaya) - Population: Approximately 32 million - Languages: Malay (official), English, Chinese dialects, Tamil - Government: Federal constitutional elective monarchy - Currency: Malaysian Ringgit (MYR) - Major ethnic groups: Malay, Chinese, Indian, indigenous peoples - Notable landmarks: Petronas Twin Towers, Mount Kinabalu, Langkawi Island - Cuisine: Famous for dishes like nasi lemak, satay, laksa, and roti canai - Economy: Significant sectors include manufacturing, services, and agriculture (especially palm oil) Malaysia gained independence from British colonial rule in 1957 and is known for its diverse culture and rainforests. """, "python": """Python is a high-level programming language known for its readability and versatility. It's great for web development, data science, AI, and more.""", "javascript": "JavaScript is a programming language used primarily for web development.", "ai": "Artificial Intelligence refers to systems that can perform tasks requiring human intelligence.", "machine learning": "Machine learning is a subset of AI where systems learn from data without explicit programming." } logger.info("Initialized minimal communicator with basic knowledge base") # Try to get model from registry self.model = registry.get(MODEL) self.tokenizer = registry.get(TOKENIZER) if self.model: logger.info(f"Minimal communicator found model: {type(self.model).__name__}") def process_input(self, prompt, **kwargs): # First try using a real model if available if self.model and self.tokenizer: try: logger.info("Attempting model inference with actual model") inputs = self.tokenizer(prompt, return_tensors="pt", padding=True, truncation=True) if hasattr(self.model, "generate_with_decoding"): response = self.model.generate_with_decoding(inputs.input_ids) elif hasattr(self.model, "generate"): output_ids = self.model.generate(inputs.input_ids) response = self.tokenizer.decode(output_ids[0], skip_special_tokens=True) else: # Forward pass outputs = self.model(inputs.input_ids) response = self.tokenizer.decode(torch.argmax(outputs, dim=-1)[0], skip_special_tokens=True) if response and len(response) > 10: # Require reasonably long response logger.info("Generated model response successfully") return {"response": response} except Exception as e: logger.warning(f"Model inference failed: {e}") # Check if prompt contains keywords we can respond to meaningfully logger.debug(f"Minimal communicator processing: {prompt[:30]}...") response = self._get_knowledge_response(prompt) if response: return {"response": response} return {"response": f"I'm analyzing your request about {prompt.split()[0] if prompt.split() else 'this topic'}..."} def process_request(self, prompt, model=None): # Use provided model if given if model: try: logger.info(f"Attempting inference with provided model: {type(model).__name__}") tokenizer = registry.get(TOKENIZER) inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True) if hasattr(model, "generate_with_decoding"): return model.generate_with_decoding(inputs.input_ids) elif hasattr(model, "generate"): output_ids = model.generate(inputs.input_ids) return tokenizer.decode(output_ids[0], skip_special_tokens=True) else: outputs = model(inputs.input_ids) return tokenizer.decode(torch.argmax(outputs, dim=-1)[0], skip_special_tokens=True) except Exception as e: logger.warning(f"Model inference failed: {e}") # Fall back to process_input result = self.process_input(prompt) return result.get("response", f"I'm analyzing '{prompt[:20]}...'") def process_request_streaming(self, prompt, model=None): response = self.process_request(prompt, model) words = response.split() for word in words: yield word + " " time.sleep(0.05) # Simulate streaming def _get_knowledge_response(self, prompt): """Check if we have knowledge about topics in the prompt""" prompt_lower = prompt.lower() # Check each topic in knowledge base for topic, info in self.knowledge_base.items(): if topic in prompt_lower: return info # No matching topic found return None logger.info("Created enhanced minimal communicator for fallback functionality") communicator = MinimalCommunicator() return communicator def generate_response(prompt: str, stream: bool = False) -> Union[str, Generator[str, None, None]]: """Generate a response using the appropriate models and communicators.""" # Ensure components exist before accessing them ensure_critical_components() # Get components from registry model_manager = get_component(MODEL_MANAGER) communicator = get_component(COMMUNICATOR) communicator_stdp = get_component("communicator_stdp") logger.info(f"Generating response for prompt: {prompt[:50]}...") try: # Process input and generate response if stream: return _generate_streaming_response(prompt, model_manager, communicator, communicator_stdp) else: return _generate_complete_response(prompt, model_manager, communicator, communicator_stdp) except Exception as e: logger.error(f"Error generating response: {e}") return f"Error generating response: {str(e)}" def _generate_complete_response(prompt: str, model_manager: Any, communicator: Any, communicator_stdp: Any) -> str: """Generate a complete response (non-streaming).""" logger.info(f"Generating response for prompt: {prompt[:50]}...") # Start time for performance measurement start_time = time.time() # Try getting a response from the model through the communicator first if communicator and hasattr(communicator, 'process_request'): try: # Get the model model = None if model_manager and hasattr(model_manager, 'select_model_for_prompt'): model = model_manager.select_model_for_prompt(prompt) # Process through communicator response = communicator.process_request(prompt, model) if response and not response.endswith("[PAD]") and response != prompt: logger.info(f"Got response from communicator: {response[:50]}...") return response else: logger.warning("Communicator returned echo or padded response, trying alternative methods") except Exception as e: logger.error(f"Error in communicator.process_request: {e}") # Try process_input method which returns a dictionary if communicator and hasattr(communicator, 'process_input'): try: result = communicator.process_input(prompt) if isinstance(result, dict) and "response" in result: response = result["response"] # Check if it's just echoing the prompt if response and response != prompt and not response.endswith("[PAD]"): logger.info(f"Got response from process_input: {response[:50]}...") return response else: logger.warning("process_input returned echo or padded response") except Exception as e: logger.error(f"Error in communicator.process_input: {e}") # Try STDP communicator as an alternative path if communicator_stdp: try: stdp_response = communicator_stdp.process_request(prompt, None) if stdp_response and stdp_response != prompt and not stdp_response.endswith("[PAD]"): logger.info(f"Got response from STDP communicator: {stdp_response[:50]}...") return stdp_response else: logger.warning("STDP communicator returned echo or padded response") except Exception as e: logger.error(f"Error in STDP communicator: {e}") # Final fallback return f"I'm processing your request about '{prompt[:20]}...'" def _generate_streaming_response(prompt: str, model_manager: Any, communicator: Any, communicator_stdp: Any) -> Generator[str, None, None]: """Generate a streaming response (yields tokens).""" logger.info(f"Generating streaming response for: {prompt[:50]}...") # Try using the streaming methods first if communicator and hasattr(communicator, 'process_request_streaming'): try: # Get model model = None if model_manager and hasattr(model_manager, 'select_model_for_prompt'): model = model_manager.select_model_for_prompt(prompt) # Process through streaming handler for token in communicator.process_request_streaming(prompt, model): yield token return except Exception as e: logger.error(f"Error in communicator streaming: {e}") # Try STDP communicator if available if communicator_stdp and hasattr(communicator_stdp, 'process_request_streaming'): try: for token in communicator_stdp.process_request_streaming(prompt): yield token return except Exception as e: logger.error(f"Error in STDP streaming: {e}") # Fallback to non-streaming response and simulate streaming response = _generate_complete_response(prompt, model_manager, communicator, communicator_stdp) words = response.split() for word in words: yield word + " " time.sleep(0.05) # Simulate streaming def parse_args(): """Parse command line arguments.""" parser = argparse.ArgumentParser(description="Tiny Language Model Server") parser.add_argument("--port", type=int, default=7860, help="Port to run the server on") parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to run the server on") parser.add_argument("--share", action="store_true", help="Enable sharing") parser.add_argument("--debug", action="store_true", help="Enable debug mode") parser.add_argument("--data_dir", type=str, help="Data directory") parser.add_argument("--model_dir", type=str, help="Model directory") parser.add_argument("--api_only", action="store_true", help="Run as API server only (no UI)") parser.add_argument("--initialize", action="store_true", help="Initialize the models and exit") return parser.parse_args() def setup_environment(args): """Set up environment variables and directories.""" if args.data_dir: logger.info(f"Using custom data directory: {args.data_dir}") os.environ['TLM_DATA_DIR'] = args.data_dir if args.model_dir: logger.info(f"Using custom model directory: {args.model_dir}") os.environ['TLM_MODEL_DIR'] = args.model_dir if args.debug: logger.info("Debug mode enabled") logging.getLogger().setLevel(logging.DEBUG) try: from config import ensure_data_directories ensure_data_directories() except ImportError: data_dir = os.environ.get('TLM_DATA_DIR', '/tmp/tlm_data') os.makedirs(data_dir, exist_ok=True) model_dir = os.path.join(data_dir, "models") os.makedirs(model_dir, exist_ok=True) def initialize_system(): """Initialize all components in the correct order""" logger.info("Starting system initialization") # First tokenizer - Use GPT-2 tokenizer instead of BERT try: from transformers import GPT2Tokenizer tokenizer = GPT2Tokenizer.from_pretrained("gpt2") # GPT-2 tokenizer doesn't have a pad_token by default, so we set it if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token except Exception as e: logger.warning(f"Could not load GPT-2 tokenizer, falling back to wrapper: {e}") from tokenizer import TokenizerWrapper tokenizer = TokenizerWrapper(model_name="gpt2") # Then register tokenizer registry.register(TOKENIZER, tokenizer, overwrite=True) logger.info("Tokenizer registered") # Initialize pretrained model (GPT-2) try: from model_PrTr import GPT_2 as PretrainedModel pretrained = PretrainedModel(model_name="gpt2", tokenizer=tokenizer) registry.register(PRETRAINED_MODEL, pretrained, overwrite=True) logger.info("GPT-2 pretrained model registered") except Exception as e: logger.error(f"Failed to initialize GPT-2 model: {e}", exc_info=True) # Now load custom model try: from model_Custm import Wildnerve_tlm01 # Use architecture parameters from config arch_params = get_model_architecture_params() model = Wildnerve_tlm01( vocab_size=arch_params["vocab_size"], specialization="general", dataset_path=None, model_name="gpt2", embedding_dim=arch_params["embedding_dim"], num_heads=arch_params["num_heads"], hidden_dim=arch_params["hidden_dim"], num_layers=arch_params["num_layers"], output_size=arch_params["vocab_size"], dropout=arch_params["dropout"], max_seq_length=arch_params["max_seq_length"], pooling_mode="last", tokenizer=tokenizer ) # Register model registry.register(MODEL, model, overwrite=True) logger.info("Custom model registered successfully") return True except Exception as e: logger.error(f"Failed to initialize custom model: {e}", exc_info=True) return False def main(): """Main application entry point with consolidated functionality""" # Initialize the system first success = initialize_system() logger.info(f"System initialization {'successful' if success else 'failed'}") # Start the server from app import app import uvicorn logger.info("Starting TLM application") uvicorn.run( app, host="0.0.0.0", port=int(os.getenv("PORT", 7860)), workers=os.cpu_count() or 1, loop="auto" ) if __name__ == "__main__": main()