File size: 22,725 Bytes
938949f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
"""
ChronosForecaster: Day-ahead photosynthesis (A) forecasting using Amazon
Chronos-2 foundation model with native covariate support and optional
LoRA fine-tuning.

Improvement history:
  v1: Broken — daytime-only rows with hidden gaps → MAE ~8.5
  v2: Regular 15-min grid + predict_df + daytime eval → MAE ~1.75 (20w)
  v3: + On-site sensor covariates (PAR, VPD, T_leaf, CO2)
      + 14-day context (captures ~2 weeks of diurnal pattern)
      + LoRA fine-tuning (1000 steps, lr=1e-4)
      + Configurable covariate modes for ablation
      → MAE 1.37 (May), 3.0-3.4 (Jun-Sep), overall beats ML baseline (2.7)
  v4: Revisited input features: added engineered time (hour_sin/cos, doy_sin/cos) and
      stress_risk_ims (VPD from IMS T+RH) in load_data; tried extended IMS (tdmax/tdmin).
      Ablation on current data: best zero-shot = sensor (MAE ~3.86) or all (MAE ~3.91, R² 0.52).
      Time/stress as covariates slightly hurt; kept 4-col IMS + sensor for \"all\".
"""

from __future__ import annotations

from pathlib import Path
from typing import Optional

import numpy as np
import pandas as pd
import torch
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score

from config.settings import (
    PROCESSED_DIR, IMS_CACHE_DIR, OUTPUTS_DIR, GROWING_SEASON_MONTHS,
)
from src.time_features import add_cyclical_time_features

# ---------------------------------------------------------------------------
# Covariate definitions
# ---------------------------------------------------------------------------

# IMS station 43 weather (available as day-ahead forecasts in production)
# tdmax_c, tdmin_c available in data; ablation showed 4-col IMS best for this dataset
IMS_COVARIATE_COLS = [
    "ghi_w_m2", "air_temperature_c", "rh_percent", "wind_speed_ms",
]

# On-site Seymour sensors (past-only: not available as forecasts)
SENSOR_COVARIATE_COLS = [
    "PAR_site", "VPD_site", "T_leaf_site", "CO2_site",
]

# Engineered time features (deterministic from timestamp; available for future)
TIME_COVARIATE_COLS = ["hour_sin", "hour_cos", "doy_sin", "doy_cos"]

# Stress risk from IMS-derived VPD (past + future; VPD_ims from T + RH)
STRESS_COVARIATE_COL = "stress_risk_ims"

# Column mapping from raw sensor CSV → clean names
_SENSOR_COL_MAP = {
    "Air1_PAR_ref": "PAR_site",
    "Air1_VPD_ref": "VPD_site",
    "Air1_leafTemperature_ref": "T_leaf_site",
    "Air1_CO2_ref": "CO2_site",
}

FREQ = "15min"
STEPS_PER_DAY = 96  # 24h / 15min

# VPD from IMS T and RH (Buck formula, kPa) for stress_risk_ims
def _vpd_from_ims_kpa(T_c: np.ndarray, rh_percent: np.ndarray) -> np.ndarray:
    """Saturation vapour pressure (kPa) then VPD = esat * (1 - RH/100)."""
    esat = 0.611 * np.exp(17.27 * T_c / (T_c + 237.3))
    return esat * (1.0 - np.clip(rh_percent, 0, 100) / 100.0)


# Covariate mode presets
# "all" = extended IMS (incl. tdmax/tdmin) + sensor; time/stress available in data for optional use
COVARIATE_MODES = {
    "none": {"past": [], "future": []},
    "ims": {"past": IMS_COVARIATE_COLS, "future": IMS_COVARIATE_COLS},
    "sensor": {"past": SENSOR_COVARIATE_COLS, "future": []},
    "all": {
        "past": IMS_COVARIATE_COLS + SENSOR_COVARIATE_COLS,
        "future": IMS_COVARIATE_COLS,
    },
}


class ChronosForecaster:
    """Day-ahead A forecaster using Chronos-2 with configurable covariates."""

    def __init__(
        self,
        model_name: str = "amazon/chronos-2",
        device: str = "mps",
        context_days: int = 14,
    ):
        self.model_name = model_name
        self.device = device
        self.context_steps = context_days * STEPS_PER_DAY
        self._pipeline = None

    @property
    def pipeline(self):
        """Lazy-load Chronos-2 pipeline on first use."""
        if self._pipeline is None:
            from chronos import Chronos2Pipeline

            self._pipeline = Chronos2Pipeline.from_pretrained(
                self.model_name,
                device_map=self.device,
                dtype=torch.float32,
            )
        return self._pipeline

    @pipeline.setter
    def pipeline(self, value):
        """Allow setting pipeline (e.g. after fine-tuning)."""
        self._pipeline = value

    # ------------------------------------------------------------------
    # Data loading and resampling
    # ------------------------------------------------------------------

    @staticmethod
    def load_data(
        labels_path: Optional[Path] = None,
        ims_path: Optional[Path] = None,
        sensor_path: Optional[Path] = None,
        growing_season_only: bool = True,
    ) -> pd.DataFrame:
        """Load labels + IMS + on-site sensors, merge, resample to regular grid.

        Growing-season-only mode (default) drops Oct-Apr dormancy months,
        concatenating seasons into a continuous series with season boundaries
        marked by a 'season' column.
        """
        from config.settings import DATA_DIR, SEYMOUR_DIR

        labels_path = labels_path or PROCESSED_DIR / "stage1_labels.csv"
        ims_path = ims_path or IMS_CACHE_DIR / "ims_merged_15min.csv"
        sensor_path = sensor_path or SEYMOUR_DIR / "sensors_wide.csv"

        # --- Labels ---
        labels = pd.read_csv(labels_path, parse_dates=["time"])
        labels.rename(columns={"time": "timestamp_utc"}, inplace=True)
        labels["timestamp_utc"] = pd.to_datetime(labels["timestamp_utc"], utc=True)

        # --- IMS ---
        ims = pd.read_csv(ims_path, parse_dates=["timestamp_utc"])
        ims["timestamp_utc"] = pd.to_datetime(ims["timestamp_utc"], utc=True)

        # --- On-site sensors ---
        raw_cols = ["time"] + list(_SENSOR_COL_MAP.keys())
        sensors = pd.read_csv(sensor_path, usecols=raw_cols, parse_dates=["time"])
        sensors.rename(columns={"time": "timestamp_utc", **_SENSOR_COL_MAP}, inplace=True)
        sensors["timestamp_utc"] = pd.to_datetime(sensors["timestamp_utc"], utc=True)

        # --- Merge ---
        merged = labels.merge(ims, on="timestamp_utc", how="inner")
        merged = merged.merge(sensors, on="timestamp_utc", how="left")
        merged.sort_values("timestamp_utc", inplace=True)
        merged.set_index("timestamp_utc", inplace=True)

        # --- Resample to regular 15-min grid ---
        full_idx = pd.date_range(
            merged.index.min(), merged.index.max(), freq=FREQ, tz="UTC",
        )
        resampled = merged.reindex(full_idx)
        resampled.index.name = "timestamp_utc"

        # Fill A=0 overnight, interpolate covariates
        resampled["A"] = resampled["A"].fillna(0.0)
        all_cov_cols = [
            c for c in IMS_COVARIATE_COLS + SENSOR_COVARIATE_COLS
            if c in resampled.columns
        ]
        for col in all_cov_cols:
            resampled[col] = (
                resampled[col].interpolate(method="time").ffill().bfill()
            )
            if col in ("ghi_w_m2", "PAR_site"):
                resampled[col] = resampled[col].clip(lower=0)

        # Engineered time covariates (deterministic; available for future)
        resampled = add_cyclical_time_features(resampled, index_is_timestamp=True)

        # Stress risk from IMS VPD (past + future; 0–1 scale, clip VPD at 6 kPa)
        if "air_temperature_c" in resampled.columns and "rh_percent" in resampled.columns:
            vpd_ims = _vpd_from_ims_kpa(
                resampled["air_temperature_c"].values,
                resampled["rh_percent"].values,
            )
            resampled[STRESS_COVARIATE_COL] = np.clip(vpd_ims / 6.0, 0.0, 1.0)

        resampled.reset_index(inplace=True)

        # --- Growing-season filter ---
        if growing_season_only:
            resampled["month"] = resampled["timestamp_utc"].dt.month
            resampled = resampled[
                resampled["month"].isin(GROWING_SEASON_MONTHS)
            ].copy()
            resampled.drop(columns=["month"], inplace=True)
            resampled.reset_index(drop=True, inplace=True)

        # Add season column (year of growing season)
        resampled["season"] = resampled["timestamp_utc"].dt.year

        return resampled

    @staticmethod
    def load_sparse_data(
        labels_path: Optional[Path] = None,
        ims_path: Optional[Path] = None,
    ) -> pd.DataFrame:
        """Load original daytime-only merged data (no resampling).
        Used to identify daytime timestamps for evaluation masking.
        """
        labels_path = labels_path or PROCESSED_DIR / "stage1_labels.csv"
        ims_path = ims_path or IMS_CACHE_DIR / "ims_merged_15min.csv"

        labels = pd.read_csv(labels_path, parse_dates=["time"])
        labels.rename(columns={"time": "timestamp_utc"}, inplace=True)
        labels["timestamp_utc"] = pd.to_datetime(labels["timestamp_utc"], utc=True)

        ims = pd.read_csv(ims_path, parse_dates=["timestamp_utc"])
        ims["timestamp_utc"] = pd.to_datetime(ims["timestamp_utc"], utc=True)

        merged = labels.merge(ims, on="timestamp_utc", how="inner")
        merged.sort_values("timestamp_utc", inplace=True)
        merged.reset_index(drop=True, inplace=True)
        return merged

    # ------------------------------------------------------------------
    # predict_df based forecasting
    # ------------------------------------------------------------------

    def forecast_day(
        self,
        df: pd.DataFrame,
        context_end_idx: int,
        prediction_length: int = STEPS_PER_DAY,
        covariate_mode: str = "all",
    ) -> pd.DataFrame:
        """Forecast next prediction_length steps using predict_df API.

        covariate_mode: 'none', 'ims', 'sensor', or 'all'
        """
        mode_cfg = COVARIATE_MODES[covariate_mode]
        past_cols = [c for c in mode_cfg["past"] if c in df.columns]
        future_cols = [c for c in mode_cfg["future"] if c in df.columns]

        ctx_start = max(0, context_end_idx - self.context_steps)
        ctx = df.iloc[ctx_start:context_end_idx].copy()

        # Build history DataFrame
        hist = ctx[["timestamp_utc", "A"]].copy()
        hist.rename(columns={"timestamp_utc": "timestamp", "A": "target"}, inplace=True)
        hist["item_id"] = "A"
        for col in past_cols:
            hist[col] = ctx[col].values

        # Build future covariates DataFrame
        future_df = None
        if future_cols:
            fwd = df.iloc[context_end_idx : context_end_idx + prediction_length]
            if len(fwd) >= prediction_length:
                future_df = fwd[["timestamp_utc"]].copy()
                future_df.rename(columns={"timestamp_utc": "timestamp"}, inplace=True)
                future_df["item_id"] = "A"
                for col in future_cols:
                    future_df[col] = fwd[col].values

        result = self.pipeline.predict_df(
            df=hist,
            future_df=future_df,
            id_column="item_id",
            timestamp_column="timestamp",
            target="target",
            prediction_length=prediction_length,
            quantile_levels=[0.1, 0.5, 0.9],
        )

        fwd_timestamps = df["timestamp_utc"].iloc[
            context_end_idx : context_end_idx + prediction_length
        ].values

        out = pd.DataFrame({
            "timestamp_utc": fwd_timestamps[:len(result)],
            "median": result["0.5"].values,
            "low_10": result["0.1"].values,
            "high_90": result["0.9"].values,
        })
        return out

    # ------------------------------------------------------------------
    # LoRA fine-tuning
    # ------------------------------------------------------------------

    def finetune(
        self,
        df: pd.DataFrame,
        train_ratio: float = 0.75,
        prediction_length: int = STEPS_PER_DAY,
        covariate_mode: str = "all",
        num_steps: int = 500,
        learning_rate: float = 1e-5,
        batch_size: Optional[int] = None,
        output_dir: Optional[str] = None,
    ) -> None:
        """LoRA fine-tune Chronos-2 on the training portion of the data.

        Uses the dict API for fit() with past and future covariates.
        Only the training portion (before train_ratio split) is used —
        no data leakage.
        """
        split_idx = int(len(df) * train_ratio)
        train_df = df.iloc[:split_idx].copy()

        mode_cfg = COVARIATE_MODES[covariate_mode]
        past_cols = [c for c in mode_cfg["past"] if c in df.columns]
        future_cols = [c for c in mode_cfg["future"] if c in df.columns]

        # Build training inputs: sliding windows over the training data
        # Each window: context_steps history + prediction_length target
        min_window = self.context_steps + prediction_length
        inputs = []

        # Sample windows every prediction_length steps for diversity
        stride = prediction_length
        for end_idx in range(min_window, len(train_df), stride):
            ctx_start = end_idx - min_window
            ctx_end = end_idx - prediction_length

            target = train_df["A"].iloc[ctx_start:ctx_end].values.astype(np.float32)
            entry: dict = {"target": target}

            if past_cols:
                past_covs = {}
                for col in past_cols:
                    past_covs[col] = (
                        train_df[col].iloc[ctx_start:ctx_end].values.astype(np.float32)
                    )
                entry["past_covariates"] = past_covs

            if future_cols:
                future_covs = {}
                for col in future_cols:
                    # Use actual values from training data as future covariates
                    future_covs[col] = (
                        train_df[col].iloc[ctx_end:end_idx].values.astype(np.float32)
                    )
                entry["future_covariates"] = future_covs

            inputs.append(entry)

        if not inputs:
            print("Not enough training data for fine-tuning.")
            return

        # Build validation inputs from last 10% of training portion
        val_split = int(len(inputs) * 0.9)
        train_inputs = inputs[:val_split]
        val_inputs = inputs[val_split:] if val_split < len(inputs) else None

        output_dir = output_dir or str(OUTPUTS_DIR / "chronos_finetuned")
        effective_batch = batch_size if batch_size is not None else min(32, len(train_inputs))

        print(f"Fine-tuning with LoRA: {len(train_inputs)} train windows, "
              f"{len(val_inputs) if val_inputs else 0} val windows, "
              f"{num_steps} steps, batch_size={effective_batch}")

        finetuned = self.pipeline.fit(
            inputs=train_inputs,
            prediction_length=prediction_length,
            validation_inputs=val_inputs,
            finetune_mode="lora",
            learning_rate=learning_rate,
            num_steps=num_steps,
            batch_size=effective_batch,
            output_dir=output_dir,
        )

        self.pipeline = finetuned
        print(f"Fine-tuning complete. Model saved → {output_dir}")

    # ------------------------------------------------------------------
    # Walk-forward benchmark
    # ------------------------------------------------------------------

    def benchmark(
        self,
        df: Optional[pd.DataFrame] = None,
        train_ratio: float = 0.75,
        prediction_length: int = STEPS_PER_DAY,
        max_test_days: Optional[int] = None,
        covariate_modes: Optional[list[str]] = None,
    ) -> pd.DataFrame:
        """Walk-forward evaluation across covariate modes.

        Predicts 96 steps (24h) on the regular grid, evaluates ONLY on
        daytime steps where actual A > 0.
        """
        if df is None:
            df = self.load_data()

        if covariate_modes is None:
            covariate_modes = ["none", "ims", "sensor", "all"]

        sparse = self.load_sparse_data()
        daytime_timestamps = set(sparse["timestamp_utc"])

        split_idx = int(len(df) * train_ratio)
        test_starts = list(range(split_idx, len(df) - prediction_length, prediction_length))
        if max_test_days is not None:
            test_starts = test_starts[:max_test_days]

        results = {}
        for mode in covariate_modes:
            all_actual, all_pred = [], []

            for start_idx in test_starts:
                forecast_df = self.forecast_day(
                    df, start_idx, prediction_length, covariate_mode=mode,
                )

                actual_slice = df.iloc[start_idx : start_idx + prediction_length]
                if len(actual_slice) < prediction_length:
                    continue

                daytime_mask = actual_slice["timestamp_utc"].isin(daytime_timestamps).values
                daytime_mask = daytime_mask[:len(forecast_df)]

                if daytime_mask.sum() < 5:
                    continue

                actual_day = actual_slice["A"].values[:len(forecast_df)][daytime_mask]
                pred_day = np.clip(forecast_df["median"].values[daytime_mask], 0, None)

                all_actual.append(actual_day)
                all_pred.append(pred_day)

            if not all_actual:
                continue

            actual_flat = np.concatenate(all_actual)
            pred_flat = np.concatenate(all_pred)

            results[mode] = {
                "MAE": round(float(mean_absolute_error(actual_flat, pred_flat)), 4),
                "RMSE": round(
                    float(np.sqrt(mean_squared_error(actual_flat, pred_flat))), 4
                ),
                "R2": round(float(r2_score(actual_flat, pred_flat)), 4),
                "n_windows": len(all_actual),
                "n_steps": len(actual_flat),
            }
            print(f"  {mode:12s}: MAE={results[mode]['MAE']:.4f}  "
                  f"RMSE={results[mode]['RMSE']:.4f}  R²={results[mode]['R2']:.4f}  "
                  f"({results[mode]['n_windows']} windows, "
                  f"{results[mode]['n_steps']} daytime steps)")

        comparison = pd.DataFrame(results).T
        comparison.index.name = "mode"
        comparison.reset_index(inplace=True)

        # Append ML baseline row for app comparison
        ml_baseline = pd.DataFrame([{
            "mode": "ML baseline (best)",
            "MAE": 2.7,
            "RMSE": np.nan,
            "R2": np.nan,
            "n_windows": np.nan,
            "n_steps": np.nan,
        }])
        comparison = pd.concat([comparison, ml_baseline], ignore_index=True)

        OUTPUTS_DIR.mkdir(parents=True, exist_ok=True)
        comparison.to_csv(OUTPUTS_DIR / "chronos_benchmark.csv", index=False)
        print(f"Saved benchmark → {OUTPUTS_DIR / 'chronos_benchmark.csv'}")

        return comparison

    # ------------------------------------------------------------------
    # Sample forecast plot
    # ------------------------------------------------------------------

    def plot_sample_forecast(
        self,
        df: Optional[pd.DataFrame] = None,
        test_day_idx: int = 0,
        train_ratio: float = 0.75,
        prediction_length: int = STEPS_PER_DAY,
    ) -> None:
        """Generate a sample forecast plot with confidence bands."""
        import matplotlib.pyplot as plt

        if df is None:
            df = self.load_data()

        split_idx = int(len(df) * train_ratio)
        start_idx = split_idx + test_day_idx * prediction_length

        if start_idx + prediction_length > len(df):
            print("Not enough data for sample forecast plot.")
            return

        forecast_df = self.forecast_day(
            df, start_idx, prediction_length, covariate_mode="all",
        )
        actual = df["A"].iloc[start_idx : start_idx + prediction_length].values

        fig, ax = plt.subplots(figsize=(12, 5))
        hours = np.arange(len(forecast_df)) * 0.25

        ax.plot(hours, actual[:len(forecast_df)], "k-", linewidth=1.5, label="Actual A")
        ax.plot(
            hours, np.clip(forecast_df["median"].values, 0, None),
            "b-", linewidth=1.5, label="Chronos-2 median",
        )
        ax.fill_between(
            hours,
            np.clip(forecast_df["low_10"].values, 0, None),
            forecast_df["high_90"].values,
            alpha=0.25, color="steelblue", label="10-90% CI",
        )
        ax.set_xlabel("Hours ahead")
        ax.set_ylabel("A (umol CO2 m-2 s-1)")
        ax.axhline(0, color="gray", linewidth=0.5, linestyle="--")

        ts = df["timestamp_utc"].iloc[start_idx]
        ax.set_title(f"Chronos-2 Day-Ahead Forecast — {ts:%Y-%m-%d %H:%M}")
        ax.legend()
        ax.grid(True, alpha=0.3)

        OUTPUTS_DIR.mkdir(parents=True, exist_ok=True)
        fig.savefig(
            OUTPUTS_DIR / "chronos_forecast_sample.png", dpi=150, bbox_inches="tight",
        )
        plt.close(fig)
        print(f"Saved plot → {OUTPUTS_DIR / 'chronos_forecast_sample.png'}")


# ----------------------------------------------------------------------
# CLI entry point
# ----------------------------------------------------------------------

if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser(description="Chronos-2 day-ahead A forecasting")
    parser.add_argument("--device", default="mps", help="torch device")
    parser.add_argument("--context-days", type=int, default=14, help="context window in days")
    parser.add_argument("--max-days", type=int, default=None, help="limit test windows")
    parser.add_argument("--plot", action="store_true", help="generate sample forecast plot")
    parser.add_argument(
        "--finetune", action="store_true",
        help="LoRA fine-tune before benchmarking",
    )
    parser.add_argument("--ft-steps", type=int, default=500, help="fine-tuning steps")
    parser.add_argument(
        "--modes", nargs="+", default=["none", "ims", "sensor", "all"],
        help="covariate modes to benchmark",
    )
    args = parser.parse_args()

    forecaster = ChronosForecaster(
        device=args.device, context_days=args.context_days,
    )

    print("Loading data (growing-season grid + on-site sensors)...")
    df = forecaster.load_data()
    print(f"  Grid: {len(df)} rows, seasons: {sorted(df['season'].unique())}")

    if args.finetune:
        print(f"\nLoRA fine-tuning ({args.ft_steps} steps)...")
        forecaster.finetune(df, num_steps=args.ft_steps, covariate_mode="all")

    print("\nRunning walk-forward benchmark (daytime-only evaluation)...")
    results = forecaster.benchmark(
        df, max_test_days=args.max_days, covariate_modes=args.modes,
    )
    print(f"\n{results.to_string(index=False)}")

    if args.plot:
        print("\nGenerating sample forecast plot...")
        forecaster.plot_sample_forecast(df)