Spaces:
Sleeping
Sleeping
| import torch | |
| from model import FuturesModel, CustomTokenizer, build_vocabulary | |
| def load_model_and_tokenizer( | |
| model_path='checkpoint_best.pt', | |
| dataset_path='futures_dataset_v2.json', | |
| vocab_size=5000, | |
| ): | |
| """Loads the trained FuturesModel and CustomTokenizer.""" | |
| # 1. Build vocabulary and tokenizer | |
| print("Building vocabulary from dataset...") | |
| vocab_dict = build_vocabulary(dataset_path, vocab_size=vocab_size) | |
| tokenizer = CustomTokenizer(vocab_dict) | |
| print(f"Vocabulary size: {len(vocab_dict)}") | |
| # 2. Initialize the model with the same architecture | |
| print("Initializing model...") | |
| model = FuturesModel( | |
| vocab_size=len(vocab_dict), | |
| n_axes=12, | |
| d_model=256, | |
| n_head=8, | |
| n_layers=4, | |
| n_experts=8, | |
| dropout=0.1 | |
| ) | |
| print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}") | |
| # 3. Load the saved state dictionary | |
| print(f"Loading model weights from {model_path}...") | |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| checkpoint = torch.load(model_path, map_location=device) | |
| # The state dict is nested in the checkpoint | |
| model.load_state_dict(checkpoint['model_state_dict']) | |
| # 4. Set the model to evaluation mode | |
| model.eval() | |
| print("Model set to evaluation mode.") | |
| return model, tokenizer | |
| if __name__ == "__main__": | |
| print("="*80) | |
| print("Loading Futures Prediction Model") | |
| print("="*80) | |
| # Correct paths for running from the root directory | |
| model_path = 'checkpoint_best.pt' | |
| dataset_path = 'futures_dataset_v2.json' | |
| try: | |
| model, tokenizer = load_model_and_tokenizer( | |
| model_path=model_path, | |
| dataset_path=dataset_path | |
| ) | |
| print("\n✅ Model and tokenizer loaded successfully!") | |
| # Example usage | |
| print("\n--- Example Usage ---") | |
| text = "In a future dominated by hyper-automation, societal structures adapt to new forms of labor and community." | |
| print(f"Input text: '{text}'") | |
| token_ids = tokenizer.encode(text) | |
| tokens_tensor = torch.LongTensor(token_ids).unsqueeze(0) # Add batch dimension | |
| print(f"Encoded tokens (first 10): {tokens_tensor[0, :10]}...") | |
| with torch.no_grad(): | |
| axis_logits, lm_logits, stats = model(tokens_tensor) | |
| axis_predictions = torch.sigmoid(axis_logits) | |
| print("\nPredicted Axis Weights:") | |
| axis_names = [ | |
| "HyperAuto", "HumanTech", "Abundant", "Individual", | |
| "Community", "Global", "Crisis", "Restore", | |
| "Adapt", "Digital", "Physical", "Collab" | |
| ] | |
| for name, weight in zip(axis_names, axis_predictions[0]): | |
| print(f" - {name:12s}: {weight:.4f}") | |
| except Exception as e: | |
| print(f"\n❌ An error occurred during loading: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| print("\n" + "="*80) | |