File size: 15,508 Bytes
4d0ffdd
 
 
 
 
 
 
 
 
 
f6c65ef
4d0ffdd
 
f6c65ef
 
4d0ffdd
 
 
 
 
6a28f91
 
4d0ffdd
 
 
 
6a28f91
 
4d0ffdd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f6c65ef
 
 
 
 
 
 
4d0ffdd
 
 
 
 
 
 
 
f6c65ef
4d0ffdd
 
 
 
 
 
 
 
 
f6c65ef
4d0ffdd
 
 
 
 
 
 
 
 
 
 
 
 
f6c65ef
 
4d0ffdd
 
 
f6c65ef
4d0ffdd
 
 
 
 
 
 
 
f6c65ef
4d0ffdd
 
 
 
 
 
f6c65ef
4d0ffdd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f6c65ef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4d0ffdd
 
 
 
 
 
 
 
 
 
 
f6c65ef
4d0ffdd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f6c65ef
4d0ffdd
 
 
 
f6c65ef
4d0ffdd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f6c65ef
4d0ffdd
f6c65ef
 
4d0ffdd
 
 
 
 
 
 
 
 
 
 
 
 
 
f6c65ef
 
 
 
 
 
 
 
 
 
4d0ffdd
f6c65ef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6a28f91
 
 
 
 
f6c65ef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6a28f91
 
 
f6c65ef
4d0ffdd
 
 
 
 
 
 
 
 
 
f6c65ef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4d0ffdd
f6c65ef
 
 
 
 
 
 
 
 
 
 
 
4d0ffdd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
"""Synthetic case generator (Phase 2).

Generates Case objects between start_date and end_date using:
- CASE_TYPE_DISTRIBUTION
- Monthly seasonality factors
- Urgent case percentage
- Court working days (CourtCalendar)

Also provides CSV export/import helpers compatible with scripts.
"""

from __future__ import annotations

import csv
import random
from dataclasses import dataclass
from datetime import date, timedelta
from pathlib import Path
from typing import Iterable, List, Tuple

from src.core.case import Case
from src.data.config import (
    CASE_TYPE_DISTRIBUTION,
    MONTHLY_SEASONALITY,
    URGENT_CASE_PERCENTAGE,
)
from src.data.param_loader import load_parameters
from src.utils.calendar import CourtCalendar


def _month_iter(start: date, end: date) -> Iterable[Tuple[int, int]]:
    y, m = start.year, start.month
    while (y, m) <= (end.year, end.month):
        yield (y, m)
        if m == 12:
            y += 1
            m = 1
        else:
            m += 1


@dataclass
class CaseGenerator:
    start: date
    end: date
    seed: int = 42

    def generate(
        self,
        n_cases: int,
        stage_mix: dict | None = None,
        stage_mix_auto: bool = False,
        case_type_distribution: dict | None = None,
    ) -> List[Case]:
        random.seed(self.seed)
        cal = CourtCalendar()
        if stage_mix_auto:
            params = load_parameters()
            stage_mix = params.get_stage_stationary_distribution()
        stage_mix = stage_mix or {"ADMISSION": 1.0}
        # normalize explicitly
        total_mix = sum(stage_mix.values()) or 1.0
        stage_mix = {k: v / total_mix for k, v in stage_mix.items()}
        # precompute cumulative for stage sampling
        stage_items = list(stage_mix.items())
        scum = []
        accs = 0.0
        for _, p in stage_items:
            accs += p
            scum.append(accs)
        if scum:
            scum[-1] = 1.0

        def sample_stage() -> str:
            if not stage_items:
                return "ADMISSION"
            r = random.random()
            for i, (st, _) in enumerate(stage_items):
                if r <= scum[i]:
                    return st
            return stage_items[-1][0]

        # duration sampling helpers (lognormal via median & p90)
        def sample_stage_duration(stage: str) -> float:
            params = getattr(sample_stage_duration, "_params", None)
            if params is None:
                sample_stage_duration._params = load_parameters()
                params = sample_stage_duration._params
            med = params.get_stage_duration(stage, "median")
            p90 = params.get_stage_duration(stage, "p90")
            import math

            med = max(med, 1e-3)
            p90 = max(p90, med + 1e-6)
            z = 1.2815515655446004
            sigma = max(1e-6, math.log(p90) - math.log(med)) / z
            mu = math.log(med)
            # Box-Muller normal sample
            u1 = max(random.random(), 1e-9)
            u2 = max(random.random(), 1e-9)
            z0 = ((-2.0 * math.log(u1)) ** 0.5) * math.cos(2.0 * math.pi * u2)
            val = math.exp(mu + sigma * z0)
            return max(1.0, val)

        # 1) Build monthly working-day lists and weights (seasonality * working days)
        month_days = {}
        month_weight = {}
        for y, m in _month_iter(self.start, self.end):
            days = cal.get_working_days_in_month(y, m)
            # restrict to [start, end]
            days = [d for d in days if self.start <= d <= self.end]
            if not days:
                continue
            month_days[(y, m)] = days
            month_weight[(y, m)] = MONTHLY_SEASONALITY.get(m, 1.0) * len(days)

        # normalize weights
        total_w = sum(month_weight.values())
        if total_w == 0:
            return []

        # 2) Allocate case counts per month (round, then adjust)
        alloc = {}
        for key, w in month_weight.items():
            cnt = int(round(n_cases * (w / total_w)))
            alloc[key] = cnt
        # adjust rounding to total n_cases
        diff = n_cases - sum(alloc.values())
        if diff != 0:
            # distribute the difference across months deterministically by key order
            keys = sorted(alloc.keys())
            idx = 0
            step = 1 if diff > 0 else -1
            for _ in range(abs(diff)):
                alloc[keys[idx]] += step
                idx = (idx + 1) % len(keys)

        # 3) Sampling helpers (case type distribution)
        # Allow custom distribution override; default to historical (from config/EDA)
        if case_type_distribution is None:
            type_dist = dict(CASE_TYPE_DISTRIBUTION)
        else:
            # Validate and normalize user-provided distribution
            # Filter out zero/negative and None values
            valid_items = {
                str(k): float(v)
                for k, v in case_type_distribution.items()
                if v is not None and float(v) > 0.0 and str(k)
            }
            # Fallback to defaults if invalid or empty after filtering
            if not valid_items:
                type_dist = dict(CASE_TYPE_DISTRIBUTION)
            else:
                total = sum(valid_items.values())
                # Normalize to 1.0
                type_dist = {k: v / total for k, v in valid_items.items()}

        type_items = list(type_dist.items())
        type_acc = []
        cum = 0.0
        for _, p in type_items:
            cum += p
            type_acc.append(cum)
        # ensure last is exactly 1.0 in case of rounding issues
        if type_acc:
            type_acc[-1] = 1.0

        def sample_case_type() -> str:
            r = random.random()
            for i, (ct, _) in enumerate(type_items):
                if r <= type_acc[i]:
                    return ct
            return type_items[-1][0]

        cases: List[Case] = []
        seq = 0
        for key in sorted(alloc.keys()):
            y, m = key
            days = month_days[key]
            if not days or alloc[key] <= 0:
                continue
            # simple distribution across working days of the month
            for _ in range(alloc[key]):
                filed = days[seq % len(days)]
                seq += 1
                ct = sample_case_type()
                urgent = random.random() < URGENT_CASE_PERCENTAGE
                cid = f"{ct}/{filed.year}/{len(cases) + 1:05d}"
                init_stage = sample_stage()
                # For initial cases: they're filed on 'filed' date, started current stage on filed date
                # days_in_stage represents how long they've been in this stage as of simulation start
                # We sample a duration but cap it to not go before filed_date
                int(sample_stage_duration(init_stage))
                # stage_start should be between filed_date and some time after
                # For simplicity: set stage_start = filed_date, case just entered this stage
                c = Case(
                    case_id=cid,
                    case_type=ct,
                    filed_date=filed,
                    current_stage=init_stage,
                    is_urgent=urgent,
                )
                c.stage_start_date = filed
                c.days_in_stage = 0
                # Initialize realistic hearing history
                # Spread last hearings across past 7-30 days to simulate realistic court flow
                # This ensures constant stream of cases becoming eligible, not all at once
                days_since_filed = (self.end - filed).days
                if days_since_filed > 30:  # Only if filed at least 30 days before end
                    # Determine number of historical hearings based on age (roughly monthly)
                    c.hearing_count = max(1, days_since_filed // 30)

                    # Define pools of purposes
                    bottleneck_purposes = [
                        "ISSUE SUMMONS",
                        "FOR NOTICE",
                        "AWAIT SERVICE OF NOTICE",
                        "STAY APPLICATION PENDING",
                        "FOR ORDERS",
                    ]
                    ripe_purposes = [
                        "ARGUMENTS",
                        "HEARING",
                        "FINAL ARGUMENTS",
                        "FOR JUDGMENT",
                        "EVIDENCE",
                    ]

                    # Build a small hearing history list on the Case.history
                    c.history = []

                    # Generate hearing dates spaced across the case lifetime, ending 7-30 days before end
                    days_before_end = random.randint(7, 30)
                    last_hearing_date = self.end - timedelta(days=days_before_end)
                    # approximate spacing
                    if c.hearing_count == 1:
                        hearing_dates = [last_hearing_date]
                    else:
                        span_days = max(days_since_filed - days_before_end, 30)
                        step = max(1, span_days // c.hearing_count)
                        hearing_dates = [
                            last_hearing_date - timedelta(days=step * i)
                            for i in range(c.hearing_count - 1)
                        ]
                        hearing_dates = sorted(hearing_dates) + [last_hearing_date]

                    # Assign purposes: earlier ones mixed; final one stage-dependent
                    for i, hdt in enumerate(hearing_dates):
                        if i == len(hearing_dates) - 1:
                            # Final hearing purpose depends on stage and random bottleneck share
                            if init_stage == "ADMISSION" and c.hearing_count < 3:
                                purpose = (
                                    random.choice(bottleneck_purposes)
                                    if random.random() < 0.4
                                    else random.choice(ripe_purposes)
                                )
                            elif init_stage in [
                                "ARGUMENTS",
                                "ORDERS / JUDGMENT",
                                "FINAL DISPOSAL",
                            ]:
                                purpose = random.choice(ripe_purposes)
                            else:
                                purpose = (
                                    random.choice(bottleneck_purposes)
                                    if random.random() < 0.2
                                    else random.choice(ripe_purposes)
                                )
                        else:
                            purpose = random.choice(bottleneck_purposes + ripe_purposes)

                        was_heard = purpose not in (
                            "ISSUE SUMMONS",
                            "FOR NOTICE",
                            "AWAIT SERVICE OF NOTICE",
                        )
                        c.history.append(
                            {
                                "date": hdt,
                                "event": "hearing",
                                "was_heard": was_heard,
                                "outcome": "",
                                "stage": init_stage,
                                "purpose": purpose,
                            }
                        )

                    # Update aggregates from generated history
                    c.last_hearing_date = last_hearing_date
                    c.days_since_last_hearing = days_before_end
                    c.last_hearing_purpose = (
                        c.history[-1]["purpose"] if c.history else None
                    )

                cases.append(c)

        return cases

    # CSV helpers -----------------------------------------------------------
    @staticmethod
    def to_csv(cases: List[Case], out_path: Path) -> None:
        out_path.parent.mkdir(parents=True, exist_ok=True)
        with out_path.open("w", newline="") as f:
            w = csv.writer(f)
            w.writerow(
                [
                    "case_id",
                    "case_type",
                    "filed_date",
                    "current_stage",
                    "is_urgent",
                    "hearing_count",
                    "last_hearing_date",
                    "days_since_last_hearing",
                    "last_hearing_purpose",
                ]
            )
            for c in cases:
                w.writerow(
                    [
                        c.case_id,
                        c.case_type,
                        c.filed_date.isoformat(),
                        c.current_stage,
                        1 if c.is_urgent else 0,
                        c.hearing_count,
                        c.last_hearing_date.isoformat() if c.last_hearing_date else "",
                        c.days_since_last_hearing,
                        c.last_hearing_purpose or "",
                    ]
                )

    @staticmethod
    def to_hearings_csv(cases: List[Case], out_path: Path) -> None:
        """Write flattened hearing histories for generated cases.

        Schema: case_id,date,stage,purpose,was_heard,event
        """
        out_path.parent.mkdir(parents=True, exist_ok=True)
        with out_path.open("w", newline="") as f:
            w = csv.writer(f)
            w.writerow(["case_id", "date", "stage", "purpose", "was_heard", "event"])
            for c in cases:
                for ev in getattr(c, "history", []) or []:
                    if ev.get("event") == "hearing":
                        w.writerow(
                            [
                                c.case_id,
                                (ev.get("date") or c.filed_date).isoformat(),
                                ev.get("stage") or c.current_stage,
                                ev.get("purpose", ""),
                                1 if ev.get("was_heard", False) else 0,
                                ev.get("event"),
                            ]
                        )

    @staticmethod
    def from_csv(path: Path) -> List[Case]:
        cases: List[Case] = []
        with path.open("r", newline="") as f:
            r = csv.DictReader(f)
            for row in r:
                c = Case(
                    case_id=row["case_id"],
                    case_type=row["case_type"],
                    filed_date=date.fromisoformat(row["filed_date"]),
                    current_stage=row.get("current_stage", "ADMISSION"),
                    is_urgent=(str(row.get("is_urgent", "0")) in ("1", "true", "True")),
                )
                # Load hearing history if available
                if "hearing_count" in row and row["hearing_count"]:
                    c.hearing_count = int(row["hearing_count"])
                if "last_hearing_date" in row and row["last_hearing_date"]:
                    c.last_hearing_date = date.fromisoformat(row["last_hearing_date"])
                if "days_since_last_hearing" in row and row["days_since_last_hearing"]:
                    c.days_since_last_hearing = int(row["days_since_last_hearing"])
                if "last_hearing_purpose" in row and row["last_hearing_purpose"]:
                    c.last_hearing_purpose = row["last_hearing_purpose"]
                cases.append(c)
        return cases