"""Inference utilities for salary prediction.""" import pickle from pathlib import Path import pandas as pd import yaml from src.schema import SalaryInput from src.preprocessing import prepare_features # Load model and artifacts at module level model_path = Path("models/model.pkl") if not model_path.exists(): raise FileNotFoundError( f"Model file not found at {model_path}. Please run 'python -m src.train' first." ) with open(model_path, "rb") as f: artifacts = pickle.load(f) # nosec B301 - loading trusted local model artifact model = artifacts["model"] feature_columns = artifacts["feature_columns"] # Load valid categories for input validation valid_categories_path = Path("config/valid_categories.yaml") if not valid_categories_path.exists(): raise FileNotFoundError( f"Valid categories file not found at {valid_categories_path}. Please run 'python -m src.train' first." ) with open(valid_categories_path, "r") as f: valid_categories = yaml.safe_load(f) # Load currency conversion rates currency_rates_path = Path("config/currency_rates.yaml") currency_rates = {} if currency_rates_path.exists(): with open(currency_rates_path, "r") as f: currency_rates = yaml.safe_load(f) or {} def get_local_currency(country: str, salary_usd: float) -> dict | None: """Convert USD salary to local currency for a given country. Returns: Dict with code, name, rate, and salary_local, or None if unavailable. """ if country not in currency_rates: return None info = currency_rates[country] return { "code": info["code"], "name": info["name"], "rate": info["rate"], "salary_local": round(salary_usd * info["rate"], 2), } def predict_salary(data: SalaryInput) -> float: """Predict salary based on input features. Args: data: SalaryInput model with developer information Returns: Predicted annual salary in USD Raises: ValueError: If country or education_level is not in valid categories """ # Validate input against valid categories from training if data.country not in valid_categories["Country"]: raise ValueError( f"Invalid country: '{data.country}'. " f"Must be one of {len(valid_categories['Country'])} valid countries. " f"Check config/valid_categories.yaml for all valid values." ) if data.education_level not in valid_categories["EdLevel"]: raise ValueError( f"Invalid education level: '{data.education_level}'. " f"Must be one of {len(valid_categories['EdLevel'])} valid education levels. " f"Check config/valid_categories.yaml for all valid values." ) if data.dev_type not in valid_categories["DevType"]: raise ValueError( f"Invalid developer type: '{data.dev_type}'. " f"Must be one of {len(valid_categories['DevType'])} valid developer types. " f"Check config/valid_categories.yaml for all valid values." ) if data.industry not in valid_categories["Industry"]: raise ValueError( f"Invalid industry: '{data.industry}'. " f"Must be one of {len(valid_categories['Industry'])} valid industries. " f"Check config/valid_categories.yaml for all valid values." ) if data.age not in valid_categories["Age"]: raise ValueError( f"Invalid age: '{data.age}'. " f"Must be one of {len(valid_categories['Age'])} valid age ranges. " f"Check config/valid_categories.yaml for all valid values." ) if data.ic_or_pm not in valid_categories["ICorPM"]: raise ValueError( f"Invalid IC or PM value: '{data.ic_or_pm}'. " f"Must be one of {len(valid_categories['ICorPM'])} valid values. " f"Check config/valid_categories.yaml for all valid values." ) if data.org_size not in valid_categories["OrgSize"]: raise ValueError( f"Invalid organization size: '{data.org_size}'. " f"Must be one of {len(valid_categories['OrgSize'])} valid sizes. " f"Check config/valid_categories.yaml for all valid values." ) if data.employment not in valid_categories["Employment"]: raise ValueError( f"Invalid employment status: '{data.employment}'. " f"Must be one of {valid_categories['Employment']}. " f"Check config/valid_categories.yaml for all valid values." ) # Create a DataFrame with the input data input_df = pd.DataFrame( { "Country": [data.country], "YearsCode": [data.years_code], "WorkExp": [data.work_exp], "EdLevel": [data.education_level], "DevType": [data.dev_type], "Industry": [data.industry], "Age": [data.age], "ICorPM": [data.ic_or_pm], "OrgSize": [data.org_size], "Employment": [data.employment], } ) # Apply the same preprocessing as training input_encoded = prepare_features(input_df) # Ensure all feature columns from training are present and in correct order # Use reindex to add missing columns with 0s and reorder in one operation input_encoded = input_encoded.reindex(columns=feature_columns, fill_value=0) # Make prediction prediction = model.predict(input_encoded)[0] # Ensure non-negative salary return max(0.0, float(prediction))