synthetic-np / SyntheticModelDriver.py
fivedollarwitch's picture
Upload SyntheticModelDriver.py
821648b verified
import math
from typing import Tuple, Union
from huggingface_hub import hf_hub_download
import torch
import pickle
import pandas as pd
import importlib.util
import sys
'''
This class must do two things
1) The constructor must load the model
2) This class must implement a method called `run_inference` that takes the input data and returns a tuple
of float, str representing the predicted sale price and the predicted sale date.
'''
class SyntheticModelDriver:
def __init__(self):
print(f"Loading model...")
self.model, self.scaler, self.label_encoders, self.min_date = self.load_model()
print(f"Model loaded.")
def run_inference(self, input_data: dict[str, Union[str, int, float]]) -> Tuple[float, str]:
categorical_columns = ['city', 'state']
# Encode categorical variables
for col in categorical_columns:
input_data[col] = self.label_encoders[col].transform([input_data[col]])[0]
# Normalize numerical variables
numerical_values = [[
input_data['price'],
input_data['beds'],
input_data['baths'],
input_data['sqft'],
input_data['lot_size'],
input_data['year_built'],
input_data['days_on_market'],
input_data['latitude'],
input_data['longitude'],
input_data['hoa_dues'],
input_data['property_type'],
]]
numerical_values = self.scaler.transform(numerical_values)
# Convert dates to days since reference
if 'last_sale_date' in input_data and input_data['last_sale_date'] != None:
last_sale_date = pd.to_datetime(input_data['last_sale_date']).tz_localize(None)
input_data['last_sale_date'] = (last_sale_date - self.min_date).days
# Convert to tensors
x_cat = torch.tensor([[input_data[col] for col in categorical_columns]], dtype=torch.long)
x_num = torch.tensor(numerical_values, dtype=torch.float32)
# Model inference
with torch.no_grad():
price_pred, date_pred = self.model(x_cat, x_num)
predicted_sale_price = price_pred.item()
predicted_sale_date = torch.argmax(date_pred, dim=1).item()
# Convert sale_date back to a readable date
predicted_sale_date = (self.min_date + pd.Timedelta(days=predicted_sale_date)).strftime('%Y-%m-%d')
if math.isnan(predicted_sale_price): predicted_sale_price = 1.0
print(f"Predicted Sale Price: ${predicted_sale_price:,.2f}")
print(f"Predicted Sale Date: {predicted_sale_date}")
return predicted_sale_price, predicted_sale_date
def load_model(self):
model_file, scaler_file, label_encoders_file, min_date_file, model_class_file = self._download_model_files()
self._import_model_class(model_class_file)
# Load the model
model = torch.load(model_file, weights_only=False)
model.eval()
# Load additional artifacts
with open(scaler_file, 'rb') as f:
scaler = pickle.load(f)
with open(label_encoders_file, 'rb') as f:
label_encoders = pickle.load(f)
with open(min_date_file, 'rb') as f:
min_date = pickle.load(f)
return model, scaler, label_encoders, min_date
def _download_model_files(self):
model_path = "fivedollarwitch/synthetic-np"
# Download the model files from the Hugging Face Hub
model_file = hf_hub_download(repo_id=model_path, filename="model.pt")
scaler_file = hf_hub_download(repo_id=model_path, filename="scaler.pkl")
label_encoders_file = hf_hub_download(repo_id=model_path, filename="label_encoders.pkl")
min_date_file = hf_hub_download(repo_id=model_path, filename="min_date.pkl")
model_class_file = hf_hub_download(repo_id=model_path, filename="SyntheticModel.py")
# Load the model and artifacts
return model_file, scaler_file, label_encoders_file, min_date_file, model_class_file
def _import_model_class(self, model_class_file):
# Reference docs here: https://docs.python.org/3/library/importlib.html#importlib.util.spec_from_loader
module_name = "SyntheticModel"
spec = importlib.util.spec_from_file_location(module_name, model_class_file)
model_module = importlib.util.module_from_spec(spec)
sys.modules[module_name] = model_module
spec.loader.exec_module(model_module)