Spaces:
Sleeping
Sleeping
File size: 2,982 Bytes
4a12ac6 32f5adb 4a12ac6 52e07d8 4a12ac6 52e07d8 4a12ac6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 |
import torch
from model import FuturesModel, CustomTokenizer, build_vocabulary
def load_model_and_tokenizer(
model_path='checkpoint_best.pt',
dataset_path='futures_dataset_v2.json',
vocab_size=5000,
):
"""Loads the trained FuturesModel and CustomTokenizer."""
# 1. Build vocabulary and tokenizer
print("Building vocabulary from dataset...")
vocab_dict = build_vocabulary(dataset_path, vocab_size=vocab_size)
tokenizer = CustomTokenizer(vocab_dict)
print(f"Vocabulary size: {len(vocab_dict)}")
# 2. Initialize the model with the same architecture
print("Initializing model...")
model = FuturesModel(
vocab_size=len(vocab_dict),
n_axes=12,
d_model=256,
n_head=8,
n_layers=4,
n_experts=8,
dropout=0.1
)
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
# 3. Load the saved state dictionary
print(f"Loading model weights from {model_path}...")
device = 'cuda' if torch.cuda.is_available() else 'cpu'
checkpoint = torch.load(model_path, map_location=device)
# The state dict is nested in the checkpoint
model.load_state_dict(checkpoint['model_state_dict'])
# 4. Set the model to evaluation mode
model.eval()
print("Model set to evaluation mode.")
return model, tokenizer
if __name__ == "__main__":
print("="*80)
print("Loading Futures Prediction Model")
print("="*80)
# Correct paths for running from the root directory
model_path = 'checkpoint_best.pt'
dataset_path = 'futures_dataset_v2.json'
try:
model, tokenizer = load_model_and_tokenizer(
model_path=model_path,
dataset_path=dataset_path
)
print("\n✅ Model and tokenizer loaded successfully!")
# Example usage
print("\n--- Example Usage ---")
text = "In a future dominated by hyper-automation, societal structures adapt to new forms of labor and community."
print(f"Input text: '{text}'")
token_ids = tokenizer.encode(text)
tokens_tensor = torch.LongTensor(token_ids).unsqueeze(0) # Add batch dimension
print(f"Encoded tokens (first 10): {tokens_tensor[0, :10]}...")
with torch.no_grad():
axis_logits, lm_logits, stats = model(tokens_tensor)
axis_predictions = torch.sigmoid(axis_logits)
print("\nPredicted Axis Weights:")
axis_names = [
"HyperAuto", "HumanTech", "Abundant", "Individual",
"Community", "Global", "Crisis", "Restore",
"Adapt", "Digital", "Physical", "Collab"
]
for name, weight in zip(axis_names, axis_predictions[0]):
print(f" - {name:12s}: {weight:.4f}")
except Exception as e:
print(f"\n❌ An error occurred during loading: {e}")
import traceback
traceback.print_exc()
print("\n" + "="*80)
|