Upload main_model.py with huggingface_hub
Browse files- main_model.py +1 -1
main_model.py
CHANGED
|
@@ -523,7 +523,7 @@ def load_models():
|
|
| 523 |
print(f"Warning: {config.tokeniser_path} not found. Using default tokenizer.")
|
| 524 |
|
| 525 |
# Load trained model first to get correct vocab size
|
| 526 |
-
checkpoint = torch.load(config.
|
| 527 |
|
| 528 |
# Extract vocab size from the checkpoint's embedding layer
|
| 529 |
vocab_size_from_checkpoint = checkpoint['text_encoder.embedding.weight'].shape[0]
|
|
|
|
| 523 |
print(f"Warning: {config.tokeniser_path} not found. Using default tokenizer.")
|
| 524 |
|
| 525 |
# Load trained model first to get correct vocab size
|
| 526 |
+
checkpoint = torch.load(config.color_model_path, map_location=config.device)
|
| 527 |
|
| 528 |
# Extract vocab size from the checkpoint's embedding layer
|
| 529 |
vocab_size_from_checkpoint = checkpoint['text_encoder.embedding.weight'].shape[0]
|