dima806's picture
Upload 39 files
eeeaee6 verified
"""Inference utilities for salary prediction."""
import pickle
from pathlib import Path
import pandas as pd
import yaml
from src.schema import SalaryInput
from src.preprocessing import prepare_features
# Load model and artifacts at module level
model_path = Path("models/model.pkl")
if not model_path.exists():
raise FileNotFoundError(
f"Model file not found at {model_path}. Please run 'python -m src.train' first."
)
with open(model_path, "rb") as f:
artifacts = pickle.load(f) # nosec B301 - loading trusted local model artifact
model = artifacts["model"]
feature_columns = artifacts["feature_columns"]
# Load valid categories for input validation
valid_categories_path = Path("config/valid_categories.yaml")
if not valid_categories_path.exists():
raise FileNotFoundError(
f"Valid categories file not found at {valid_categories_path}. Please run 'python -m src.train' first."
)
with open(valid_categories_path, "r") as f:
valid_categories = yaml.safe_load(f)
# Load currency conversion rates
currency_rates_path = Path("config/currency_rates.yaml")
currency_rates = {}
if currency_rates_path.exists():
with open(currency_rates_path, "r") as f:
currency_rates = yaml.safe_load(f) or {}
def get_local_currency(country: str, salary_usd: float) -> dict | None:
"""Convert USD salary to local currency for a given country.
Returns:
Dict with code, name, rate, and salary_local, or None if unavailable.
"""
if country not in currency_rates:
return None
info = currency_rates[country]
return {
"code": info["code"],
"name": info["name"],
"rate": info["rate"],
"salary_local": round(salary_usd * info["rate"], 2),
}
def predict_salary(data: SalaryInput) -> float:
"""Predict salary based on input features.
Args:
data: SalaryInput model with developer information
Returns:
Predicted annual salary in USD
Raises:
ValueError: If country or education_level is not in valid categories
"""
# Validate input against valid categories from training
if data.country not in valid_categories["Country"]:
raise ValueError(
f"Invalid country: '{data.country}'. "
f"Must be one of {len(valid_categories['Country'])} valid countries. "
f"Check config/valid_categories.yaml for all valid values."
)
if data.education_level not in valid_categories["EdLevel"]:
raise ValueError(
f"Invalid education level: '{data.education_level}'. "
f"Must be one of {len(valid_categories['EdLevel'])} valid education levels. "
f"Check config/valid_categories.yaml for all valid values."
)
if data.dev_type not in valid_categories["DevType"]:
raise ValueError(
f"Invalid developer type: '{data.dev_type}'. "
f"Must be one of {len(valid_categories['DevType'])} valid developer types. "
f"Check config/valid_categories.yaml for all valid values."
)
if data.industry not in valid_categories["Industry"]:
raise ValueError(
f"Invalid industry: '{data.industry}'. "
f"Must be one of {len(valid_categories['Industry'])} valid industries. "
f"Check config/valid_categories.yaml for all valid values."
)
if data.age not in valid_categories["Age"]:
raise ValueError(
f"Invalid age: '{data.age}'. "
f"Must be one of {len(valid_categories['Age'])} valid age ranges. "
f"Check config/valid_categories.yaml for all valid values."
)
if data.ic_or_pm not in valid_categories["ICorPM"]:
raise ValueError(
f"Invalid IC or PM value: '{data.ic_or_pm}'. "
f"Must be one of {len(valid_categories['ICorPM'])} valid values. "
f"Check config/valid_categories.yaml for all valid values."
)
if data.org_size not in valid_categories["OrgSize"]:
raise ValueError(
f"Invalid organization size: '{data.org_size}'. "
f"Must be one of {len(valid_categories['OrgSize'])} valid sizes. "
f"Check config/valid_categories.yaml for all valid values."
)
if data.employment not in valid_categories["Employment"]:
raise ValueError(
f"Invalid employment status: '{data.employment}'. "
f"Must be one of {valid_categories['Employment']}. "
f"Check config/valid_categories.yaml for all valid values."
)
# Create a DataFrame with the input data
input_df = pd.DataFrame(
{
"Country": [data.country],
"YearsCode": [data.years_code],
"WorkExp": [data.work_exp],
"EdLevel": [data.education_level],
"DevType": [data.dev_type],
"Industry": [data.industry],
"Age": [data.age],
"ICorPM": [data.ic_or_pm],
"OrgSize": [data.org_size],
"Employment": [data.employment],
}
)
# Apply the same preprocessing as training
input_encoded = prepare_features(input_df)
# Ensure all feature columns from training are present and in correct order
# Use reindex to add missing columns with 0s and reorder in one operation
input_encoded = input_encoded.reindex(columns=feature_columns, fill_value=0)
# Make prediction
prediction = model.predict(input_encoded)[0]
# Ensure non-negative salary
return max(0.0, float(prediction))