| | |
| | import torch |
| | import pandas as pd |
| | from model import TabularModel |
| |
|
| | def load_model_and_predict(data): |
| | |
| | checkpoint = torch.load('ev_classifier_model.pth') |
| | model = TabularModel(input_size=9, hidden_sizes=[128, 64, 32], output_size=2) |
| | model.load_state_dict(checkpoint['model_state_dict']) |
| | model.eval() |
| | |
| | |
| | scaler = checkpoint['scaler'] |
| | label_encoders = checkpoint['label_encoders'] |
| | |
| | |
| | |
| | |
| | return predictions |
| |
|
| | |
| | if __name__ == "__main__": |
| | sample_data = pd.DataFrame({ |
| | 'model_year': [2021], |
| | 'make': ['TESLA'], |
| | 'model': ['MODEL 3'], |
| | |
| | }) |
| | |
| | prediction = load_model_and_predict(sample_data) |
| | print(f"Prediction: {prediction}") |