from __future__ import annotations from dataclasses import dataclass from .assumptions import ( AGE_BAND_FACTORS, BASELINE, EDUCATION_FACTORS, HEIGHT_FACTORS, INCOME_FACTORS, REGION_FACTORS, RELATIONSHIP_STATUS_FACTORS, TARGET_POPULATION_FACTORS, ) @dataclass(frozen=True) class Criteria: base_population: int target_population: str age_min: int age_max: int region_scope: str relationship_status: str min_height_cm: int income_level: str education_level: str @dataclass(frozen=True) class PoolEstimate: conservative: float central: float optimistic: float def age_factor(age_min: int, age_max: int) -> float: bands = { "18-24": (18, 24), "25-34": (25, 34), "35-44": (35, 44), "45-54": (45, 54), "55-70": (55, 70), } selected = 0.0 for label, (band_min, band_max) in bands.items(): overlap_min = max(age_min, band_min) overlap_max = min(age_max, band_max) if overlap_min <= overlap_max: band_width = band_max - band_min + 1 overlap_width = overlap_max - overlap_min + 1 selected += AGE_BAND_FACTORS[label] * (overlap_width / band_width) return max(0.01, min(selected, 1.0)) def height_factor(min_height_cm: int) -> float: thresholds = sorted(HEIGHT_FACTORS) if min_height_cm <= thresholds[0]: return HEIGHT_FACTORS[thresholds[0]] if min_height_cm >= thresholds[-1]: return HEIGHT_FACTORS[thresholds[-1]] lower = max(threshold for threshold in thresholds if threshold <= min_height_cm) upper = min(threshold for threshold in thresholds if threshold >= min_height_cm) if lower == upper: return HEIGHT_FACTORS[lower] ratio = (min_height_cm - lower) / (upper - lower) return HEIGHT_FACTORS[lower] + ratio * (HEIGHT_FACTORS[upper] - HEIGHT_FACTORS[lower]) def model_factors(criteria: Criteria) -> list[tuple[str, float]]: return [ ("Target population", TARGET_POPULATION_FACTORS[criteria.target_population]), ("Age range", age_factor(criteria.age_min, criteria.age_max)), ("Region scope", REGION_FACTORS[criteria.region_scope]), ("Relationship status", RELATIONSHIP_STATUS_FACTORS[criteria.relationship_status]), ("Minimum height", height_factor(criteria.min_height_cm)), ("Income threshold", INCOME_FACTORS[criteria.income_level]), ("Education filter", EDUCATION_FACTORS[criteria.education_level]), ] def central_estimate(criteria: Criteria) -> float: value = float(criteria.base_population) for _, factor in model_factors(criteria): value *= factor return value def estimate_pool(criteria: Criteria) -> PoolEstimate: central = central_estimate(criteria) return PoolEstimate( conservative=central * BASELINE.uncertainty_low, central=central, optimistic=central * BASELINE.uncertainty_high, ) def sensitivity_table(criteria: Criteria) -> list[dict[str, float | str]]: remaining = float(criteria.base_population) rows: list[dict[str, float | str]] = [ {"factor": "Baseline", "coefficient": 1.0, "remaining": remaining} ] for label, coefficient in model_factors(criteria): remaining *= coefficient rows.append( { "factor": label, "coefficient": round(coefficient, 4), "remaining": round(remaining, 2), } ) return rows