TEZv's picture
Format counts and add target population
ac4e07f
raw
history blame
3.53 kB
from __future__ import annotations
from dataclasses import dataclass
from .assumptions import (
AGE_BAND_FACTORS,
BASELINE,
EDUCATION_FACTORS,
HEIGHT_FACTORS,
INCOME_FACTORS,
REGION_FACTORS,
RELATIONSHIP_STATUS_FACTORS,
TARGET_POPULATION_FACTORS,
)
@dataclass(frozen=True)
class Criteria:
base_population: int
target_population: str
age_min: int
age_max: int
region_scope: str
relationship_status: str
min_height_cm: int
income_level: str
education_level: str
@dataclass(frozen=True)
class PoolEstimate:
conservative: float
central: float
optimistic: float
def age_factor(age_min: int, age_max: int) -> float:
bands = {
"18-24": (18, 24),
"25-34": (25, 34),
"35-44": (35, 44),
"45-54": (45, 54),
"55-70": (55, 70),
}
selected = 0.0
for label, (band_min, band_max) in bands.items():
overlap_min = max(age_min, band_min)
overlap_max = min(age_max, band_max)
if overlap_min <= overlap_max:
band_width = band_max - band_min + 1
overlap_width = overlap_max - overlap_min + 1
selected += AGE_BAND_FACTORS[label] * (overlap_width / band_width)
return max(0.01, min(selected, 1.0))
def height_factor(min_height_cm: int) -> float:
thresholds = sorted(HEIGHT_FACTORS)
if min_height_cm <= thresholds[0]:
return HEIGHT_FACTORS[thresholds[0]]
if min_height_cm >= thresholds[-1]:
return HEIGHT_FACTORS[thresholds[-1]]
lower = max(threshold for threshold in thresholds if threshold <= min_height_cm)
upper = min(threshold for threshold in thresholds if threshold >= min_height_cm)
if lower == upper:
return HEIGHT_FACTORS[lower]
ratio = (min_height_cm - lower) / (upper - lower)
return HEIGHT_FACTORS[lower] + ratio * (HEIGHT_FACTORS[upper] - HEIGHT_FACTORS[lower])
def model_factors(criteria: Criteria) -> list[tuple[str, float]]:
return [
("Target population", TARGET_POPULATION_FACTORS[criteria.target_population]),
("Age range", age_factor(criteria.age_min, criteria.age_max)),
("Region scope", REGION_FACTORS[criteria.region_scope]),
("Relationship status", RELATIONSHIP_STATUS_FACTORS[criteria.relationship_status]),
("Minimum height", height_factor(criteria.min_height_cm)),
("Income threshold", INCOME_FACTORS[criteria.income_level]),
("Education filter", EDUCATION_FACTORS[criteria.education_level]),
]
def central_estimate(criteria: Criteria) -> float:
value = float(criteria.base_population)
for _, factor in model_factors(criteria):
value *= factor
return value
def estimate_pool(criteria: Criteria) -> PoolEstimate:
central = central_estimate(criteria)
return PoolEstimate(
conservative=central * BASELINE.uncertainty_low,
central=central,
optimistic=central * BASELINE.uncertainty_high,
)
def sensitivity_table(criteria: Criteria) -> list[dict[str, float | str]]:
remaining = float(criteria.base_population)
rows: list[dict[str, float | str]] = [
{"factor": "Baseline", "coefficient": 1.0, "remaining": remaining}
]
for label, coefficient in model_factors(criteria):
remaining *= coefficient
rows.append(
{
"factor": label,
"coefficient": round(coefficient, 4),
"remaining": round(remaining, 2),
}
)
return rows