fivedollarwitch commited on
Commit
821648b
·
verified ·
1 Parent(s): a32ebb2

Upload SyntheticModelDriver.py

Browse files
Files changed (1) hide show
  1. SyntheticModelDriver.py +116 -0
SyntheticModelDriver.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Tuple, Union
3
+ from huggingface_hub import hf_hub_download
4
+ import torch
5
+ import pickle
6
+ import pandas as pd
7
+ import importlib.util
8
+ import sys
9
+
10
+
11
+ '''
12
+ This class must do two things
13
+ 1) The constructor must load the model
14
+ 2) This class must implement a method called `run_inference` that takes the input data and returns a tuple
15
+ of float, str representing the predicted sale price and the predicted sale date.
16
+ '''
17
+ class SyntheticModelDriver:
18
+
19
+
20
+ def __init__(self):
21
+ print(f"Loading model...")
22
+ self.model, self.scaler, self.label_encoders, self.min_date = self.load_model()
23
+ print(f"Model loaded.")
24
+
25
+
26
+ def run_inference(self, input_data: dict[str, Union[str, int, float]]) -> Tuple[float, str]:
27
+
28
+ categorical_columns = ['city', 'state']
29
+
30
+ # Encode categorical variables
31
+ for col in categorical_columns:
32
+ input_data[col] = self.label_encoders[col].transform([input_data[col]])[0]
33
+
34
+ # Normalize numerical variables
35
+ numerical_values = [[
36
+ input_data['price'],
37
+ input_data['beds'],
38
+ input_data['baths'],
39
+ input_data['sqft'],
40
+ input_data['lot_size'],
41
+ input_data['year_built'],
42
+ input_data['days_on_market'],
43
+ input_data['latitude'],
44
+ input_data['longitude'],
45
+ input_data['hoa_dues'],
46
+ input_data['property_type'],
47
+ ]]
48
+ numerical_values = self.scaler.transform(numerical_values)
49
+
50
+ # Convert dates to days since reference
51
+ if 'last_sale_date' in input_data and input_data['last_sale_date'] != None:
52
+ last_sale_date = pd.to_datetime(input_data['last_sale_date']).tz_localize(None)
53
+ input_data['last_sale_date'] = (last_sale_date - self.min_date).days
54
+
55
+ # Convert to tensors
56
+ x_cat = torch.tensor([[input_data[col] for col in categorical_columns]], dtype=torch.long)
57
+ x_num = torch.tensor(numerical_values, dtype=torch.float32)
58
+
59
+ # Model inference
60
+ with torch.no_grad():
61
+ price_pred, date_pred = self.model(x_cat, x_num)
62
+ predicted_sale_price = price_pred.item()
63
+ predicted_sale_date = torch.argmax(date_pred, dim=1).item()
64
+
65
+ # Convert sale_date back to a readable date
66
+ predicted_sale_date = (self.min_date + pd.Timedelta(days=predicted_sale_date)).strftime('%Y-%m-%d')
67
+
68
+ if math.isnan(predicted_sale_price): predicted_sale_price = 1.0
69
+
70
+ print(f"Predicted Sale Price: ${predicted_sale_price:,.2f}")
71
+ print(f"Predicted Sale Date: {predicted_sale_date}")
72
+ return predicted_sale_price, predicted_sale_date
73
+
74
+
75
+ def load_model(self):
76
+ model_file, scaler_file, label_encoders_file, min_date_file, model_class_file = self._download_model_files()
77
+
78
+ self._import_model_class(model_class_file)
79
+
80
+ # Load the model
81
+ model = torch.load(model_file, weights_only=False)
82
+ model.eval()
83
+
84
+ # Load additional artifacts
85
+ with open(scaler_file, 'rb') as f:
86
+ scaler = pickle.load(f)
87
+
88
+ with open(label_encoders_file, 'rb') as f:
89
+ label_encoders = pickle.load(f)
90
+
91
+ with open(min_date_file, 'rb') as f:
92
+ min_date = pickle.load(f)
93
+
94
+ return model, scaler, label_encoders, min_date
95
+
96
+
97
+ def _download_model_files(self):
98
+ model_path = "fivedollarwitch/synthetic-np"
99
+
100
+ # Download the model files from the Hugging Face Hub
101
+ model_file = hf_hub_download(repo_id=model_path, filename="model.pt")
102
+ scaler_file = hf_hub_download(repo_id=model_path, filename="scaler.pkl")
103
+ label_encoders_file = hf_hub_download(repo_id=model_path, filename="label_encoders.pkl")
104
+ min_date_file = hf_hub_download(repo_id=model_path, filename="min_date.pkl")
105
+ model_class_file = hf_hub_download(repo_id=model_path, filename="SyntheticModel.py")
106
+
107
+ # Load the model and artifacts
108
+ return model_file, scaler_file, label_encoders_file, min_date_file, model_class_file
109
+
110
+ def _import_model_class(self, model_class_file):
111
+ # Reference docs here: https://docs.python.org/3/library/importlib.html#importlib.util.spec_from_loader
112
+ module_name = "SyntheticModel"
113
+ spec = importlib.util.spec_from_file_location(module_name, model_class_file)
114
+ model_module = importlib.util.module_from_spec(spec)
115
+ sys.modules[module_name] = model_module
116
+ spec.loader.exec_module(model_module)