""" Functions for downloading model weights from Hugging Face repositories. """ import os import sys import time import logging import traceback import torch # Add missing torch import from pathlib import Path from typing import Dict, Optional, Tuple, List, Any, Union from urllib.error import HTTPError from huggingface_hub import hf_hub_download, HfFileSystem, HfApi # Add the current directory to Python's path to ensure modules are found sys.path.append(os.path.dirname(os.path.abspath(__file__))) # Configure Logging logger = logging.getLogger(__name__) # Fix typo: getLOgger -> getLogger # Try local direct import first with fallback to a minimal version try: from model_repo_config import get_repo_config logger.info("Successfully imported model_repo_config") except ImportError: logger.warning("model_repo_config module not found, using minimal implementation") # Define minimal version inline as fallback class MinimalRepoConfig: """Minimal repository config for fallback""" def __init__(self): self.repo_id = "EvolphTech/Weights" self.cache_dir = "/tmp/tlm_cache" self.weight_locations = ["Wildnerve-tlm01-0.05Bx12.bin", "model.bin", "pytorch_model.bin"] self.snn_weight_locations = ["stdp_model_epoch_30.bin", "snn_model.bin"] self.default_repo = "EvolphTech/Weights" self.alternative_paths = ["Wildnerve/tlm-0.05Bx12", "Wildnerve/tlm", "EvolphTech/Checkpoints"] logger.info("Using minimal repository config") def get_auth_token(self): """Get authentication token from environment""" return os.environ.get("HF_TOKEN") or os.environ.get("HF_API_TOKEN") def save_download_status(self, success, files): """Minimal implementation that just logs""" logger.info(f"Download status: success={success}, files={len(files) if files else 0}") def get_repo_config(): """Get minimal repository config""" return MinimalRepoConfig() # Only set if not already set if not os.environ.get("HF_TOKEN"): os.environ["HF_TOKEN"] = "your_token_here" # Replace with your actual token # Configure logging logger = logging.getLogger(__name__) def verify_token(): """Verify the HF token is available and properly formatted.""" token = os.environ.get("HF_TOKEN", os.environ.get("HF_API_TOKEN")) if token: token_length = len(token) token_preview = token[:5] + "..." + token[-5:] if token_length > 10 else "too_short" logger.info(f"HF Token found: length={token_length}, preview={token_preview}") # Test if token works against a public Hugging Face API endpoint try: import requests headers = {"Authorization": f"Bearer {token}"} test_url = "https://huggingface.co/api/whoami" response = requests.get(test_url, headers=headers, timeout=10) if response.status_code == 200: user_info = response.json() logger.info(f"Token validated for user: {user_info.get('name', 'unknown')}") return True else: logger.warning(f"Token validation failed: {response.status_code} - {response.text[:100]}") except Exception as e: logger.warning(f"Error testing token: {e}") # Even if test fails, return True if we have a token return True else: logger.error("❌ HF Token not found in environment variables!") return False # Call this early in the script or application startup token_verified = verify_token() def verify_repository(repo_id: str, token: Optional[str] = None) -> Tuple[bool, List[str]]: """ Verify that a repository exists and is accessible. Args: repo_id: Repository ID to verify token: Optional Hugging Face API token Returns: (success, files): Tuple of success flag and list of files """ try: # Try to list the repository contents api = HfApi() logger.info(f"Verifying access to repository: {repo_id}") try: files = api.list_repo_files(repo_id, token=token) logger.info(f"Repository {repo_id} is accessible") logger.info(f"Found {len(files)} files in repository") return True, files except Exception as e: error_msg = str(e).lower() if "not found" in error_msg or "404" in error_msg: logger.error(f"Repository {repo_id} not found. Please check the name.") return False, [] elif "unauthorized" in error_msg or "permission" in error_msg or "401" in error_msg: if token: logger.error(f"Authentication failed for repository {repo_id} despite token") else: logger.error(f"No token provided for private repository {repo_id}") return False, [] else: logger.error(f"Error accessing repository {repo_id}: {e}") return False, [] except Exception as e: logger.error(f"Unexpected error verifying repository {repo_id}: {e}") return False, [] def download_file(repo_id: str, file_path: str, cache_dir: str, token: Optional[str] = None) -> Optional[str]: """ Download a file from a Hugging Face repository with retry logic. Args: repo_id: Repository ID file_path: Path to the file within the repository cache_dir: Directory to save the file token: Optional Hugging Face API token Returns: Path to the downloaded file if successful, None otherwise """ max_retries = 3 for attempt in range(1, max_retries + 1): try: logger.info(f"Downloading {file_path} from {repo_id} (attempt {attempt}/{max_retries})...") local_path = hf_hub_download( repo_id=repo_id, filename=file_path, cache_dir=cache_dir, force_download=attempt > 1, token=token ) logger.info(f"Successfully downloaded {file_path} to {local_path}") return local_path except Exception as e: logger.warning(f"Failed to download {file_path} from {repo_id} (attempt {attempt}/{max_retries}): {e}") if attempt == max_retries: return None time.sleep(1) # Wait before retry def check_for_local_weights(): """Check if weights are available locally""" # First check if we've already found weights (avoid redundant checks) if os.environ.get("MODEL_WEIGHTS_FOUND") == "true" or os.environ.get("USING_LOCAL_WEIGHTS") == "true": logger.info("Using previously found local weights") return True # Check for transformer weights transformer_weights = os.environ.get("TLM_TRANSFORMER_WEIGHTS") if transformer_weights and os.path.exists(transformer_weights): logger.info(f"Found transformer weights locally at: {transformer_weights}") # Check for SNN weights snn_weights = os.environ.get("TLM_SNN_WEIGHTS") if snn_weights and os.path.exists(snn_weights): logger.info(f"Found SNN weights locally at: {snn_weights}") # Set environment variable to indicate weights are found os.environ["MODEL_WEIGHTS_FOUND"] = "true" os.environ["USING_LOCAL_WEIGHTS"] = "true" return True # Check common paths for transformer weights transformer_paths = [ "/app/Weights/Transformer/Wildnerve-tlm01-0.05Bx12.bin", "/app/Weights/Wildnerve-tlm01-0.05Bx12.bin", "/app/weights/Wildnerve-tlm01-0.05Bx12.bin", "./Weights/Transformer/Wildnerve-tlm01-0.05Bx12.bin", "./Weights/Wildnerve-tlm01-0.05Bx12.bin" ] for path in transformer_paths: if os.path.exists(path): logger.info(f"Found transformer weights at: {path}") os.environ["TLM_TRANSFORMER_WEIGHTS"] = path os.environ["MODEL_WEIGHTS_FOUND"] = "true" # Check for SNN weights snn_paths = [ "/app/Weights/SNN/stdp_model_epoch_30.bin", "/app/Weights/stdp_model_epoch_30.bin", "/app/weights/stdp_model_epoch_30.bin", "./Weights/SNN/stdp_model_epoch_30.bin", "./Weights/stdp_model_epoch_30.bin" ] for snn_path in snn_paths: if os.path.exists(snn_path): logger.info(f"Found SNN weights at: {snn_path}") os.environ["TLM_SNN_WEIGHTS"] = snn_path break return True return False def load_model_weights(model=None): """Load model weights from local files or download from repository.""" # Check for local model weights first logger.info("Checking for local model weights...") if check_for_local_weights(): logger.info("Using local weights, skipping repository download") return { "transformer": os.environ.get("TLM_TRANSFORMER_WEIGHTS"), "snn": os.environ.get("TLM_SNN_WEIGHTS") } # Only attempt to download if no local weights logger.info("No local weights found, attempting to download from repository") # Get repository configuration config = get_repo_config() repo_id_base = config.repo_id cache_dir = config.cache_dir sub_dir = None return download_model_files(repo_id_base, sub_dir, cache_dir) def download_model_files(repo_id_base: str, sub_dir: Optional[str] = None, cache_dir: Optional[str] = None) -> Dict[str, str]: """ Download model files from a Hugging Face repository. Args: repo_id_base: Base repository ID sub_dir: Optional subdirectory within the repository cache_dir: Optional cache directory Returns: Dictionary of downloaded files (file_type: local_path) """ # Get global configuration config = get_repo_config() # Use provided cache_dir or fall back to config's cache_dir cache_dir = cache_dir or config.cache_dir # Get authentication token if available token = config.get_auth_token() # Dictionary to store downloaded file paths downloaded_files = {} # FIRST: Check if weights exist locally in the current directory or app directory local_weight_paths = [ "./Wildnerve-tlm01-0.05Bx12.bin", "./weights/Wildnerve-tlm01-0.05Bx12.bin", "./pytorch_model.bin", "./model.bin", "/app/Wildnerve-tlm01-0.05Bx12.bin", # For HF Spaces environment "/app/weights/Wildnerve-tlm01-0.05Bx12.bin", "/app/pytorch_model.bin" ] # Look for local weights first logger.info("Checking for local model weights...") for weight_path in local_weight_paths: if os.path.exists(weight_path): logger.info(f"Found local weights: {weight_path}") downloaded_files["transformer"] = weight_path # Try to find a config file too local_config_paths = [ os.path.join(os.path.dirname(weight_path), "config.json"), "./config.json", "/app/config.json" ] for config_path in local_config_paths: if os.path.exists(config_path): downloaded_files["config"] = config_path break # Set environment variables os.environ["TLM_TRANSFORMER_WEIGHTS"] = downloaded_files["transformer"] if "config" in downloaded_files: os.environ["TLM_CONFIG_PATH"] = downloaded_files["config"] # Return early since we found local weights logger.info(f"Using local weights: {weight_path}") return downloaded_files # If no local weights, continue with normal HF download procedure logger.info("No local weights found, attempting to download from repository") # Create full repository path (with subdir if provided) repo_id = repo_id_base if sub_dir: # Remove any trailing slashes from repo_id and leading slashes from sub_dir repo_id = repo_id_base.rstrip('/') + '/' + sub_dir.lstrip('/') # First try the primary Wildnerve model repository wildnerve_repo = "Wildnerve/tlm-0.05Bx12" logger.info(f"Trying primary Wildnerve model repository: {wildnerve_repo}") success, files = verify_repository(wildnerve_repo, token) if success: repo_id = wildnerve_repo else: # Verify repository exists and is accessible success, files = verify_repository(repo_id, token) if not success: # Try alternatives logger.info(f"Primary repository {repo_id} not accessible, trying alternatives") # Try Wildnerve model repo variants first wildnerve_variants = ["Wildnerve/tlm", "EvolphTech/Checkpoints"] for wildnerve_alt in wildnerve_variants: logger.info(f"Trying Wildnerve alternative: {wildnerve_alt}") success, files = verify_repository(wildnerve_alt, token) if success: repo_id = wildnerve_alt break # If still not successful, try other fallbacks if not success: for alt_repo in config.alternative_paths: logger.info(f"Trying alternative repository: {alt_repo}") success, files = verify_repository(alt_repo, token) if success: repo_id = alt_repo break # Use default if all alternatives fail if not success: repo_id = config.default_repo success, files = verify_repository(repo_id, token) # Dictionary to store downloaded file paths downloaded_files = {} # Download configuration if available try: logger.info(f"Downloading config from {repo_id}...") config_path = download_file(repo_id, "config.json", cache_dir, token) if config_path: downloaded_files["config"] = config_path else: logger.warning("Will use default config values") except Exception as e: logger.warning(f"Error downloading config: {e}") # Download transformer weights logger.info(f"Downloading transformer weights from {repo_id}...") transformer_path = None # First try the specific Wildnerve model file name wildnerve_paths = ["Wildnerve-tlm01-0.05Bx12.bin", "model.bin", "pytorch_model.bin"] for path in wildnerve_paths: logger.info(f"Trying Wildnerve model path: {path}") transformer_path = download_file(repo_id, path, cache_dir, token) if transformer_path: downloaded_files["transformer"] = transformer_path break # If that doesn't work, try the standard paths if not transformer_path: for path in config.weight_locations: transformer_path = download_file(repo_id, path, cache_dir, token) if transformer_path: downloaded_files["transformer"] = transformer_path break logger.info(f"Trying path: {path}") if not transformer_path: logger.warning("No transformer weights found, trying public BERT model as fallback") try: # Try to download BERT weights transformer_path = download_file(config.default_repo, "pytorch_model.bin", cache_dir, token) if transformer_path: downloaded_files["transformer"] = transformer_path logger.info("Successfully downloaded fallback BERT model") else: # Additional fallbacks to try for alt_repo in ["bert-base-uncased", "distilbert-base-uncased"]: transformer_path = download_file(alt_repo, "pytorch_model.bin", cache_dir, token) if transformer_path: downloaded_files["transformer"] = transformer_path logger.info(f"Successfully downloaded fallback model from {alt_repo}") break except Exception as e: logger.error(f"Failed to download fallback model: {e}") # Download SNN weights if transformer weights were found if "transformer" in downloaded_files: logger.info(f"Downloading SNN weights from {repo_id}...") snn_path = None for path in config.snn_weight_locations: snn_path = download_file(repo_id, path, cache_dir, token) if snn_path: downloaded_files["snn"] = snn_path break logger.info(f"Trying path: {path}") # Set environment variables for other modules to use if "transformer" in downloaded_files: os.environ["TLM_TRANSFORMER_WEIGHTS"] = downloaded_files["transformer"] if "snn" in downloaded_files: os.environ["TLM_SNN_WEIGHTS"] = downloaded_files["snn"] # Save download status config.save_download_status(bool(downloaded_files), downloaded_files) return downloaded_files def find_expanded_weights(base_weight_path, target_dim=768): """ Find expanded weights in various potential locations based on the base weight path. Args: base_weight_path: Path to the original weights file target_dim: Target embedding dimension to look for Returns: Path to expanded weights if found, otherwise None """ if not base_weight_path: return None base_name = os.path.basename(base_weight_path) base_stem, ext = os.path.splitext(base_name) expanded_name = f"{base_stem}_expanded_{target_dim}{ext}" # Check in common writable directories common_dirs = [ "/tmp", "/tmp/tlm_data", os.environ.get("TLM_DATA_DIR", "/tmp/tlm_data") ] # Also check the original directory original_dir = os.path.dirname(base_weight_path) if original_dir: common_dirs.append(original_dir) # Check each location for directory in common_dirs: if not directory: continue expanded_path = os.path.join(directory, expanded_name) if os.path.exists(expanded_path): logger.info(f"Found expanded weights at {expanded_path}") return expanded_path # Check just the base filename for absolute paths if os.path.exists(expanded_name): return expanded_name return None def load_weights_into_model(model, weights_path: str, strict: bool = False) -> bool: """ Load weights from a file into a model. Args: model: The model to load weights into weights_path: Path to the weights file strict: Whether to strictly enforce that the keys in the weights file match the model Returns: bool: True if weights were successfully loaded, False otherwise """ try: logger.info(f"Loading weights from {weights_path}") # Try expanded weights first expanded_path = find_expanded_weights(weights_path) if expanded_path: logger.info(f"Using expanded weights: {expanded_path}") weights_path = expanded_path # Load the state dictionary state_dict = torch.load(weights_path, map_location="cpu") # If state_dict has nested structure, extract the actual model weights if isinstance(state_dict, dict) and "model_state_dict" in state_dict: state_dict = state_dict["model_state_dict"] elif isinstance(state_dict, dict) and "state_dict" in state_dict: state_dict = state_dict["state_dict"] # Special handling for Wildnerve-tlm01-0.05Bx12 model if "Wildnerve-tlm01" in str(model.__class__): logger.info("Detected Wildnerve-tlm01 model, applying special weight loading") # Check if keys need to be remapped model_keys = dict(model.named_parameters()) state_dict_keys = set(state_dict.keys()) # Check key alignment if not any(k in state_dict_keys for k in model_keys.keys()): logger.info("Wildnerve model keys don't match state dict keys, attempting remapping") # Create mapping for common Wildnerve model patterns key_mappings = { "embedding.weight": ["embeddings.word_embeddings.weight", "embedding.weight", "word_embeddings.weight"], "pos_encoder.pe": ["position_embeddings.weight", "pos_encoder.pe", "pe"], "transformer_encoder": ["encoder.layer", "transformer.encoder", "transformer_encoder"], "classifier.weight": ["output.weight", "classifier.weight", "lm_head.weight"], "classifier.bias": ["output.bias", "classifier.bias", "lm_head.bias"] } # Apply mappings adapted_state_dict = {} for target_key, source_keys in key_mappings.items(): for source_key in source_keys: for sd_key in state_dict_keys: if source_key in sd_key: if target_key not in model_keys: # Find a target key that's close enough for mk in model_keys: if target_key.split('.')[0] in mk: adapted_state_dict[mk] = state_dict[sd_key] break else: adapted_state_dict[target_key] = state_dict[sd_key] # Try to load the remapped weights if adapted_state_dict: logger.info(f"Attempting to load with {len(adapted_state_dict)} remapped keys") try: missing_keys, unexpected_keys = model.load_state_dict(adapted_state_dict, strict=False) logger.info(f"Loaded remapped weights with {len(missing_keys)} missing keys and {len(unexpected_keys)} unexpected keys") return True except Exception as e: logger.error(f"Error loading remapped weights: {e}") # Special handling for transformer models from Hugging Face if all(k.startswith("bert.") or k.startswith("roberta.") or k.startswith("model.") for k in state_dict.keys()): # Try to adapt the state dict keys to match our model logger.info("Adapting pretrained Hugging Face transformer weights") adapted_state_dict = {} # Map expected model keys to state dict keys key_mappings = { # Common mappings for transformer models "embedding.weight": ["embeddings.word_embeddings.weight", "bert.embeddings.word_embeddings.weight"], "pos_encoder.pe": ["embeddings.position_embeddings.weight", "bert.embeddings.position_embeddings.weight"], "transformer_encoder": ["encoder.layer", "bert.encoder.layer"], "classifier.weight": ["cls.predictions.decoder.weight", "bert.pooler.dense.weight"], "classifier.bias": ["cls.predictions.decoder.bias", "bert.pooler.dense.bias"] } # Try to map keys from state dict to model model_keys = dict(model.named_parameters()) # First try exact matches for target_key, source_keys in key_mappings.items(): for source_key in source_keys: if source_key in state_dict: adapted_state_dict[target_key] = state_dict[source_key] break # If we have very few matches, try partial matches if len(adapted_state_dict) < len(model_keys) * 0.1: logger.info("Using partial key matching for weights") for model_key in model_keys: for sd_key in state_dict: # Skip keys already matched if model_key in adapted_state_dict: continue # Try to find common substrings in the key names key_parts = model_key.split('.') sd_parts = sd_key.split('.') # Check for common parts like "attention", "layer", etc. common_parts = set(key_parts) & set(sd_parts) if len(common_parts) > 0: adapted_state_dict[model_key] = state_dict[sd_key] break # If we still don't have many matches, try direct loading with non-strict mode if len(adapted_state_dict) < len(model_keys) * 0.5: logger.warning(f"Could not adapt many keys ({len(adapted_state_dict)}/{len(model_keys)})") logger.warning("Attempting to load original state dict with non-strict mode") try: # Load with non-strict mode to allow partial loading missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) logger.info(f"Loaded weights with {len(missing_keys)} missing keys and {len(unexpected_keys)} unexpected keys") return True except Exception as e: logger.error(f"Error loading original state dict: {e}") return False else: # Load adapted state dict logger.info(f"Loading adapted state dict with {len(adapted_state_dict)} keys") try: missing_keys, unexpected_keys = model.load_state_dict(adapted_state_dict, strict=False) logger.info(f"Loaded weights with {len(missing_keys)} missing keys and {len(unexpected_keys)} unexpected keys") return True except Exception as e: logger.error(f"Error loading adapted state dict: {e}") return False else: # Standard loading try: missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=strict) logger.info(f"Loaded weights with {len(missing_keys)} missing keys and {len(unexpected_keys)} unexpected keys") return True except Exception as e: logger.error(f"Error loading state dict: {e}") # Try non-strict loading if strict failed if strict: logger.info("Attempting non-strict loading") try: missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) logger.info(f"Loaded weights with {len(missing_keys)} missing keys and {len(unexpected_keys)} unexpected keys") return True except Exception as ne: logger.error(f"Non-strict loading also failed: {ne}") return False except Exception as e: logger.error(f"Failed to load weights: {e}") return False def list_model_files(repo_id: str, token: Optional[str] = None) -> List[str]: """ List model files in a repository. Args: repo_id: Repository ID token: Optional Hugging Face API token Returns: List of file paths """ try: api = HfApi() files = api.list_repo_files(repo_id, token=token) # Filter for model files model_files = [f for f in files if f.endswith('.bin') or f.endswith('.pt') or f.endswith('.pth')] logger.info(f"Found {len(model_files)} model files in {repo_id}") return model_files except Exception as e: logger.error(f"Error listing model files in {repo_id}: {e}") return [] if __name__ == "__main__": # Configure logging logging.basicConfig(level=logging.INFO) # Get arguments import argparse parser = argparse.ArgumentParser(description="Download model weights") parser.add_argument("--repo-id", type=str, default=None, help="Repository ID") parser.add_argument("--sub-dir", type=str, default=None, help="Subdirectory within repository") parser.add_argument("--cache-dir", type=str, default=None, help="Cache directory") args = parser.parse_args() # Download model files repo_id = args.repo_id or os.environ.get("MODEL_REPO") or get_repo_config().repo_id result = download_model_files(repo_id, args.sub_dir, args.cache_dir) # Print results print(f"\nDownload Results:") if "transformer" in result: print(f"Transformer weights: {result['transformer']}") else: print(f"⚠️ No transformer weights downloaded") if "snn" in result: print(f"SNN weights: {result['snn']}") else: print(f"⚠️ No SNN weights downloaded")