|
|
import math |
|
|
from typing import Tuple, Union |
|
|
from huggingface_hub import hf_hub_download |
|
|
import torch |
|
|
import pickle |
|
|
import pandas as pd |
|
|
import importlib.util |
|
|
import sys |
|
|
|
|
|
|
|
|
''' |
|
|
This class must do two things |
|
|
1) The constructor must load the model |
|
|
2) This class must implement a method called `run_inference` that takes the input data and returns a tuple |
|
|
of float, str representing the predicted sale price and the predicted sale date. |
|
|
''' |
|
|
class SyntheticModelDriver: |
|
|
|
|
|
|
|
|
def __init__(self): |
|
|
print(f"Loading model...") |
|
|
self.model, self.scaler, self.label_encoders, self.min_date = self.load_model() |
|
|
print(f"Model loaded.") |
|
|
|
|
|
|
|
|
def run_inference(self, input_data: dict[str, Union[str, int, float]]) -> Tuple[float, str]: |
|
|
|
|
|
categorical_columns = ['city', 'state'] |
|
|
|
|
|
|
|
|
for col in categorical_columns: |
|
|
input_data[col] = self.label_encoders[col].transform([input_data[col]])[0] |
|
|
|
|
|
|
|
|
numerical_values = [[ |
|
|
input_data['price'], |
|
|
input_data['beds'], |
|
|
input_data['baths'], |
|
|
input_data['sqft'], |
|
|
input_data['lot_size'], |
|
|
input_data['year_built'], |
|
|
input_data['days_on_market'], |
|
|
input_data['latitude'], |
|
|
input_data['longitude'], |
|
|
input_data['hoa_dues'], |
|
|
input_data['property_type'], |
|
|
]] |
|
|
numerical_values = self.scaler.transform(numerical_values) |
|
|
|
|
|
|
|
|
if 'last_sale_date' in input_data and input_data['last_sale_date'] != None: |
|
|
last_sale_date = pd.to_datetime(input_data['last_sale_date']).tz_localize(None) |
|
|
input_data['last_sale_date'] = (last_sale_date - self.min_date).days |
|
|
|
|
|
|
|
|
x_cat = torch.tensor([[input_data[col] for col in categorical_columns]], dtype=torch.long) |
|
|
x_num = torch.tensor(numerical_values, dtype=torch.float32) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
price_pred, date_pred = self.model(x_cat, x_num) |
|
|
predicted_sale_price = price_pred.item() |
|
|
predicted_sale_date = torch.argmax(date_pred, dim=1).item() |
|
|
|
|
|
|
|
|
predicted_sale_date = (self.min_date + pd.Timedelta(days=predicted_sale_date)).strftime('%Y-%m-%d') |
|
|
|
|
|
if math.isnan(predicted_sale_price): predicted_sale_price = 1.0 |
|
|
|
|
|
print(f"Predicted Sale Price: ${predicted_sale_price:,.2f}") |
|
|
print(f"Predicted Sale Date: {predicted_sale_date}") |
|
|
return predicted_sale_price, predicted_sale_date |
|
|
|
|
|
|
|
|
def load_model(self): |
|
|
model_file, scaler_file, label_encoders_file, min_date_file, model_class_file = self._download_model_files() |
|
|
|
|
|
self._import_model_class(model_class_file) |
|
|
|
|
|
|
|
|
model = torch.load(model_file, weights_only=False) |
|
|
model.eval() |
|
|
|
|
|
|
|
|
with open(scaler_file, 'rb') as f: |
|
|
scaler = pickle.load(f) |
|
|
|
|
|
with open(label_encoders_file, 'rb') as f: |
|
|
label_encoders = pickle.load(f) |
|
|
|
|
|
with open(min_date_file, 'rb') as f: |
|
|
min_date = pickle.load(f) |
|
|
|
|
|
return model, scaler, label_encoders, min_date |
|
|
|
|
|
|
|
|
def _download_model_files(self): |
|
|
model_path = "fivedollarwitch/synthetic-np" |
|
|
|
|
|
|
|
|
model_file = hf_hub_download(repo_id=model_path, filename="model.pt") |
|
|
scaler_file = hf_hub_download(repo_id=model_path, filename="scaler.pkl") |
|
|
label_encoders_file = hf_hub_download(repo_id=model_path, filename="label_encoders.pkl") |
|
|
min_date_file = hf_hub_download(repo_id=model_path, filename="min_date.pkl") |
|
|
model_class_file = hf_hub_download(repo_id=model_path, filename="SyntheticModel.py") |
|
|
|
|
|
|
|
|
return model_file, scaler_file, label_encoders_file, min_date_file, model_class_file |
|
|
|
|
|
def _import_model_class(self, model_class_file): |
|
|
|
|
|
module_name = "SyntheticModel" |
|
|
spec = importlib.util.spec_from_file_location(module_name, model_class_file) |
|
|
model_module = importlib.util.module_from_spec(spec) |
|
|
sys.modules[module_name] = model_module |
|
|
spec.loader.exec_module(model_module) |
|
|
|