TEZv's picture
Replace income categories with salary slider
82c668e
from __future__ import annotations
from dataclasses import dataclass
from .assumptions import (
AGE_BAND_FACTORS,
ALCOHOL_FACTORS,
BASELINE,
CHILDREN_STATUS_FACTORS,
EDUCATION_FACTORS,
FUTURE_CHILDREN_FACTORS,
HEIGHT_FACTORS,
HOUSING_FACTORS,
INCOME_CURVE_POINTS_UAH,
LANGUAGE_FACTORS,
MILITARY_STATUS_FACTORS,
PETS_FACTORS,
REGION_FACTORS,
RELATIONSHIP_STATUS_FACTORS,
RELOCATION_FACTORS,
SMOKING_FACTORS,
TARGET_POPULATION_FACTORS,
)
@dataclass(frozen=True)
class Criteria:
base_population: int
target_population: str
age_min: int
age_max: int
region_scope: str
relationship_status: str
min_height_cm: int
income_min_uah: int
education_level: str
children_status: str
future_children: str
military_status: str
relocation: str
housing: str
smoking: str
alcohol: str
language: str
pets: str
@dataclass(frozen=True)
class PoolEstimate:
conservative: float
central: float
optimistic: float
def age_factor(age_min: int, age_max: int) -> float:
bands = {
"18-24": (18, 24),
"25-34": (25, 34),
"35-44": (35, 44),
"45-54": (45, 54),
"55-70": (55, 70),
}
selected = 0.0
for label, (band_min, band_max) in bands.items():
overlap_min = max(age_min, band_min)
overlap_max = min(age_max, band_max)
if overlap_min <= overlap_max:
band_width = band_max - band_min + 1
overlap_width = overlap_max - overlap_min + 1
selected += AGE_BAND_FACTORS[label] * (overlap_width / band_width)
return max(0.01, min(selected, 1.0))
def height_factor(min_height_cm: int) -> float:
thresholds = sorted(HEIGHT_FACTORS)
if min_height_cm <= thresholds[0]:
return HEIGHT_FACTORS[thresholds[0]]
if min_height_cm >= thresholds[-1]:
return HEIGHT_FACTORS[thresholds[-1]]
lower = max(threshold for threshold in thresholds if threshold <= min_height_cm)
upper = min(threshold for threshold in thresholds if threshold >= min_height_cm)
if lower == upper:
return HEIGHT_FACTORS[lower]
ratio = (min_height_cm - lower) / (upper - lower)
return HEIGHT_FACTORS[lower] + ratio * (HEIGHT_FACTORS[upper] - HEIGHT_FACTORS[lower])
def income_factor(income_min_uah: int) -> float:
points = sorted(INCOME_CURVE_POINTS_UAH)
if income_min_uah <= points[0][0]:
return points[0][1]
if income_min_uah >= points[-1][0]:
return points[-1][1]
lower = max(point for point in points if point[0] <= income_min_uah)
upper = min(point for point in points if point[0] >= income_min_uah)
if lower == upper:
return lower[1]
ratio = (income_min_uah - lower[0]) / (upper[0] - lower[0])
return lower[1] + ratio * (upper[1] - lower[1])
def model_factors(criteria: Criteria) -> list[tuple[str, float]]:
return [
("Target population", TARGET_POPULATION_FACTORS[criteria.target_population]),
("Age range", age_factor(criteria.age_min, criteria.age_max)),
("Region scope", REGION_FACTORS[criteria.region_scope]),
("Relationship status", RELATIONSHIP_STATUS_FACTORS[criteria.relationship_status]),
("Minimum height", height_factor(criteria.min_height_cm)),
("Minimum income", income_factor(criteria.income_min_uah)),
("Education filter", EDUCATION_FACTORS[criteria.education_level]),
("Children status", CHILDREN_STATUS_FACTORS[criteria.children_status]),
("Future children", FUTURE_CHILDREN_FACTORS[criteria.future_children]),
("Military status", MILITARY_STATUS_FACTORS[criteria.military_status]),
("Relocation", RELOCATION_FACTORS[criteria.relocation]),
("Housing", HOUSING_FACTORS[criteria.housing]),
("Smoking", SMOKING_FACTORS[criteria.smoking]),
("Alcohol", ALCOHOL_FACTORS[criteria.alcohol]),
("Language", LANGUAGE_FACTORS[criteria.language]),
("Pets", PETS_FACTORS[criteria.pets]),
]
def central_estimate(criteria: Criteria) -> float:
value = float(criteria.base_population)
for _, factor in model_factors(criteria):
value *= factor
return value
def estimate_pool(criteria: Criteria) -> PoolEstimate:
central = central_estimate(criteria)
return PoolEstimate(
conservative=central * BASELINE.uncertainty_low,
central=central,
optimistic=central * BASELINE.uncertainty_high,
)
def sensitivity_table(criteria: Criteria) -> list[dict[str, float | str]]:
remaining = float(criteria.base_population)
rows: list[dict[str, float | str]] = [
{
"factor": "Baseline",
"coefficient": 1.0,
"remaining": remaining,
"percent_of_baseline": 100.0,
}
]
for label, coefficient in model_factors(criteria):
remaining *= coefficient
rows.append(
{
"factor": label,
"coefficient": round(coefficient, 4),
"remaining": round(remaining, 2),
"percent_of_baseline": round((remaining / criteria.base_population) * 100, 6),
}
)
return rows