File size: 3,531 Bytes
3c0e2ec
 
 
 
 
 
 
 
 
 
 
 
ac4e07f
3c0e2ec
 
 
 
 
 
ac4e07f
3c0e2ec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ac4e07f
3c0e2ec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
from __future__ import annotations

from dataclasses import dataclass

from .assumptions import (
    AGE_BAND_FACTORS,
    BASELINE,
    EDUCATION_FACTORS,
    HEIGHT_FACTORS,
    INCOME_FACTORS,
    REGION_FACTORS,
    RELATIONSHIP_STATUS_FACTORS,
    TARGET_POPULATION_FACTORS,
)


@dataclass(frozen=True)
class Criteria:
    base_population: int
    target_population: str
    age_min: int
    age_max: int
    region_scope: str
    relationship_status: str
    min_height_cm: int
    income_level: str
    education_level: str


@dataclass(frozen=True)
class PoolEstimate:
    conservative: float
    central: float
    optimistic: float


def age_factor(age_min: int, age_max: int) -> float:
    bands = {
        "18-24": (18, 24),
        "25-34": (25, 34),
        "35-44": (35, 44),
        "45-54": (45, 54),
        "55-70": (55, 70),
    }
    selected = 0.0
    for label, (band_min, band_max) in bands.items():
        overlap_min = max(age_min, band_min)
        overlap_max = min(age_max, band_max)
        if overlap_min <= overlap_max:
            band_width = band_max - band_min + 1
            overlap_width = overlap_max - overlap_min + 1
            selected += AGE_BAND_FACTORS[label] * (overlap_width / band_width)
    return max(0.01, min(selected, 1.0))


def height_factor(min_height_cm: int) -> float:
    thresholds = sorted(HEIGHT_FACTORS)
    if min_height_cm <= thresholds[0]:
        return HEIGHT_FACTORS[thresholds[0]]
    if min_height_cm >= thresholds[-1]:
        return HEIGHT_FACTORS[thresholds[-1]]

    lower = max(threshold for threshold in thresholds if threshold <= min_height_cm)
    upper = min(threshold for threshold in thresholds if threshold >= min_height_cm)
    if lower == upper:
        return HEIGHT_FACTORS[lower]

    ratio = (min_height_cm - lower) / (upper - lower)
    return HEIGHT_FACTORS[lower] + ratio * (HEIGHT_FACTORS[upper] - HEIGHT_FACTORS[lower])


def model_factors(criteria: Criteria) -> list[tuple[str, float]]:
    return [
        ("Target population", TARGET_POPULATION_FACTORS[criteria.target_population]),
        ("Age range", age_factor(criteria.age_min, criteria.age_max)),
        ("Region scope", REGION_FACTORS[criteria.region_scope]),
        ("Relationship status", RELATIONSHIP_STATUS_FACTORS[criteria.relationship_status]),
        ("Minimum height", height_factor(criteria.min_height_cm)),
        ("Income threshold", INCOME_FACTORS[criteria.income_level]),
        ("Education filter", EDUCATION_FACTORS[criteria.education_level]),
    ]


def central_estimate(criteria: Criteria) -> float:
    value = float(criteria.base_population)
    for _, factor in model_factors(criteria):
        value *= factor
    return value


def estimate_pool(criteria: Criteria) -> PoolEstimate:
    central = central_estimate(criteria)
    return PoolEstimate(
        conservative=central * BASELINE.uncertainty_low,
        central=central,
        optimistic=central * BASELINE.uncertainty_high,
    )


def sensitivity_table(criteria: Criteria) -> list[dict[str, float | str]]:
    remaining = float(criteria.base_population)
    rows: list[dict[str, float | str]] = [
        {"factor": "Baseline", "coefficient": 1.0, "remaining": remaining}
    ]
    for label, coefficient in model_factors(criteria):
        remaining *= coefficient
        rows.append(
            {
                "factor": label,
                "coefficient": round(coefficient, 4),
                "remaining": round(remaining, 2),
            }
        )
    return rows