WildnerveAI commited on
Commit
234321b
·
verified ·
1 Parent(s): 1c78d33

Upload load_model_weights.py

Browse files
Files changed (1) hide show
  1. load_model_weights.py +39 -9
load_model_weights.py CHANGED
@@ -67,7 +67,7 @@ def verify_token():
67
 
68
  # Clean up token format - remove any "Bearer " prefix if present
69
  if token.startswith("Bearer "):
70
- token = token[7:].trip()
71
  os.environ["HF_TOKEN"] = token # Store the cleaned token
72
 
73
  token_length = len(token)
@@ -512,21 +512,25 @@ def download_model_files(repo_id_base: str, sub_dir: Optional[str] = None,
512
  except Exception as e:
513
  logger.error(f"Failed to download fallback model: {e}")
514
 
515
- # Try public models if private repositories fail
516
  if not transformer_path:
517
- logger.warning("⚠️ Could not download from private repos, trying public models")
518
  try:
519
- # Try to download from public models directly using model IDs
520
  public_models = [
521
- "gpt2", # Small but works well
522
- "distilgpt2", # Even smaller
523
- "facebook/opt-125m" # Another small model
 
 
 
524
  ]
525
 
526
  for model_id in public_models:
527
- logger.info(f"Trying public model: {model_id}")
528
  try:
529
- transformer_path = download_file(model_id, "pytorch_model.bin", cache_dir, None)
 
530
  if transformer_path:
531
  downloaded_files["transformer"] = transformer_path
532
  logger.info(f"✅ Successfully downloaded weights from {model_id}")
@@ -536,6 +540,32 @@ def download_model_files(repo_id_base: str, sub_dir: Optional[str] = None,
536
 
537
  except Exception as e:
538
  logger.error(f"Failed to download public models: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
539
 
540
  # Download SNN weights if transformer weights were found
541
  if "transformer" in downloaded_files:
 
67
 
68
  # Clean up token format - remove any "Bearer " prefix if present
69
  if token.startswith("Bearer "):
70
+ token = token[7:].strip() # Fix typo: .trip() -> .strip()
71
  os.environ["HF_TOKEN"] = token # Store the cleaned token
72
 
73
  token_length = len(token)
 
512
  except Exception as e:
513
  logger.error(f"Failed to download fallback model: {e}")
514
 
515
+ # Try public models if private repositories fail - ADD MORE PUBLIC MODELS
516
  if not transformer_path:
517
+ logger.warning("⚠️ Could not download from private repos, trying public models WITHOUT token")
518
  try:
519
+ # Try to download from public models directly using model IDs that don't require authentication
520
  public_models = [
521
+ "TinyLlama/TinyLlama-1.1B-Chat-v1.0", # Try this one first - it's small but good
522
+ "google/mobilevit-small", # Very small model
523
+ "prajjwal1/bert-tiny", # Extremely small BERT
524
+ "distilbert/distilbert-base-uncased", # Public DistilBERT
525
+ "google/bert_uncased_L-2_H-128_A-2", # Tiny BERT
526
+ "hf-internal-testing/tiny-random-gptj" # Super tiny test model
527
  ]
528
 
529
  for model_id in public_models:
530
+ logger.info(f"Trying public model WITHOUT token: {model_id}")
531
  try:
532
+ # IMPORTANT: Don't pass the token for these public models
533
+ transformer_path = download_file(model_id, "pytorch_model.bin", cache_dir, token=None)
534
  if transformer_path:
535
  downloaded_files["transformer"] = transformer_path
536
  logger.info(f"✅ Successfully downloaded weights from {model_id}")
 
540
 
541
  except Exception as e:
542
  logger.error(f"Failed to download public models: {e}")
543
+
544
+ # If still no weights, try to use a model from the transformers library directly
545
+ if not transformer_path:
546
+ try:
547
+ # Try to use tiny-bert which should be bundled with transformers
548
+ logger.info("Attempting to use tiny-bert from transformers cache")
549
+ from transformers import AutoModel, AutoTokenizer
550
+
551
+ model_id = "prajjwal1/bert-tiny"
552
+ tiny_model = AutoModel.from_pretrained(model_id)
553
+ tiny_tokenizer = AutoTokenizer.from_pretrained(model_id)
554
+
555
+ # Save the model to a local file we can use
556
+ tmp_dir = os.path.join(cache_dir or "/tmp/tlm_cache", "tiny-bert")
557
+ os.makedirs(tmp_dir, exist_ok=True)
558
+ temp_file = os.path.join(tmp_dir, "pytorch_model.bin")
559
+
560
+ # Save model state dict
561
+ torch.save(tiny_model.state_dict(), temp_file)
562
+ logger.info(f"✅ Saved tiny-bert model to {temp_file}")
563
+
564
+ # Add to downloaded files
565
+ downloaded_files["transformer"] = temp_file
566
+ transformer_path = temp_file
567
+ except Exception as e:
568
+ logger.error(f"Failed to use tiny-bert from transformers: {e}")
569
 
570
  # Download SNN weights if transformer weights were found
571
  if "transformer" in downloaded_files: