Upload load_model_weights.py
Browse files- load_model_weights.py +39 -9
load_model_weights.py
CHANGED
|
@@ -67,7 +67,7 @@ def verify_token():
|
|
| 67 |
|
| 68 |
# Clean up token format - remove any "Bearer " prefix if present
|
| 69 |
if token.startswith("Bearer "):
|
| 70 |
-
token = token[7:].trip()
|
| 71 |
os.environ["HF_TOKEN"] = token # Store the cleaned token
|
| 72 |
|
| 73 |
token_length = len(token)
|
|
@@ -512,21 +512,25 @@ def download_model_files(repo_id_base: str, sub_dir: Optional[str] = None,
|
|
| 512 |
except Exception as e:
|
| 513 |
logger.error(f"Failed to download fallback model: {e}")
|
| 514 |
|
| 515 |
-
# Try public models if private repositories fail
|
| 516 |
if not transformer_path:
|
| 517 |
-
logger.warning("⚠️ Could not download from private repos, trying public models")
|
| 518 |
try:
|
| 519 |
-
# Try to download from public models directly using model IDs
|
| 520 |
public_models = [
|
| 521 |
-
"
|
| 522 |
-
"
|
| 523 |
-
"
|
|
|
|
|
|
|
|
|
|
| 524 |
]
|
| 525 |
|
| 526 |
for model_id in public_models:
|
| 527 |
-
logger.info(f"Trying public model: {model_id}")
|
| 528 |
try:
|
| 529 |
-
|
|
|
|
| 530 |
if transformer_path:
|
| 531 |
downloaded_files["transformer"] = transformer_path
|
| 532 |
logger.info(f"✅ Successfully downloaded weights from {model_id}")
|
|
@@ -536,6 +540,32 @@ def download_model_files(repo_id_base: str, sub_dir: Optional[str] = None,
|
|
| 536 |
|
| 537 |
except Exception as e:
|
| 538 |
logger.error(f"Failed to download public models: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 539 |
|
| 540 |
# Download SNN weights if transformer weights were found
|
| 541 |
if "transformer" in downloaded_files:
|
|
|
|
| 67 |
|
| 68 |
# Clean up token format - remove any "Bearer " prefix if present
|
| 69 |
if token.startswith("Bearer "):
|
| 70 |
+
token = token[7:].strip() # Fix typo: .trip() -> .strip()
|
| 71 |
os.environ["HF_TOKEN"] = token # Store the cleaned token
|
| 72 |
|
| 73 |
token_length = len(token)
|
|
|
|
| 512 |
except Exception as e:
|
| 513 |
logger.error(f"Failed to download fallback model: {e}")
|
| 514 |
|
| 515 |
+
# Try public models if private repositories fail - ADD MORE PUBLIC MODELS
|
| 516 |
if not transformer_path:
|
| 517 |
+
logger.warning("⚠️ Could not download from private repos, trying public models WITHOUT token")
|
| 518 |
try:
|
| 519 |
+
# Try to download from public models directly using model IDs that don't require authentication
|
| 520 |
public_models = [
|
| 521 |
+
"TinyLlama/TinyLlama-1.1B-Chat-v1.0", # Try this one first - it's small but good
|
| 522 |
+
"google/mobilevit-small", # Very small model
|
| 523 |
+
"prajjwal1/bert-tiny", # Extremely small BERT
|
| 524 |
+
"distilbert/distilbert-base-uncased", # Public DistilBERT
|
| 525 |
+
"google/bert_uncased_L-2_H-128_A-2", # Tiny BERT
|
| 526 |
+
"hf-internal-testing/tiny-random-gptj" # Super tiny test model
|
| 527 |
]
|
| 528 |
|
| 529 |
for model_id in public_models:
|
| 530 |
+
logger.info(f"Trying public model WITHOUT token: {model_id}")
|
| 531 |
try:
|
| 532 |
+
# IMPORTANT: Don't pass the token for these public models
|
| 533 |
+
transformer_path = download_file(model_id, "pytorch_model.bin", cache_dir, token=None)
|
| 534 |
if transformer_path:
|
| 535 |
downloaded_files["transformer"] = transformer_path
|
| 536 |
logger.info(f"✅ Successfully downloaded weights from {model_id}")
|
|
|
|
| 540 |
|
| 541 |
except Exception as e:
|
| 542 |
logger.error(f"Failed to download public models: {e}")
|
| 543 |
+
|
| 544 |
+
# If still no weights, try to use a model from the transformers library directly
|
| 545 |
+
if not transformer_path:
|
| 546 |
+
try:
|
| 547 |
+
# Try to use tiny-bert which should be bundled with transformers
|
| 548 |
+
logger.info("Attempting to use tiny-bert from transformers cache")
|
| 549 |
+
from transformers import AutoModel, AutoTokenizer
|
| 550 |
+
|
| 551 |
+
model_id = "prajjwal1/bert-tiny"
|
| 552 |
+
tiny_model = AutoModel.from_pretrained(model_id)
|
| 553 |
+
tiny_tokenizer = AutoTokenizer.from_pretrained(model_id)
|
| 554 |
+
|
| 555 |
+
# Save the model to a local file we can use
|
| 556 |
+
tmp_dir = os.path.join(cache_dir or "/tmp/tlm_cache", "tiny-bert")
|
| 557 |
+
os.makedirs(tmp_dir, exist_ok=True)
|
| 558 |
+
temp_file = os.path.join(tmp_dir, "pytorch_model.bin")
|
| 559 |
+
|
| 560 |
+
# Save model state dict
|
| 561 |
+
torch.save(tiny_model.state_dict(), temp_file)
|
| 562 |
+
logger.info(f"✅ Saved tiny-bert model to {temp_file}")
|
| 563 |
+
|
| 564 |
+
# Add to downloaded files
|
| 565 |
+
downloaded_files["transformer"] = temp_file
|
| 566 |
+
transformer_path = temp_file
|
| 567 |
+
except Exception as e:
|
| 568 |
+
logger.error(f"Failed to use tiny-bert from transformers: {e}")
|
| 569 |
|
| 570 |
# Download SNN weights if transformer weights were found
|
| 571 |
if "transformer" in downloaded_files:
|