""" Functions for downloading model weights from Hugging Face repositories. """ import os import sys import time import logging import traceback import torch # Add missing torch import from pathlib import Path from typing import Dict, Optional, Tuple, List, Any, Union from urllib.error import HTTPError from huggingface_hub import hf_hub_download, HfFileSystem, HfApi # Add the current directory to Python's path to ensure modules are found sys.path.append(os.path.dirname(os.path.abspath(__file__))) # Configure Logging logger = logging.getLogger(__name__) # Fix typo: getLOgger -> getLogger # Try local direct import first with fallback to a minimal version try: from model_repo_config import get_repo_config logger.info("Successfully imported model_repo_config") except ImportError: logger.warning("model_repo_config module not found, using minimal implementation") # Define minimal version inline as fallback class MinimalRepoConfig: """Minimal repository config for fallback""" def __init__(self): self.repo_id = "EvolphTech/Weights" self.cache_dir = "/tmp/tlm_cache" self.weight_locations = ["Wildnerve-tlm01-0.05Bx12.bin", "model.bin", "pytorch_model.bin"] self.snn_weight_locations = ["stdp_model_epoch_30.bin", "snn_model.bin"] self.default_repo = "EvolphTech/Weights" self.alternative_paths = ["Wildnerve/tlm-0.05Bx12", "Wildnerve/tlm", "EvolphTech/Checkpoints"] logger.info("Using minimal repository config") def get_auth_token(self): """Get authentication token from environment""" return os.environ.get("HF_TOKEN") or os.environ.get("HF_API_TOKEN") def save_download_status(self, success, files): """Minimal implementation that just logs""" logger.info(f"Download status: success={success}, files={len(files) if files else 0}") def get_repo_config(): """Get minimal repository config""" return MinimalRepoConfig() # Only set if not already set if not os.environ.get("HF_TOKEN"): os.environ["HF_TOKEN"] = "your_token_here" # Replace with your actual token # Try to load token from file if not in env if not os.environ.get("HF_TOKEN"): token_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), ".hf_token") if os.path.exists(token_file): try: with open(token_file, "r") as f: token = f.read().strip() if token: os.environ["HF_TOKEN"] = token logger.info(f"Loaded token from file with length {len(token)}") except Exception as e: logger.error(f"Failed to load token from file: {e}") else: logger.warning("No token found in environment or token file") logger.warning("Run: python set_token.py YOUR_HF_TOKEN to set your token") os.environ["HF_TOKEN"] = "" # Set empty to avoid None issues # Ensure token isn't the placeholder if os.environ.get("HF_TOKEN") == "your_token_here": logger.warning("Token is still set to the placeholder 'your_token_here'") logger.warning("Please set a real token using set_token.py") os.environ["HF_TOKEN"] = "" # Clear the placeholder # Configure logging logger = logging.getLogger(__name__) def verify_token(): """Verify the HF token is available and properly formatted.""" token = os.environ.get("HF_TOKEN", os.environ.get("HF_API_TOKEN")) # Check if token exists at all if not token: logger.error("❌ HF_TOKEN not found in environment variables!") return False # Clean up token format - remove any "Bearer " prefix if present if token.startswith("Bearer "): token = token[7:].strip() # Fix typo: .trip() -> .strip() os.environ["HF_TOKEN"] = token # Store the cleaned token token_length = len(token) token_preview = token[:5] + "..." + token[-5:] if token_length > 10 else "too_short" logger.info(f"HF Token found: length={token_length}, preview={token_preview}") # Test if token works against a public Hugging Face API endpoint try: import requests headers = {"Authorization": f"Bearer {token}"} test_url = "https://huggingface.co/api/whoami" response = requests.get(test_url, headers=headers, timeout=10) if response.status_code == 200: user_info = response.json() logger.info(f"✅ Token validated for user: {user_info.get('name', 'unknown')}") return True else: logger.warning(f"❌ Token validation failed: {response.status_code} - {response.text[:100]}") logger.warning("Please make sure your token has the correct permissions") # Check for common token issues if response.status_code == 401: logger.warning("Token appears to be invalid or expired") elif response.status_code == 403: logger.warning("Token doesn't have required permissions") except Exception as e: logger.warning(f"Error testing token: {e}") # Return based on token presence, even if validation failed return bool(token) # Call this early in the script or application startup token_verified = verify_token() def verify_repository(repo_id: str, token: Optional[str] = None) -> Tuple[bool, List[str]]: """ Verify that a repository exists and is accessible. Args: repo_id: Repository ID to verify token: Optional Hugging Face API token Returns: (success, files): Tuple of success flag and list of files """ try: # Try to list the repository contents api = HfApi() logger.info(f"Verifying access to repository: {repo_id}") try: files = api.list_repo_files(repo_id, token=token) logger.info(f"Repository {repo_id} is accessible") logger.info(f"Found {len(files)} files in repository") return True, files except Exception as e: error_msg = str(e).lower() if "not found" in error_msg or "404" in error_msg: logger.error(f"Repository {repo_id} not found. Please check the name.") return False, [] elif "unauthorized" in error_msg or "permission" in error_msg or "401" in error_msg: if token: logger.error(f"Authentication failed for repository {repo_id} despite token") else: logger.error(f"No token provided for private repository {repo_id}") return False, [] else: logger.error(f"Error accessing repository {repo_id}: {e}") return False, [] except Exception as e: logger.error(f"Unexpected error verifying repository {repo_id}: {e}") return False, [] def download_file(repo_id: str, file_path: str, cache_dir: str, token: Optional[str] = None) -> Optional[str]: """ Download a file from a Hugging Face repository with retry logic. """ max_retries = 3 # Fix token formatting here - make sure it's properly formatted when sending to API if token: # Remove "Bearer " if it exists if token.startswith("Bearer "): token = token[7:].strip() # Don't send empty tokens if not token.strip(): token = None for attempt in range(1, max_retries + 1): try: logger.info(f"Downloading {file_path} from {repo_id} (attempt {attempt}/{max_retries})...") # More detailed logging for debugging if attempt > 1: token_status = "No token" if not token else f"Token with length {len(token)}" logger.info(f"Using: {token_status}") logger.info(f"Repo ID: {repo_id}, Path: {file_path}") # Use token=token directly - huggingface_hub will add "Bearer" internally local_path = hf_hub_download( repo_id=repo_id, filename=file_path, cache_dir=cache_dir, force_download=attempt > 1, token=token, local_files_only=False # Force online check ) # Verify file exists and has content if os.path.exists(local_path) and os.path.getsize(local_path) > 0: logger.info(f"✅ Successfully downloaded {file_path} to {local_path} ({os.path.getsize(local_path)/1024/1024:.1f} MB)") return local_path else: logger.warning(f"⚠️ Downloaded file exists but may be empty: {local_path}") if attempt < max_retries: continue return local_path except Exception as e: error_msg = str(e).lower() # More specific error handling if "401" in error_msg or "unauthorized" in error_msg: logger.warning(f"❌ Authentication error when downloading {file_path} from {repo_id}: {e}") logger.warning("Please check your HF_TOKEN environment variable") elif "404" in error_msg or "not found" in error_msg: logger.warning(f"❌ File or repository not found: {file_path} in {repo_id}") else: logger.warning(f"❌ Failed to download {file_path} from {repo_id} (attempt {attempt}/{max_retries}): {e}") if attempt == max_retries: return None time.sleep(1) # Wait before retry def check_for_local_weights(): """Check if weights are available locally""" # First check if we've already found weights (avoid redundant checks) if os.environ.get("MODEL_WEIGHTS_FOUND") == "true" or os.environ.get("USING_LOCAL_WEIGHTS") == "true": logger.info("Using previously found local weights") return True # Check for transformer weights transformer_weights = os.environ.get("TLM_TRANSFORMER_WEIGHTS") if transformer_weights and os.path.exists(transformer_weights): logger.info(f"Found transformer weights locally at: {transformer_weights}") # Check for SNN weights snn_weights = os.environ.get("TLM_SNN_WEIGHTS") if snn_weights and os.path.exists(snn_weights): logger.info(f"Found SNN weights locally at: {snn_weights}") # Set environment variable to indicate weights are found os.environ["MODEL_WEIGHTS_FOUND"] = "true" os.environ["USING_LOCAL_WEIGHTS"] = "true" return True # Check common paths for transformer weights transformer_paths = [ "/app/Weights/Transformer/Wildnerve-tlm01-0.05Bx12.bin", "/app/Weights/Wildnerve-tlm01-0.05Bx12.bin", "/app/weights/Wildnerve-tlm01-0.05Bx12.bin", "./Weights/Transformer/Wildnerve-tlm01-0.05Bx12.bin", "./Weights/Wildnerve-tlm01-0.05Bx12.bin" ] for path in transformer_paths: if os.path.exists(path): logger.info(f"Found transformer weights at: {path}") os.environ["TLM_TRANSFORMER_WEIGHTS"] = path os.environ["MODEL_WEIGHTS_FOUND"] = "true" # Check for SNN weights snn_paths = [ "/app/Weights/SNN/stdp_model_epoch_30.bin", "/app/Weights/stdp_model_epoch_30.bin", "/app/weights/stdp_model_epoch_30.bin", "./Weights/SNN/stdp_model_epoch_30.bin", "./Weights/stdp_model_epoch_30.bin" ] for snn_path in snn_paths: # FIXED: Added 'in snn_paths' here if os.path.exists(snn_path): logger.info(f"Found SNN weights at: {snn_path}") os.environ["TLM_SNN_WEIGHTS"] = snn_path break return True return False def load_model_weights(model=None): """Load model weights from local files or download from repository.""" # Check for local model weights first logger.info("Checking for local model weights...") if check_for_local_weights(): logger.info("Using local weights, skipping repository download") return { "transformer": os.environ.get("TLM_TRANSFORMER_WEIGHTS"), "snn": os.environ.get("TLM_SNN_WEIGHTS") } # Only attempt to download if no local weights logger.info("No local weights found, attempting to download from repository") # Get repository configuration config = get_repo_config() repo_id_base = config.repo_id cache_dir = config.cache_dir sub_dir = None return download_model_files(repo_id_base, sub_dir, cache_dir) def download_model_files(repo_id_base: str, sub_dir: Optional[str] = None, cache_dir: Optional[str] = None) -> Dict[str, str]: """ Download model files from a Hugging Face repository. Args: repo_id_base: Base repository ID sub_dir: Optional subdirectory within the repository cache_dir: Optional cache directory Returns: Dictionary of downloaded files (file_type: local_path) """ # Get global configuration config = get_repo_config() # Use provided cache_dir or fall back to config's cache_dir cache_dir = cache_dir or config.cache_dir # Get authentication token if available token = config.get_auth_token() # Dictionary to store downloaded file paths downloaded_files = {} # FIRST: Check if weights exist locally in the current directory or app directory local_weight_paths = [ "./Wildnerve-tlm01-0.05Bx12.bin", "./weights/Wildnerve-tlm01-0.05Bx12.bin", "./pytorch_model.bin", "./model.bin", "/app/Wildnerve-tlm01-0.05Bx12.bin", # For HF Spaces environment "/app/weights/Wildnerve-tlm01-0.05Bx12.bin", "/app/pytorch_model.bin" ] # Look for local weights first logger.info("Checking for local model weights...") for weight_path in local_weight_paths: if os.path.exists(weight_path): logger.info(f"Found local weights: {weight_path}") downloaded_files["transformer"] = weight_path # Try to find a config file too local_config_paths = [ os.path.join(os.path.dirname(weight_path), "config.json"), "./config.json", "/app/config.json" ] for config_path in local_config_paths: if os.path.exists(config_path): downloaded_files["config"] = config_path break # Set environment variables os.environ["TLM_TRANSFORMER_WEIGHTS"] = downloaded_files["transformer"] if "config" in downloaded_files: os.environ["TLM_CONFIG_PATH"] = downloaded_files["config"] # Return early since we found local weights logger.info(f"Using local weights: {weight_path}") return downloaded_files # If no local weights, continue with normal HF download procedure logger.info("No local weights found, attempting to download from repository") # Try EvolphTech/Weights repository with proper subdirectories evolphtech_repo = "EvolphTech/Weights" logger.info(f"Trying EvolphTech/Weights repository with proper subdirectories") # First check if the repository is accessible success, files = verify_repository(evolphtech_repo, token) if success: logger.info(f"✅ Successfully connected to {evolphtech_repo}") logger.info(f"Found {len(files)} files in repository") # DEBUG: List all files found to help diagnose logger.info(f"File list preview (first 10 files): {files[:10] if len(files) > 10 else files}") # Look specifically in the Transformer subdirectory transformer_paths = [ "Transformer/Wildnerve-tlm01-0.05Bx12.bin", "Transformer/model.bin", "Transformer/pytorch_model.bin" ] # Try downloading transformer weights with explicit subdirectory paths logger.info("Trying to download transformer weights from Transformer subdirectory") transformer_path = None for path in transformer_paths: logger.info(f"Attempting to download: {evolphtech_repo}/{path}") transformer_path = download_file(evolphtech_repo, path, cache_dir, token) if transformer_path: downloaded_files["transformer"] = transformer_path logger.info(f"✅ Successfully downloaded transformer weights: {path}") break # Look specifically in the SNN subdirectory if transformer weights were found if "transformer" in downloaded_files: snn_paths = [ "SNN/stdp_model_epoch_30.bin", "SNN/snn_model.bin" ] logger.info("Trying to download SNN weights from SNN subdirectory") snn_path = None for path in snn_paths: logger.info(f"Attempting to download: {evolphtech_repo}/{path}") snn_path = download_file(evolphtech_repo, path, cache_dir, token) if snn_path: downloaded_files["snn"] = snn_path logger.info(f"✅ Successfully downloaded SNN weights: {path}") break # If we found weights in the subdirectories, set env vars and return if "transformer" in downloaded_files: os.environ["TLM_TRANSFORMER_WEIGHTS"] = downloaded_files["transformer"] if "snn" in downloaded_files: os.environ["TLM_SNN_WEIGHTS"] = downloaded_files["snn"] # Save download status config.save_download_status(bool(downloaded_files), downloaded_files) return downloaded_files # If we get here, we couldn't find weights in the subdirectories - continue with original code logger.warning("Couldn't find weights in Transformer/SNN subdirectories, trying alternative paths") # Create full repository path (with subdir if provided) repo_id = repo_id_base if sub_dir: # Remove any trailing slashes from repo_id and leading slashes from sub_dir repo_id = repo_id_base.rstrip('/') + '/' + sub_dir.lstrip('/') # First try the primary Wildnerve model repository wildnerve_repo = "Wildnerve/tlm-0.05Bx12" logger.info(f"Trying primary Wildnerve model repository: {wildnerve_repo}") success, files = verify_repository(wildnerve_repo, token) if success: repo_id = wildnerve_repo else: # Verify repository exists and is accessible success, files = verify_repository(repo_id, token) if not success: # Try alternatives logger.info(f"Primary repository {repo_id} not accessible, trying alternatives") # Try Wildnerve model repo variants first wildnerve_variants = ["Wildnerve/tlm", "EvolphTech/Checkpoints"] for wildnerve_alt in wildnerve_variants: logger.info(f"Trying Wildnerve alternative: {wildnerve_alt}") success, files = verify_repository(wildnerve_alt, token) if success: repo_id = wildnerve_alt break # If still not successful, try other fallbacks if not success: for alt_repo in config.alternative_paths: logger.info(f"Trying alternative repository: {alt_repo}") success, files = verify_repository(alt_repo, token) if success: repo_id = alt_repo break # Use default if all alternatives fail if not success: repo_id = config.default_repo success, files = verify_repository(repo_id, token) # Dictionary to store downloaded file paths downloaded_files = {} # Download configuration if available try: logger.info(f"Downloading config from {repo_id}...") config_path = download_file(repo_id, "config.json", cache_dir, token) if config_path: downloaded_files["config"] = config_path else: logger.warning("Will use default config values") except Exception as e: logger.warning(f"Error downloading config: {e}") # Download transformer weights logger.info(f"Downloading transformer weights from {repo_id}...") transformer_path = None # First try the specific Wildnerve model file name wildnerve_paths = ["Wildnerve-tlm01-0.05Bx12.bin", "model.bin", "pytorch_model.bin"] for path in wildnerve_paths: logger.info(f"Trying Wildnerve model path: {path}") transformer_path = download_file(repo_id, path, cache_dir, token) if transformer_path: downloaded_files["transformer"] = transformer_path break # If that doesn't work, try the standard paths if not transformer_path: for path in config.weight_locations: transformer_path = download_file(repo_id, path, cache_dir, token) if transformer_path: downloaded_files["transformer"] = transformer_path break logger.info(f"Trying path: {path}") if not transformer_path: logger.warning("No transformer weights found, trying public BERT model as fallback") try: # Try to download BERT weights transformer_path = download_file(config.default_repo, "pytorch_model.bin", cache_dir, token) if transformer_path: downloaded_files["transformer"] = transformer_path logger.info("Successfully downloaded fallback BERT model") else: # Additional fallbacks to try for alt_repo in ["bert-base-uncased", "distilbert-base-uncased"]: transformer_path = download_file(alt_repo, "pytorch_model.bin", cache_dir, token) if transformer_path: downloaded_files["transformer"] = transformer_path logger.info(f"Successfully downloaded fallback model from {alt_repo}") break except Exception as e: logger.error(f"Failed to download fallback model: {e}") # Try public models if private repositories fail - ADD MORE PUBLIC MODELS if not transformer_path: logger.warning("⚠️ Could not download from private repos, trying public models WITHOUT token") try: # Try to download from public models directly using model IDs that don't require authentication public_models = [ "TinyLlama/TinyLlama-1.1B-Chat-v1.0", # Try this one first - it's small but good "google/mobilevit-small", # Very small model "prajjwal1/bert-tiny", # Extremely small BERT "distilbert/distilbert-base-uncased", # Public DistilBERT "google/bert_uncased_L-2_H-128_A-2", # Tiny BERT "hf-internal-testing/tiny-random-gptj" # Super tiny test model ] for model_id in public_models: logger.info(f"Trying public model WITHOUT token: {model_id}") try: # IMPORTANT: Don't pass the token for these public models transformer_path = download_file(model_id, "pytorch_model.bin", cache_dir, token=None) if transformer_path: downloaded_files["transformer"] = transformer_path logger.info(f"✅ Successfully downloaded weights from {model_id}") break except Exception as e: logger.warning(f"Could not download from {model_id}: {e}") except Exception as e: logger.error(f"Failed to download public models: {e}") # If still no weights, try to use a model from the transformers library directly if not transformer_path: try: # Try to use tiny-bert which should be bundled with transformers logger.info("Attempting to use tiny-bert from transformers cache") from transformers import AutoModel, AutoTokenizer model_id = "prajjwal1/bert-tiny" tiny_model = AutoModel.from_pretrained(model_id) tiny_tokenizer = AutoTokenizer.from_pretrained(model_id) # Save the model to a local file we can use tmp_dir = os.path.join(cache_dir or "/tmp/tlm_cache", "tiny-bert") os.makedirs(tmp_dir, exist_ok=True) temp_file = os.path.join(tmp_dir, "pytorch_model.bin") # Save model state dict torch.save(tiny_model.state_dict(), temp_file) logger.info(f"✅ Saved tiny-bert model to {temp_file}") # Add to downloaded files downloaded_files["transformer"] = temp_file transformer_path = temp_file except Exception as e: logger.error(f"Failed to use tiny-bert from transformers: {e}") # Download SNN weights if transformer weights were found if "transformer" in downloaded_files: logger.info(f"Downloading SNN weights from {repo_id}...") snn_path = None for path in config.snn_weight_locations: snn_path = download_file(repo_id, path, cache_dir, token) if snn_path: downloaded_files["snn"] = snn_path break logger.info(f"Trying path: {path}") # Set environment variables for other modules to use if "transformer" in downloaded_files: os.environ["TLM_TRANSFORMER_WEIGHTS"] = downloaded_files["transformer"] if "snn" in downloaded_files: os.environ["TLM_SNN_WEIGHTS"] = downloaded_files["snn"] # Save download status config.save_download_status(bool(downloaded_files), downloaded_files) return downloaded_files def find_expanded_weights(base_weight_path, target_dim=768): """ Find expanded weights in various potential locations based on the base weight path. Args: base_weight_path: Path to the original weights file target_dim: Target embedding dimension to look for Returns: Path to expanded weights if found, otherwise None """ if not base_weight_path: return None base_name = os.path.basename(base_weight_path) base_stem, ext = os.path.splitext(base_name) expanded_name = f"{base_stem}_expanded_{target_dim}{ext}" # Check in common writable directories common_dirs = [ "/tmp", "/tmp/tlm_data", os.environ.get("TLM_DATA_DIR", "/tmp/tlm_data") ] # Also check the original directory original_dir = os.path.dirname(base_weight_path) if original_dir: common_dirs.append(original_dir) # Check each location for directory in common_dirs: if not directory: continue expanded_path = os.path.join(directory, expanded_name) if os.path.exists(expanded_path): logger.info(f"Found expanded weights at {expanded_path}") return expanded_path # Check just the base filename for absolute paths if os.path.exists(expanded_name): return expanded_name return None def load_weights_into_model(model, weights_path: str, strict: bool = False) -> bool: """ Load weights from a file into a model. Args: model: The model to load weights into weights_path: Path to the weights file strict: Whether to strictly enforce that the keys in the weights file match the model Returns: bool: True if weights were successfully loaded, False otherwise """ try: logger.info(f"Loading weights from {weights_path}") # Try expanded weights first expanded_path = find_expanded_weights(weights_path) if expanded_path: logger.info(f"Using expanded weights: {expanded_path}") weights_path = expanded_path # Load the state dictionary state_dict = torch.load(weights_path, map_location="cpu") # If state_dict has nested structure, extract the actual model weights if isinstance(state_dict, dict) and "model_state_dict" in state_dict: state_dict = state_dict["model_state_dict"] elif isinstance(state_dict, dict) and "state_dict" in state_dict: state_dict = state_dict["state_dict"] # Get model config dimensions and state dict dimensions model_dims = {} state_dict_dims = {} # Extract key dimensions from model for name, param in model.named_parameters(): if 'weight' in name and len(param.shape) >= 1: if hasattr(param, 'shape') and len(param.shape) > 0: model_dims[name] = param.shape[0] # Capture primary dimension # Extract key dimensions from state dict for name, tensor in state_dict.items(): if 'weight' in name and len(tensor.shape) >= 1: state_dict_dims[name] = tensor.shape[0] # Compare common dimensions to detect mismatch common_keys = set(model_dims.keys()) & set(state_dict_dims.keys()) if common_keys: model_dim = None state_dict_dim = None # Find most common dimensions for key in common_keys: if not model_dim: model_dim = model_dims[key] if not state_dict_dim: state_dict_dim = state_dict_dims[key] # Log dimensional mismatch if model_dim != state_dict_dim: logger.warning(f"⚠️ Dimensional mismatch detected: model={model_dim}, weights={state_dict_dim}") logger.warning(f"This will cause incorrect outputs (gibberish) in generation") # Don't proceed with loading mismatched weights logger.error(f"❌ Aborting weight loading due to dimension mismatch") logger.error(f"You must use weights compatible with your model architecture") logger.error(f"Expected hidden_dim={model_dim}, got hidden_dim={state_dict_dim}") return False # Rest of your existing weight loading code below... # Standard loading try: missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=strict) logger.info(f"Loaded weights with {len(missing_keys)} missing keys and {len(unexpected_keys)} unexpected keys") return True except Exception as e: logger.error(f"Error loading state dict: {e}") # Try non-strict loading if strict failed if strict: logger.info("Attempting non-strict loading") try: missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) logger.info(f"Loaded weights with {len(missing_keys)} missing keys and {len(unexpected_keys)} unexpected keys") return True except Exception as ne: logger.error(f"Non-strict loading also failed: {ne}") return False except Exception as e: logger.error(f"Failed to load weights: {e}") return False def list_model_files(repo_id: str, token: Optional[str] = None) -> List[str]: """ List model files in a repository. Args: repo_id: Repository ID token: Optional Hugging Face API token Returns: List of file paths """ try: api = HfApi() files = api.list_repo_files(repo_id, token=token) # Filter for model files model_files = [f for f in files if f.endswith('.bin') or f.endswith('.pt') or f.endswith('.pth')] logger.info(f"Found {len(model_files)} model files in {repo_id}") return model_files except Exception as e: logger.error(f"Error listing model files in {repo_id}: {e}") return [] def set_token(token: str, save_to_file: bool = True) -> bool: """ Set the HF token for accessing private repositories. Args: token: The Hugging Face token to set save_to_file: Whether to save the token to a file for persistence Returns: bool: True if successful, False otherwise """ try: # Make sure the token doesn't have "Bearer " prefix if token.startswith("Bearer "): token = token[7:].strip() # Set the token in the environment os.environ["HF_TOKEN"] = token logger.info(f"Token set in environment with length {len(token)}") # Store in file if requested (for persistence between runs) if save_to_file: token_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), ".hf_token") with open(token_file, "w") as f: f.write(token) logger.info(f"Token saved to file: {token_file}") return True except Exception as e: logger.error(f"Error setting token: {e}") return False def get_token_from_file() -> Optional[str]: """ Load HF token from file if available. Returns: Optional[str]: The token if found in file, None otherwise """ token_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), ".hf_token") if os.path.exists(token_file): try: with open(token_file, "r") as f: token = f.read().strip() if token: return token except Exception as e: logger.error(f"Error reading token file: {e}") return None # Modify the existing verify_token function to use token from file def verify_token(): """Verify the HF token is available and properly formatted.""" # Try get token from environment first, then from file token = os.environ.get("HF_TOKEN", os.environ.get("HF_API_TOKEN")) # If no token in environment, try to load from file if not token: token = get_token_from_file() if token: os.environ["HF_TOKEN"] = token logger.info("Loaded HF_TOKEN from file") # Check if token exists at all if not token: logger.error("❌ HF_TOKEN not found in environment variables or token file!") return False # Clean up token format - remove any "Bearer " prefix if present if token.startswith("Bearer "): token = token[7:].strip() # Fix typo: trip() -> strip() os.environ["HF_TOKEN"] = token # Store the cleaned token token_length = len(token) token_preview = token[:5] + "..." + token[-5:] if token_length > 10 else "too_short" logger.info(f"HF Token found: length={token_length}, preview={token_preview}") # Test if token works against a public Hugging Face API endpoint try: import requests headers = {"Authorization": f"Bearer {token}"} test_url = "https://huggingface.co/api/whoami" response = requests.get(test_url, headers=headers, timeout=10) if response.status_code == 200: user_info = response.json() logger.info(f"✅ Token validated for user: {user_info.get('name', 'unknown')}") return True else: logger.warning(f"❌ Token validation failed: {response.status_code} - {response.text[:100]}") logger.warning("Please make sure your token has the correct permissions") # Check for common token issues if response.status_code == 401: logger.warning("Token appears to be invalid or expired") elif response.status_code == 403: logger.warning("Token doesn't have required permissions") except Exception as e: logger.warning(f"Error testing token: {e}") # Return based on token presence, even if validation failed return bool(token) if __name__ == "__main__": # Configure logging logging.basicConfig(level=logging.INFO) # Get arguments import argparse parser = argparse.ArgumentParser(description="Download model weights or set HF token") parser.add_argument("--repo-id", type=str, default=None, help="Repository ID") parser.add_argument("--sub-dir", type=str, default=None, help="Subdirectory within repository") parser.add_argument("--cache-dir", type=str, default=None, help="Cache directory") # Add set-token argument parser.add_argument("--set-token", type=str, help="Set Hugging Face token for private repositories") args = parser.parse_args() # Check if we're setting a token if (args.set_token): success = set_token(args.set_token) if success: print(f"✅ Token saved successfully with length {len(args.set_token)}") print("You can now use the model with this token") else: print("❌ Failed to set token") sys.exit(0 if success else 1) # Download model files repo_id = args.repo_id or os.environ.get("MODEL_REPO") or get_repo_config().repo_id result = download_model_files(repo_id, args.sub_dir, args.cache_dir) # Print results print(f"\nDownload Results:") if "transformer" in result: print(f"Transformer weights: {result['transformer']}") else: print(f"⚠️ No transformer weights downloaded") if "snn" in result: print(f"SNN weights: {result['snn']}") else: print(f"⚠️ No SNN weights downloaded")