WildnerveAI's picture
Upload main.py
26929f2 verified
# Main.py - Main entry point for Wildnerve-tlm_HF
import os
import sys
import time
import json
import gc
import logging
import argparse
import importlib
import threading
from typing import Dict, Any, Optional, List, Union, Generator, Tuple
from pathlib import Path
# Set up line buffering early
sys.stdout.reconfigure(line_buffering=True)
# Configure logging once at the top level
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
force=True
)
logger = logging.getLogger(__name__)
# Add file handlers for persistent logs
file_handler = logging.FileHandler('/tmp/app_debug.log')
file_handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
logger.addHandler(file_handler)
fh = logging.FileHandler("/tmp/container.log")
fh.setLevel(logging.DEBUG)
fh.setFormatter(logging.Formatter("%(asctime)s - %(name)s - %(levellevel)s - %(message)s"))
logging.getLogger().addHandler(fh)
logger.info("Logging configured")
# Force early initialization of vital environment variables
if not os.environ.get("TLM_DATA_DIR"):
os.environ["TLM_DATA_DIR"] = "/tmp/tlm_data"
# Select GPU if available
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {device}")
# Add GPU memory monitoring
try:
import torch
if torch.cuda.is_available():
torch.cuda.empty_cache()
logger.info(f"GPU Memory: {torch.cuda.memory_allocated()/1e9:.2f}GB / {torch.cuda.get_device_properties(0).total_memory/1e9:.2f}GB")
except Exception as e:
logger.warning(f"Error with PyTorch setup: {e}")
# Import configuration
try:
from config import app_config, load_config, get_model_architecture_params
# Create an emergency patch for config.py issue
if hasattr(app_config, 'TRANSFORMER_CONFIG'):
if not hasattr(app_config.TRANSFORMER_CONFIG, 'config_data') and isinstance(app_config.TRANSFORMER_CONFIG, dict):
# Create a minimal config_data attribute to avoid attribute errors
app_config.TRANSFORMER_CONFIG.config_data = app_config.TRANSFORMER_CONFIG
# Also ensure MODEL_NAME is gpt2, not distilbert
app_config.TRANSFORMER_CONFIG.MODEL_NAME = "gpt2"
app_config.TRANSFORMER_CONFIG.VOCAB_SIZE = 50257 # GPT-2 vocab size
elif hasattr(app_config.TRANSFORMER_CONFIG, 'specialization'):
# Ensure specialization is a list
if isinstance(app_config.TRANSFORMER_CONFIG.specialization, str):
if ',' in app_config.TRANSFORMER_CONFIG.specialization:
app_config.TRANSFORMER_CONFIG.specialization = [
s.strip() for s in app_config.TRANSFORMER_CONFIG.specialization.split(',')
]
else:
app_config.TRANSFORMER_CONFIG.specialization = [app_config.TRANSFORMER_CONFIG.specialization]
logger.info(f"Fixed specialization: {app_config.TRANSFORMER_CONFIG.specialization}")
except Exception as e:
logger.error(f"Error loading configuration: {e}")
raise # stop startup on config load failure
# Apply transformers patches
try:
import transformer_patches
except ImportError:
logger.warning("Could not import transformer_patches")
# Import service registry
try:
from service_registry import registry, MODEL, TOKENIZER, MODEL_MANAGER, COMMUNICATOR, PIPELINE, PRETRAINED_MODEL
# Import event system types
from utils.event_system import (
EVENT_STDP_REQUEST, EVENT_STDP_RESPONSE, EVENT_TOKEN_GENERATED,
EVENT_USER_INPUT, EVENT_MODEL_REQUEST, EVENT_MODEL_RESPONSE,
EVENT_RESPONSE_COMPLETE, EVENT_ERROR
)
# Also import event bus for lightweight communication
from utils.event_bus import event_bus
except ImportError as e:
logger.error(f"Failed to import core modules: {e}")
# Define minimal registry
class Registry:
def __init__(self):
self._registry = {}
def register(self, key, value):
self._registry[key] = value
def get(self, key, default=None):
return self._registry.get(key, default)
def has(self, key):
return key in self._registry
registry = Registry()
MODEL = "model"
TOKENIZER = "tokenizer"
MODEL_MANAGER = "model_manager"
COMMUNICATOR = "communicator"
PIPELINE = "pipeline"
PRETRAINED_MODEL = "pretrained_model" # Added this constant
# Define minimal event constants
EVENT_STDP_REQUEST = "stdp_request"
EVENT_STDP_RESPONSE = "stdp_response"
EVENT_TOKEN_GENERATED = "token_generated"
EVENT_USER_INPUT = "user_input"
EVENT_MODEL_REQUEST = "model_request"
EVENT_MODEL_RESPONSE = "model_response"
EVENT_RESPONSE_COMPLETE = "response_complete"
EVENT_ERROR = "error"
# Define minimal event bus
class EventBus:
def publish(self, event_type, data):
pass
event_bus = EventBus()
from find_weights import find_transformer_weights, find_snn_weights
# Import API components
try:
from api_wp import TLMInterface
from verify_repo import verify_model_repo_access
except Exception as e:
logger.error(f"Error importing API components: {e}")
# Define minimal placeholders
class TLMInterface:
def initialize(self, force=False):
return False
def process_input(self, text):
return {"response": "API unavailable"}
def verify_model_repo_access():
return False
# --- Helper functions ---
def fix_config_file(config_path="config.json"):
"""Fix the config file directly"""
import os
import json
try:
# Check if file exists
if not os.path.exists(config_path):
logger.error(f"Config file not found at {config_path}")
return False
# Read the config file
with open(config_path, 'r') as f:
config_data = json.load(f)
# Fix TRANSFORMER_CONFIG section if it exists
if 'TRANSFORMER_CONFIG' in config_data and isinstance(config_data['TRANSFORMER_CONFIG'], dict):
transformer_config = config_data['TRANSFORMER_CONFIG']
# Fix specialization if it's a list
if 'specialization' in transformer_config and isinstance(transformer_config['specialization'], list):
if transformer_config['specialization']:
logger.info(f"Converting specialization from list to string: {transformer_config['specialization'][0]}")
transformer_config['specialization'] = transformer_config['specialization'][0]
else:
logger.info("Setting empty specialization list to 'general'")
transformer_config['specialization'] = "general"
# Fix DATASET_PATH if it's a dict or list
if 'DATASET_PATH' in transformer_config:
if isinstance(transformer_config['DATASET_PATH'], dict):
# Take first value from dict
if transformer_config['DATASET_PATH']:
first_key = next(iter(transformer_config['DATASET_PATH']))
path_value = transformer_config['DATASET_PATH'][first_key]
if isinstance(path_value, list) and path_value:
logger.info(f"Converting DATASET_PATH from dict of list to string: {path_value[0]}")
transformer_config['DATASET_PATH'] = path_value[0]
else:
logger.info(f"Converting DATASET_PATH from dict to string: {path_value}")
transformer_config['DATASET_PATH'] = str(path_value)
else:
transformer_config['DATASET_PATH'] = ""
elif isinstance(transformer_config['DATASET_PATH'], list):
# Take first value from list
if transformer_config['DATASET_PATH']:
logger.info(f"Converting DATASET_PATH from list to string: {transformer_config['DATASET_PATH'][0]}")
transformer_config['DATASET_PATH'] = transformer_config['DATASET_PATH'][0]
else:
transformer_config['DATASET_PATH'] = ""
# Create backup of original
backup_path = f"{config_path}.bak"
os.rename(config_path, backup_path)
# Write the fixed config
with open(config_path, 'w') as f:
json.dump(config_data, f, indent=2)
logger.info(f"Config file fixed successfully (backup saved to {backup_path})")
return True
except Exception as e:
logger.error(f"Error fixing config file: {e}")
return False
def get_component(component_key: str, default: Any = None) -> Any:
"""Get a component from the registry with proper error handling."""
component = registry.get(component_key)
if component is None:
logger.warning(f"Component {component_key} not found in registry - using default")
return component or default
# Function to ensure critical components exist
def ensure_critical_components():
"""Make sure critical components exist in registry, create minimal versions if not"""
logger.info("Ensuring critical components are available in registry")
# Ensure tokenizer exists
if not registry.has(TOKENIZER):
try:
logger.info("Creating minimal tokenizer")
from transformers import AutoTokenizer
minimal_tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
registry.register(TOKENIZER, minimal_tokenizer)
logger.info("Registered minimal tokenizer in registry")
except Exception as e:
logger.error(f"Failed to create minimal tokenizer: {e}")
# Ensure model_manager exists
if not registry.has(MODEL_MANAGER):
logger.info("Creating minimal model manager")
minimal_manager = _create_minimal_model_manager()
registry.register(MODEL_MANAGER, minimal_manager)
logger.info("Registered minimal model manager in registry")
# Ensure communicator exists
if not registry.has(COMMUNICATOR):
logger.info("Creating communicator via factory")
from communicator import create_communicator
create_communicator(registry.get(MODEL_MANAGER))
logger.info("Registered communicator in registry")
# Ensure communicator_stdp exists if possible
if not registry.has("communicator_stdp"):
try:
# Try to import STDP communicator or create minimal version
import importlib
if importlib.util.find_spec("communicator_STDP"):
module = importlib.import_module("communicator_STDP")
if hasattr(module, "Communicator_STDP"):
logger.info("Creating minimal STDP communicator")
minimal_stdp = module.Communicator_STDP()
registry.register("communicator_stdp", minimal_stdp)
logger.info("Registered minimal STDP communicator")
except Exception as e:
logger.warning(f"Could not create STDP communicator: {e}")
# Initialize the pipeline components
def initialize_pipeline():
"""Initialize the processing pipeline components."""
try:
# Import pipeline dynamically to avoid circular dependencies
from pipeline import Pipeline
# Create pipeline instance
pipeline = Pipeline()
# Register in service registry
registry.register(PIPELINE, pipeline)
logger.info("Initialized pipeline successfully")
return pipeline
except ImportError:
logger.warning("Pipeline module not available - skipping initialization")
return None
except Exception as e:
logger.error(f"Error initializing pipeline: {e}")
return None
# Helper functions for model management
def _create_minimal_model_manager():
"""Create a minimal model manager for basic functionality"""
class MinimalModelManager:
def __init__(self):
self.models = {}
self.model_pool = {} # Add model_pool attribute that was missing
self.attempted_full_load = False
self.retry_count = 0
self.max_retries = 5
self.last_attempt_time = 0
self.retry_backoff = 60 # Start with 1 minute between retries
# First load attempt
self._attempt_load_full_model()
def _attempt_load_full_model(self):
"""Attempt to load a full model with weights"""
self.attempted_full_load = True
self.last_attempt_time = time.time()
self.retry_count += 1
try:
logger.info(f"Attempting to load full model (attempt {self.retry_count}/{self.max_retries})")
# First try to download model weights - Use primary Wildnerve model repo
try:
# Import weight downloader
from load_model_weights import download_model_files
# Use primary Wildnerve repo - this is the CORRECT primary model
repo_id_base = "Wildnerve/tlm-0.05Bx12"
# Try both default and Transformer/SNN subdirectories
logger.info(f"Trying to download weights from {repo_id_base}")
result = download_model_files(
repo_id_base=repo_id_base,
cache_dir=None # Use default cache dir
)
if result and "transformer" in result:
logger.info(f"Successfully downloaded weights from {repo_id_base}")
else:
# Try original repo from logs
fallback_repo = "EvolphTech/Checkpoints"
logger.info(f"Trying to download weights from {fallback_repo}")
result = download_model_files(
repo_id_base=fallback_repo,
cache_dir=None
)
if not result or "transformer" not in result:
# Try with Transformer subdirectory
logger.info(f"Trying to download weights from {fallback_repo}/Transformer")
result = download_model_files(
repo_id_base=f"{fallback_repo}/Transformer",
cache_dir=None
)
# Load weights into model if available
if result and "transformer" in result:
# Set environment variable for other components
os.environ["TLM_TRANSFORMER_WEIGHTS"] = result["transformer"]
# Import the model classes
from model_Custm import Wildnerve_tlm01
# Create model instance with correct parameters for Wildnerve-tlm01-0.05Bx12
model_params = {
"specialization": "general",
"vocab_size": 30522,
"embedding_dim": 768,
"num_heads": 12,
"hidden_dim": 768,
"num_layers": 6,
"output_size": 30522,
"dropout": 0.1,
"max_seq_length": 512
}
model = Wildnerve_tlm01(**model_params)
# Load weights into model
try:
# Import the weight loader function
from load_model_weights import load_weights_into_model
# Load weights
if load_weights_into_model(model, result["transformer"], strict=False):
logger.info(f"Loading weights from TLM_TRANSFORMER_WEIGHTS: {result['transformer']}")
# Register model in registry
registry.register(MODEL, model)
self.models["default"] = model
# Store model name for reference
self.primary_model_name = "Wildnerve-tlm01-0.05Bx12"
logger.info(f"Successfully initialized {self.primary_model_name} model")
return True
else:
logger.warning(f"Failed to load weights from {result['transformer']}")
except Exception as load_error:
logger.warning(f"Failed to load weights from TLM_TRANSFORMER_WEIGHTS: {load_error}")
except Exception as dl_error:
logger.error(f"Error downloading model weights: {dl_error}")
# Try loading model implementations
model_modules = ["model_Custm", "model_Combn", "model_PrTr"]
model_class_names = ["Wildnerve_tlm01", "PretrainedTransformer", "CombinedModel"]
# Try each module and class combination
for module_name in model_modules:
try:
module = importlib.import_module(module_name)
if hasattr(module, model_class_names[0]):
model_class = getattr(module, model_class_names[0])
model = model_class()
# Try to load weights if TLM_TRANSFORMER_WEIGHTS is set
if "TLM_TRANSFORMER_WEIGHTS" in os.environ and os.path.exists(os.environ["TLM_TRANSFORMER_WEIGHTS"]):
try:
from load_model_weights import load_weights_into_model
if load_weights_into_model(model, os.environ["TLM_TRANSFORMER_WEIGHTS"], strict=False):
logger.info(f"Loading weights from TLM_TRANSFORMER_WEIGHTS: {os.environ['TLM_TRANSFORMER_WEIGHTS']}")
else:
logger.warning(f"Failed to load weights from {os.environ['TLM_TRANSFORMER_WEIGHTS']}")
except Exception as load_error:
logger.warning(f"Failed to load weights from TLM_TRANSFORMER_WEIGHTS: {load_error}")
# Register model in registry
registry.register(MODEL, model)
self.models["default"] = model
logger.info(f"Successfully loaded model from {module_name}")
return True
except Exception as e:
logger.warning(f"Error loading {module_name}: {e}")
# If we get here, all attempts to load full models failed
logger.warning("All attempts to load full models failed")
# Schedule next retry with exponential backoff if needed
if self.retry_count < self.max_retries:
backoff = self.retry_backoff * (2 ** (self.retry_count - 1))
logger.info(f"Will retry in {backoff} seconds (attempt {self.retry_count}/{self.max_retries})")
else:
logger.warning(f"Reached maximum retry attempts ({self.max_retries})")
return False
except Exception as e:
logger.warning(f"Could not load an actual model: {e}")
return False
finally:
# Create minimal model as fallback
if not self.models:
# If we have no models at all, create a minimal one
self._create_minimal_model()
def _create_minimal_model(self):
"""Create a minimal model as fallback"""
logger.info("Creating minimal model instance")
try:
class MinimalModel:
def __init__(self):
self._is_minimal = True
def forward(self, input_ids):
return torch.zeros((input_ids.shape[0], input_ids.shape[1], 30522))
def generate_with_decoding(self, input_ids):
return "I'm running in minimal mode due to initialization issues. Please try again later for full model responses."
minimal_model = MinimalModel()
self.models["default"] = minimal_model
try:
registry.register(MODEL, minimal_model)
logger.warning("Registered MINIMAL model in registry - FULL MODEL UNAVAILABLE")
except Exception as reg_error:
logger.error(f"Failed to register minimal model: {reg_error}")
except Exception as minimal_error:
logger.error(f"Failed to create minimal model: {minimal_error}")
def select_model_for_prompt(self, prompt):
# Check if we should retry loading a full model
should_retry = (
len(self.models) > 0 and
"default" in self.models and
hasattr(self.models["default"], "_is_minimal") and
self.models["default"]._is_minimal and
self.retry_count < self.max_retries and
(time.time() - self.last_attempt_time) > self.retry_backoff * (2 ** (self.retry_count - 1))
)
if should_retry:
# Try to load full model again
logger.info("Attempting to reload full model on demand")
self._attempt_load_full_model()
# Return the model
if self.models and "default" in self.models:
model = self.models["default"]
if hasattr(model, '_is_minimal') and model._is_minimal:
logger.warning("Using minimal model for prompt - FULL MODEL UNAVAILABLE")
else:
logger.info("Using full model for prompt")
return model
logger.debug(f"Minimal model manager received prompt but has no models: {prompt[:30]}...")
return None
def get_available_models(self):
return self.models
return MinimalModelManager()
def _create_minimal_communicator():
"""Create a minimal communicator for basic functionality"""
class MinimalCommunicator:
def __init__(self):
# Knowledge base for minimal mode
self.knowledge_base = {
"malaysia": """
Malaysia is a Southeast Asian country located on the Malay Peninsula and parts of Borneo island. Key facts:
- Capital: Kuala Lumpur (administrative capital is Putrajaya)
- Population: Approximately 32 million
- Languages: Malay (official), English, Chinese dialects, Tamil
- Government: Federal constitutional elective monarchy
- Currency: Malaysian Ringgit (MYR)
- Major ethnic groups: Malay, Chinese, Indian, indigenous peoples
- Notable landmarks: Petronas Twin Towers, Mount Kinabalu, Langkawi Island
- Cuisine: Famous for dishes like nasi lemak, satay, laksa, and roti canai
- Economy: Significant sectors include manufacturing, services, and agriculture (especially palm oil)
Malaysia gained independence from British colonial rule in 1957 and is known for its diverse culture and rainforests.
""",
"python": """Python is a high-level programming language known for its readability and versatility.
It's great for web development, data science, AI, and more.""",
"javascript": "JavaScript is a programming language used primarily for web development.",
"ai": "Artificial Intelligence refers to systems that can perform tasks requiring human intelligence.",
"machine learning": "Machine learning is a subset of AI where systems learn from data without explicit programming."
}
logger.info("Initialized minimal communicator with basic knowledge base")
# Try to get model from registry
self.model = registry.get(MODEL)
self.tokenizer = registry.get(TOKENIZER)
if self.model:
logger.info(f"Minimal communicator found model: {type(self.model).__name__}")
def process_input(self, prompt, **kwargs):
# First try using a real model if available
if self.model and self.tokenizer:
try:
logger.info("Attempting model inference with actual model")
inputs = self.tokenizer(prompt, return_tensors="pt", padding=True, truncation=True)
if hasattr(self.model, "generate_with_decoding"):
response = self.model.generate_with_decoding(inputs.input_ids)
elif hasattr(self.model, "generate"):
output_ids = self.model.generate(inputs.input_ids)
response = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)
else:
# Forward pass
outputs = self.model(inputs.input_ids)
response = self.tokenizer.decode(torch.argmax(outputs, dim=-1)[0], skip_special_tokens=True)
if response and len(response) > 10: # Require reasonably long response
logger.info("Generated model response successfully")
return {"response": response}
except Exception as e:
logger.warning(f"Model inference failed: {e}")
# Check if prompt contains keywords we can respond to meaningfully
logger.debug(f"Minimal communicator processing: {prompt[:30]}...")
response = self._get_knowledge_response(prompt)
if response:
return {"response": response}
return {"response": f"I'm analyzing your request about {prompt.split()[0] if prompt.split() else 'this topic'}..."}
def process_request(self, prompt, model=None):
# Use provided model if given
if model:
try:
logger.info(f"Attempting inference with provided model: {type(model).__name__}")
tokenizer = registry.get(TOKENIZER)
inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True)
if hasattr(model, "generate_with_decoding"):
return model.generate_with_decoding(inputs.input_ids)
elif hasattr(model, "generate"):
output_ids = model.generate(inputs.input_ids)
return tokenizer.decode(output_ids[0], skip_special_tokens=True)
else:
outputs = model(inputs.input_ids)
return tokenizer.decode(torch.argmax(outputs, dim=-1)[0], skip_special_tokens=True)
except Exception as e:
logger.warning(f"Model inference failed: {e}")
# Fall back to process_input
result = self.process_input(prompt)
return result.get("response", f"I'm analyzing '{prompt[:20]}...'")
def process_request_streaming(self, prompt, model=None):
response = self.process_request(prompt, model)
words = response.split()
for word in words:
yield word + " "
time.sleep(0.05) # Simulate streaming
def _get_knowledge_response(self, prompt):
"""Check if we have knowledge about topics in the prompt"""
prompt_lower = prompt.lower()
# Check each topic in knowledge base
for topic, info in self.knowledge_base.items():
if topic in prompt_lower:
return info
# No matching topic found
return None
logger.info("Created enhanced minimal communicator for fallback functionality")
communicator = MinimalCommunicator()
return communicator
def generate_response(prompt: str, stream: bool = False) -> Union[str, Generator[str, None, None]]:
"""Generate a response using the appropriate models and communicators."""
# Ensure components exist before accessing them
ensure_critical_components()
# Get components from registry
model_manager = get_component(MODEL_MANAGER)
communicator = get_component(COMMUNICATOR)
communicator_stdp = get_component("communicator_stdp")
logger.info(f"Generating response for prompt: {prompt[:50]}...")
try:
# Process input and generate response
if stream:
return _generate_streaming_response(prompt, model_manager, communicator, communicator_stdp)
else:
return _generate_complete_response(prompt, model_manager, communicator, communicator_stdp)
except Exception as e:
logger.error(f"Error generating response: {e}")
return f"Error generating response: {str(e)}"
def _generate_complete_response(prompt: str, model_manager: Any, communicator: Any, communicator_stdp: Any) -> str:
"""Generate a complete response (non-streaming)."""
logger.info(f"Generating response for prompt: {prompt[:50]}...")
# Start time for performance measurement
start_time = time.time()
# Try getting a response from the model through the communicator first
if communicator and hasattr(communicator, 'process_request'):
try:
# Get the model
model = None
if model_manager and hasattr(model_manager, 'select_model_for_prompt'):
model = model_manager.select_model_for_prompt(prompt)
# Process through communicator
response = communicator.process_request(prompt, model)
if response and not response.endswith("[PAD]") and response != prompt:
logger.info(f"Got response from communicator: {response[:50]}...")
return response
else:
logger.warning("Communicator returned echo or padded response, trying alternative methods")
except Exception as e:
logger.error(f"Error in communicator.process_request: {e}")
# Try process_input method which returns a dictionary
if communicator and hasattr(communicator, 'process_input'):
try:
result = communicator.process_input(prompt)
if isinstance(result, dict) and "response" in result:
response = result["response"]
# Check if it's just echoing the prompt
if response and response != prompt and not response.endswith("[PAD]"):
logger.info(f"Got response from process_input: {response[:50]}...")
return response
else:
logger.warning("process_input returned echo or padded response")
except Exception as e:
logger.error(f"Error in communicator.process_input: {e}")
# Try STDP communicator as an alternative path
if communicator_stdp:
try:
stdp_response = communicator_stdp.process_request(prompt, None)
if stdp_response and stdp_response != prompt and not stdp_response.endswith("[PAD]"):
logger.info(f"Got response from STDP communicator: {stdp_response[:50]}...")
return stdp_response
else:
logger.warning("STDP communicator returned echo or padded response")
except Exception as e:
logger.error(f"Error in STDP communicator: {e}")
# Final fallback
return f"I'm processing your request about '{prompt[:20]}...'"
def _generate_streaming_response(prompt: str, model_manager: Any, communicator: Any, communicator_stdp: Any) -> Generator[str, None, None]:
"""Generate a streaming response (yields tokens)."""
logger.info(f"Generating streaming response for: {prompt[:50]}...")
# Try using the streaming methods first
if communicator and hasattr(communicator, 'process_request_streaming'):
try:
# Get model
model = None
if model_manager and hasattr(model_manager, 'select_model_for_prompt'):
model = model_manager.select_model_for_prompt(prompt)
# Process through streaming handler
for token in communicator.process_request_streaming(prompt, model):
yield token
return
except Exception as e:
logger.error(f"Error in communicator streaming: {e}")
# Try STDP communicator if available
if communicator_stdp and hasattr(communicator_stdp, 'process_request_streaming'):
try:
for token in communicator_stdp.process_request_streaming(prompt):
yield token
return
except Exception as e:
logger.error(f"Error in STDP streaming: {e}")
# Fallback to non-streaming response and simulate streaming
response = _generate_complete_response(prompt, model_manager, communicator, communicator_stdp)
words = response.split()
for word in words:
yield word + " "
time.sleep(0.05) # Simulate streaming
def parse_args():
"""Parse command line arguments."""
parser = argparse.ArgumentParser(description="Tiny Language Model Server")
parser.add_argument("--port", type=int, default=7860, help="Port to run the server on")
parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to run the server on")
parser.add_argument("--share", action="store_true", help="Enable sharing")
parser.add_argument("--debug", action="store_true", help="Enable debug mode")
parser.add_argument("--data_dir", type=str, help="Data directory")
parser.add_argument("--model_dir", type=str, help="Model directory")
parser.add_argument("--api_only", action="store_true", help="Run as API server only (no UI)")
parser.add_argument("--initialize", action="store_true", help="Initialize the models and exit")
return parser.parse_args()
def setup_environment(args):
"""Set up environment variables and directories."""
if args.data_dir:
logger.info(f"Using custom data directory: {args.data_dir}")
os.environ['TLM_DATA_DIR'] = args.data_dir
if args.model_dir:
logger.info(f"Using custom model directory: {args.model_dir}")
os.environ['TLM_MODEL_DIR'] = args.model_dir
if args.debug:
logger.info("Debug mode enabled")
logging.getLogger().setLevel(logging.DEBUG)
try:
from config import ensure_data_directories
ensure_data_directories()
except ImportError:
data_dir = os.environ.get('TLM_DATA_DIR', '/tmp/tlm_data')
os.makedirs(data_dir, exist_ok=True)
model_dir = os.path.join(data_dir, "models")
os.makedirs(model_dir, exist_ok=True)
def initialize_system():
"""Initialize all components in the correct order"""
logger.info("Starting system initialization")
# First tokenizer - Use GPT-2 tokenizer instead of BERT
try:
from transformers import GPT2Tokenizer
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
# GPT-2 tokenizer doesn't have a pad_token by default, so we set it
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
except Exception as e:
logger.warning(f"Could not load GPT-2 tokenizer, falling back to wrapper: {e}")
from tokenizer import TokenizerWrapper
tokenizer = TokenizerWrapper(model_name="gpt2")
# Then register tokenizer
registry.register(TOKENIZER, tokenizer, overwrite=True)
logger.info("Tokenizer registered")
# Initialize pretrained model (GPT-2)
try:
from model_PrTr import GPT_2 as PretrainedModel
pretrained = PretrainedModel(model_name="gpt2", tokenizer=tokenizer)
registry.register(PRETRAINED_MODEL, pretrained, overwrite=True)
logger.info("GPT-2 pretrained model registered")
except Exception as e:
logger.error(f"Failed to initialize GPT-2 model: {e}", exc_info=True)
# Now load custom model
try:
from model_Custm import Wildnerve_tlm01
# Use architecture parameters from config
arch_params = get_model_architecture_params()
model = Wildnerve_tlm01(
vocab_size=arch_params["vocab_size"],
specialization="general",
dataset_path=None,
model_name="gpt2",
embedding_dim=arch_params["embedding_dim"],
num_heads=arch_params["num_heads"],
hidden_dim=arch_params["hidden_dim"],
num_layers=arch_params["num_layers"],
output_size=arch_params["vocab_size"],
dropout=arch_params["dropout"],
max_seq_length=arch_params["max_seq_length"],
pooling_mode="last",
tokenizer=tokenizer
)
# Register model
registry.register(MODEL, model, overwrite=True)
logger.info("Custom model registered successfully")
return True
except Exception as e:
logger.error(f"Failed to initialize custom model: {e}", exc_info=True)
return False
def main():
"""Main application entry point with consolidated functionality"""
# Initialize the system first
success = initialize_system()
logger.info(f"System initialization {'successful' if success else 'failed'}")
# Start the server
from app import app
import uvicorn
logger.info("Starting TLM application")
uvicorn.run(
app,
host="0.0.0.0",
port=int(os.getenv("PORT", 7860)),
workers=os.cpu_count() or 1,
loop="auto"
)
if __name__ == "__main__":
main()