Frontend upgrade: heat-red identity, danger banner, callout cards, richer detail
Browse files- Section headers use heat-red (#e63946) bottom borders instead of gold
- Sidebar logo and nav active state use heat-red accent
- Sidebar gradient shifts warm (#221a18) vs Weather AI 2's neutral (#222018)
- Dashboard: danger banner when triggers are active, callout cards explaining
UHI effect and AI prediction (matches Weather AI 2's detail level)
- Stage cards have colored top borders (blue/amber/red)
- Added CSS for callout cards, danger banner, temperature value colors
- Fixed missing Thermometer import in ProgramDesigner
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
frontend/src/components/Sidebar.tsx
CHANGED
|
@@ -21,12 +21,12 @@ export default function Sidebar() {
|
|
| 21 |
return (
|
| 22 |
<aside
|
| 23 |
className="fixed top-0 left-0 z-50 h-full w-56 flex flex-col"
|
| 24 |
-
style={{ background: 'linear-gradient(180deg, #1a1a1a 0%, #
|
| 25 |
>
|
| 26 |
{/* Brand */}
|
| 27 |
<div className="flex items-center h-16 px-5 border-b border-white/10">
|
| 28 |
<NavLink to="/" className="flex items-center gap-2.5 no-underline">
|
| 29 |
-
<div className="w-8 h-8 rounded-lg
|
| 30 |
<Thermometer size={18} className="text-white" />
|
| 31 |
</div>
|
| 32 |
<div>
|
|
@@ -51,7 +51,7 @@ export default function Sidebar() {
|
|
| 51 |
className={({ isActive }) =>
|
| 52 |
`flex items-center gap-3 px-3 py-2.5 rounded-lg text-sm font-sans font-medium transition-colors duration-100 ${
|
| 53 |
isActive
|
| 54 |
-
? 'bg-
|
| 55 |
: 'text-[#e0dcd5] hover:bg-white/5 hover:text-white'
|
| 56 |
}`
|
| 57 |
}
|
|
|
|
| 21 |
return (
|
| 22 |
<aside
|
| 23 |
className="fixed top-0 left-0 z-50 h-full w-56 flex flex-col"
|
| 24 |
+
style={{ background: 'linear-gradient(180deg, #1a1a1a 0%, #221a18 100%)' }}
|
| 25 |
>
|
| 26 |
{/* Brand */}
|
| 27 |
<div className="flex items-center h-16 px-5 border-b border-white/10">
|
| 28 |
<NavLink to="/" className="flex items-center gap-2.5 no-underline">
|
| 29 |
+
<div className="w-8 h-8 rounded-lg flex items-center justify-center" style={{ background: '#e63946' }}>
|
| 30 |
<Thermometer size={18} className="text-white" />
|
| 31 |
</div>
|
| 32 |
<div>
|
|
|
|
| 51 |
className={({ isActive }) =>
|
| 52 |
`flex items-center gap-3 px-3 py-2.5 rounded-lg text-sm font-sans font-medium transition-colors duration-100 ${
|
| 53 |
isActive
|
| 54 |
+
? 'bg-[#e63946]/15 text-[#ff6b6b]'
|
| 55 |
: 'text-[#e0dcd5] hover:bg-white/5 hover:text-white'
|
| 56 |
}`
|
| 57 |
}
|
frontend/src/index.css
CHANGED
|
@@ -255,12 +255,68 @@
|
|
| 255 |
transform: translateY(-3px);
|
| 256 |
}
|
| 257 |
|
| 258 |
-
/* ── Section headers ── */
|
| 259 |
.section-header {
|
| 260 |
@apply uppercase text-[0.78rem] font-sans font-semibold text-warm-muted pb-2 mb-4;
|
| 261 |
letter-spacing: 1.5px;
|
| 262 |
-
border-bottom: 2px solid #
|
| 263 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 264 |
}
|
| 265 |
|
| 266 |
@layer utilities {
|
|
|
|
| 255 |
transform: translateY(-3px);
|
| 256 |
}
|
| 257 |
|
| 258 |
+
/* ── Section headers — heat-red accent (distinct from Weather AI 2's gold) ── */
|
| 259 |
.section-header {
|
| 260 |
@apply uppercase text-[0.78rem] font-sans font-semibold text-warm-muted pb-2 mb-4;
|
| 261 |
letter-spacing: 1.5px;
|
| 262 |
+
border-bottom: 2px solid #e63946;
|
| 263 |
}
|
| 264 |
+
|
| 265 |
+
/* ── Stage card top borders (heat-themed color coding) ── */
|
| 266 |
+
.stage-card[data-stage="data"] { border-top: 3px solid #1565C0; }
|
| 267 |
+
.stage-card[data-stage="forecast"] { border-top: 3px solid #e67e22; }
|
| 268 |
+
.stage-card[data-stage="program"] { border-top: 3px solid #e63946; }
|
| 269 |
+
|
| 270 |
+
/* ── Heat status indicators ── */
|
| 271 |
+
.heat-pill {
|
| 272 |
+
@apply inline-flex items-center gap-1.5 rounded-full px-2.5 py-1 text-[11px] font-semibold font-sans;
|
| 273 |
+
}
|
| 274 |
+
.heat-pill.safe { background: rgba(42, 157, 143, 0.15); color: #2a9d8f; }
|
| 275 |
+
.heat-pill.caution { background: rgba(212, 160, 25, 0.15); color: #d4a019; }
|
| 276 |
+
.heat-pill.warning { background: rgba(230, 126, 34, 0.15); color: #e67e22; }
|
| 277 |
+
.heat-pill.danger { background: rgba(230, 57, 70, 0.15); color: #e63946; }
|
| 278 |
+
|
| 279 |
+
/* ── Callout cards (left-border accent — distinct from Health Optimizer's teal) ── */
|
| 280 |
+
.callout {
|
| 281 |
+
@apply bg-white rounded-[10px] border border-warm-border p-4;
|
| 282 |
+
border-left: 3px solid #e63946;
|
| 283 |
+
}
|
| 284 |
+
.callout.amber { border-left-color: #e67e22; }
|
| 285 |
+
.callout.blue { border-left-color: #1565C0; }
|
| 286 |
+
.callout.green { border-left-color: #2a9d8f; }
|
| 287 |
+
|
| 288 |
+
.callout-title {
|
| 289 |
+
@apply text-[0.82rem] font-semibold font-sans text-[#1a1a1a] mb-1;
|
| 290 |
+
}
|
| 291 |
+
.callout-body {
|
| 292 |
+
@apply text-[0.78rem] text-warm-body leading-relaxed;
|
| 293 |
+
}
|
| 294 |
+
|
| 295 |
+
/* ── Danger banner ── */
|
| 296 |
+
.danger-banner {
|
| 297 |
+
@apply rounded-[10px] mb-6 flex items-center gap-3;
|
| 298 |
+
background: linear-gradient(135deg, rgba(230,57,70,0.06) 0%, rgba(230,126,34,0.04) 100%);
|
| 299 |
+
border: 1px solid rgba(230,57,70,0.15);
|
| 300 |
+
padding: 14px 20px;
|
| 301 |
+
}
|
| 302 |
+
|
| 303 |
+
.danger-dot {
|
| 304 |
+
width: 10px; height: 10px; border-radius: 50%; flex-shrink: 0;
|
| 305 |
+
background: #e63946;
|
| 306 |
+
box-shadow: 0 0 0 3px rgba(230,57,70,0.2);
|
| 307 |
+
}
|
| 308 |
+
|
| 309 |
+
.danger-text {
|
| 310 |
+
@apply text-[0.82rem] font-sans font-medium;
|
| 311 |
+
color: #e63946;
|
| 312 |
+
}
|
| 313 |
+
|
| 314 |
+
/* ── Temperature value colors ── */
|
| 315 |
+
.temp-safe { color: #2a9d8f; }
|
| 316 |
+
.temp-caution { color: #d4a019; }
|
| 317 |
+
.temp-warning { color: #e67e22; }
|
| 318 |
+
.temp-danger { color: #d35400; }
|
| 319 |
+
.temp-extreme { color: #e63946; }
|
| 320 |
}
|
| 321 |
|
| 322 |
@layer utilities {
|
frontend/src/pages/Dashboard.tsx
CHANGED
|
@@ -4,17 +4,21 @@ import { Satellite, Thermometer, SlidersHorizontal, ChevronDown, ChevronRight, A
|
|
| 4 |
import MetricCard from '../components/MetricCard'
|
| 5 |
import StatusBadge from '../components/StatusBadge'
|
| 6 |
import { LoadingSpinner, ErrorState } from '../components/LoadingState'
|
| 7 |
-
import { usePipelineStats, usePipelineRuns } from '../lib/api'
|
| 8 |
|
| 9 |
export default function Dashboard() {
|
| 10 |
const stats = usePipelineStats()
|
| 11 |
const runs = usePipelineRuns()
|
|
|
|
| 12 |
const [showRuns, setShowRuns] = useState(false)
|
| 13 |
|
| 14 |
if (stats.isLoading) return <LoadingSpinner />
|
| 15 |
if (stats.isError) return <ErrorState onRetry={() => stats.refetch()} />
|
| 16 |
|
| 17 |
const s = stats.data
|
|
|
|
|
|
|
|
|
|
| 18 |
|
| 19 |
return (
|
| 20 |
<div className="animate-slide-up">
|
|
@@ -26,11 +30,22 @@ export default function Dashboard() {
|
|
| 26 |
</p>
|
| 27 |
</div>
|
| 28 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
{/* Stage Cards */}
|
| 30 |
<div data-tour="stage-cards" className="mb-8">
|
| 31 |
<div className="section-header">How It Works</div>
|
| 32 |
<div className="grid grid-cols-1 md:grid-cols-3 gap-4 items-center">
|
| 33 |
-
<Link to="/heat-monitor" className="stage-card no-underline">
|
| 34 |
<div className="flex items-center gap-3 mb-2">
|
| 35 |
<div className="w-9 h-9 rounded-lg bg-blue-50 flex items-center justify-center">
|
| 36 |
<Satellite size={18} className="text-info" />
|
|
@@ -49,7 +64,7 @@ export default function Dashboard() {
|
|
| 49 |
<ArrowRight size={20} className="text-warm-border" />
|
| 50 |
</div>
|
| 51 |
|
| 52 |
-
<Link to="/heat-monitor" className="stage-card no-underline">
|
| 53 |
<div className="flex items-center gap-3 mb-2">
|
| 54 |
<div className="w-9 h-9 rounded-lg bg-amber-50 flex items-center justify-center">
|
| 55 |
<Thermometer size={18} className="text-warning" />
|
|
@@ -68,7 +83,7 @@ export default function Dashboard() {
|
|
| 68 |
<ArrowRight size={20} className="text-warm-border" />
|
| 69 |
</div>
|
| 70 |
|
| 71 |
-
<Link to="/calibrate" className="stage-card no-underline">
|
| 72 |
<div className="flex items-center gap-3 mb-2">
|
| 73 |
<div className="w-9 h-9 rounded-lg bg-red-50 flex items-center justify-center">
|
| 74 |
<SlidersHorizontal size={18} className="text-error" />
|
|
@@ -112,12 +127,30 @@ export default function Dashboard() {
|
|
| 112 |
</div>
|
| 113 |
</div>
|
| 114 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
{/* Run History (collapsible) */}
|
| 116 |
<div className="mb-8">
|
| 117 |
<button
|
| 118 |
onClick={() => setShowRuns(!showRuns)}
|
| 119 |
className="flex items-center gap-2 section-header cursor-pointer w-full text-left border-b-0 pb-0 mb-0 bg-transparent border-none"
|
| 120 |
-
style={{ borderBottom: '2px solid #
|
| 121 |
>
|
| 122 |
{showRuns ? <ChevronDown size={14} /> : <ChevronRight size={14} />}
|
| 123 |
Update History
|
|
|
|
| 4 |
import MetricCard from '../components/MetricCard'
|
| 5 |
import StatusBadge from '../components/StatusBadge'
|
| 6 |
import { LoadingSpinner, ErrorState } from '../components/LoadingState'
|
| 7 |
+
import { usePipelineStats, usePipelineRuns, useTriggers } from '../lib/api'
|
| 8 |
|
| 9 |
export default function Dashboard() {
|
| 10 |
const stats = usePipelineStats()
|
| 11 |
const runs = usePipelineRuns()
|
| 12 |
+
const triggers = useTriggers()
|
| 13 |
const [showRuns, setShowRuns] = useState(false)
|
| 14 |
|
| 15 |
if (stats.isLoading) return <LoadingSpinner />
|
| 16 |
if (stats.isError) return <ErrorState onRetry={() => stats.refetch()} />
|
| 17 |
|
| 18 |
const s = stats.data
|
| 19 |
+
const activeTriggers = triggers.data?.triggers ?? []
|
| 20 |
+
const criticalZones = activeTriggers.filter((t) => t.trigger_level === 'critical')
|
| 21 |
+
const dangerNames = activeTriggers.slice(0, 3).map((t) => t.zone_name)
|
| 22 |
|
| 23 |
return (
|
| 24 |
<div className="animate-slide-up">
|
|
|
|
| 30 |
</p>
|
| 31 |
</div>
|
| 32 |
|
| 33 |
+
{/* Danger banner — shows when there are active triggers */}
|
| 34 |
+
{activeTriggers.length > 0 && (
|
| 35 |
+
<div className="danger-banner">
|
| 36 |
+
<div className="danger-dot" />
|
| 37 |
+
<div className="danger-text">
|
| 38 |
+
<strong>{activeTriggers.length} zone{activeTriggers.length !== 1 ? 's' : ''}</strong> currently exceeding safe heat levels
|
| 39 |
+
{dangerNames.length > 0 && <> — {dangerNames.join(', ')}{activeTriggers.length > 3 ? ` and ${activeTriggers.length - 3} more` : ''}</>}
|
| 40 |
+
</div>
|
| 41 |
+
</div>
|
| 42 |
+
)}
|
| 43 |
+
|
| 44 |
{/* Stage Cards */}
|
| 45 |
<div data-tour="stage-cards" className="mb-8">
|
| 46 |
<div className="section-header">How It Works</div>
|
| 47 |
<div className="grid grid-cols-1 md:grid-cols-3 gap-4 items-center">
|
| 48 |
+
<Link to="/heat-monitor" className="stage-card no-underline" data-stage="data">
|
| 49 |
<div className="flex items-center gap-3 mb-2">
|
| 50 |
<div className="w-9 h-9 rounded-lg bg-blue-50 flex items-center justify-center">
|
| 51 |
<Satellite size={18} className="text-info" />
|
|
|
|
| 64 |
<ArrowRight size={20} className="text-warm-border" />
|
| 65 |
</div>
|
| 66 |
|
| 67 |
+
<Link to="/heat-monitor" className="stage-card no-underline" data-stage="forecast">
|
| 68 |
<div className="flex items-center gap-3 mb-2">
|
| 69 |
<div className="w-9 h-9 rounded-lg bg-amber-50 flex items-center justify-center">
|
| 70 |
<Thermometer size={18} className="text-warning" />
|
|
|
|
| 83 |
<ArrowRight size={20} className="text-warm-border" />
|
| 84 |
</div>
|
| 85 |
|
| 86 |
+
<Link to="/calibrate" className="stage-card no-underline" data-stage="program">
|
| 87 |
<div className="flex items-center gap-3 mb-2">
|
| 88 |
<div className="w-9 h-9 rounded-lg bg-red-50 flex items-center justify-center">
|
| 89 |
<SlidersHorizontal size={18} className="text-error" />
|
|
|
|
| 127 |
</div>
|
| 128 |
</div>
|
| 129 |
|
| 130 |
+
{/* Explainer callouts */}
|
| 131 |
+
<div className="grid grid-cols-1 md:grid-cols-2 gap-3 mb-8">
|
| 132 |
+
<div className="callout amber">
|
| 133 |
+
<div className="callout-title">Urban Heat Island Effect</div>
|
| 134 |
+
<div className="callout-body">
|
| 135 |
+
Informal settlements with tin roofs can be 3-6°C hotter than surrounding areas.
|
| 136 |
+
The system adjusts satellite readings to reflect what workers actually experience on the ground.
|
| 137 |
+
</div>
|
| 138 |
+
</div>
|
| 139 |
+
<div className="callout blue">
|
| 140 |
+
<div className="callout-title">AI-Powered Prediction</div>
|
| 141 |
+
<div className="callout-body">
|
| 142 |
+
Two AI models work together to predict dangerous heat 7 days out.
|
| 143 |
+
When one model lacks data, the system automatically falls back to the next most reliable method.
|
| 144 |
+
</div>
|
| 145 |
+
</div>
|
| 146 |
+
</div>
|
| 147 |
+
|
| 148 |
{/* Run History (collapsible) */}
|
| 149 |
<div className="mb-8">
|
| 150 |
<button
|
| 151 |
onClick={() => setShowRuns(!showRuns)}
|
| 152 |
className="flex items-center gap-2 section-header cursor-pointer w-full text-left border-b-0 pb-0 mb-0 bg-transparent border-none"
|
| 153 |
+
style={{ borderBottom: '2px solid #e63946', paddingBottom: 8, marginBottom: 16 }}
|
| 154 |
>
|
| 155 |
{showRuns ? <ChevronDown size={14} /> : <ChevronRight size={14} />}
|
| 156 |
Update History
|
frontend/src/pages/HeatMonitor.tsx
CHANGED
|
@@ -61,7 +61,10 @@ export default function HeatMonitor() {
|
|
| 61 |
<div data-tour="heat-monitor-title" className="pt-2 pb-6">
|
| 62 |
<h1 className="page-title">Heat Monitor</h1>
|
| 63 |
<p className="page-caption">
|
| 64 |
-
Temperature,
|
|
|
|
|
|
|
|
|
|
| 65 |
</p>
|
| 66 |
</div>
|
| 67 |
|
|
@@ -235,8 +238,11 @@ export default function HeatMonitor() {
|
|
| 235 |
{/* Temperature Chart */}
|
| 236 |
<div className="card card-body">
|
| 237 |
<h3 className="text-sm font-semibold font-sans text-[#1a1a1a] mb-1">
|
| 238 |
-
90
|
| 239 |
</h3>
|
|
|
|
|
|
|
|
|
|
| 240 |
<p className="text-xs text-warm-muted mb-4">
|
| 241 |
Current:{' '}
|
| 242 |
<span style={{ color: tempColor(selectedData?.temp_current ?? 0) }} className="font-semibold">
|
|
|
|
| 61 |
<div data-tour="heat-monitor-title" className="pt-2 pb-6">
|
| 62 |
<h1 className="page-title">Heat Monitor</h1>
|
| 63 |
<p className="page-caption">
|
| 64 |
+
Temperature, heat stress, and feels-like conditions across all zones
|
| 65 |
+
</p>
|
| 66 |
+
<p className="text-sm text-warm-body mt-2 leading-relaxed" style={{ maxWidth: '640px' }}>
|
| 67 |
+
Live readings from satellite sensors, adjusted for how hot each neighborhood actually feels to someone working outside. Red means danger — workers in those zones face unsafe conditions today.
|
| 68 |
</p>
|
| 69 |
</div>
|
| 70 |
|
|
|
|
| 238 |
{/* Temperature Chart */}
|
| 239 |
<div className="card card-body">
|
| 240 |
<h3 className="text-sm font-semibold font-sans text-[#1a1a1a] mb-1">
|
| 241 |
+
How hot has it been? 90 days in {selectedData?.zone_name}
|
| 242 |
</h3>
|
| 243 |
+
<p className="text-xs text-warm-body mb-1 leading-relaxed">
|
| 244 |
+
The solid red line is what workers actually feel (adjusted for local conditions). The dashed line is raw satellite data before adjustment. When the red line crosses the 35°C danger mark, outdoor work becomes unsafe.
|
| 245 |
+
</p>
|
| 246 |
<p className="text-xs text-warm-muted mb-4">
|
| 247 |
Current:{' '}
|
| 248 |
<span style={{ color: tempColor(selectedData?.temp_current ?? 0) }} className="font-semibold">
|
frontend/src/pages/ProgramDesigner.tsx
CHANGED
|
@@ -1,4 +1,5 @@
|
|
| 1 |
import { useState, useEffect, useRef } from 'react'
|
|
|
|
| 2 |
import MetricCard from '../components/MetricCard'
|
| 3 |
import { LoadingSpinner, ErrorState } from '../components/LoadingState'
|
| 4 |
import { useCalibrateQuery } from '../lib/api'
|
|
@@ -284,8 +285,16 @@ export default function ProgramDesigner() {
|
|
| 284 |
</table>
|
| 285 |
</div>
|
| 286 |
{sortedZones.length === 0 && (
|
| 287 |
-
<div className="text-center py-
|
| 288 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 289 |
</div>
|
| 290 |
)}
|
| 291 |
</>
|
|
|
|
| 1 |
import { useState, useEffect, useRef } from 'react'
|
| 2 |
+
import { Thermometer } from 'lucide-react'
|
| 3 |
import MetricCard from '../components/MetricCard'
|
| 4 |
import { LoadingSpinner, ErrorState } from '../components/LoadingState'
|
| 5 |
import { useCalibrateQuery } from '../lib/api'
|
|
|
|
| 285 |
</table>
|
| 286 |
</div>
|
| 287 |
{sortedZones.length === 0 && (
|
| 288 |
+
<div className="text-center py-16 font-sans">
|
| 289 |
+
<div className="text-2xl mb-3">
|
| 290 |
+
<Thermometer size={36} className="mx-auto text-warm-border" />
|
| 291 |
+
</div>
|
| 292 |
+
<p className="text-sm font-semibold text-[#1a1a1a] mb-2">
|
| 293 |
+
No zones reached the danger threshold
|
| 294 |
+
</p>
|
| 295 |
+
<p className="text-sm text-warm-muted leading-relaxed max-w-md mx-auto">
|
| 296 |
+
None of the monitored zones have experienced enough consecutive hot days at this temperature to trigger an alert. Try lowering the danger temperature, reducing the consecutive hot days requirement, or increasing your budget to cover more zones.
|
| 297 |
+
</p>
|
| 298 |
</div>
|
| 299 |
)}
|
| 300 |
</>
|
frontend/src/pages/Zones.tsx
CHANGED
|
@@ -192,6 +192,11 @@ export default function Zones() {
|
|
| 192 |
|
| 193 |
{activeTab === 'exposure' && (
|
| 194 |
<div className="animate-tab-enter">
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 195 |
{enrolled.isLoading ? (
|
| 196 |
<LoadingSpinner message="Loading enrollment data..." />
|
| 197 |
) : enrolled.isError ? (
|
|
|
|
| 192 |
|
| 193 |
{activeTab === 'exposure' && (
|
| 194 |
<div className="animate-tab-enter">
|
| 195 |
+
<div className="card card-body mb-6">
|
| 196 |
+
<p className="text-sm text-warm-body leading-relaxed m-0">
|
| 197 |
+
How many workers in each zone spend their day outdoors in the heat — and how many are enrolled in the protection program. Zones with high outdoor exposure but low enrollment are the biggest coverage gaps.
|
| 198 |
+
</p>
|
| 199 |
+
</div>
|
| 200 |
{enrolled.isLoading ? (
|
| 201 |
<LoadingSpinner message="Loading enrollment data..." />
|
| 202 |
) : enrolled.isError ? (
|
scripts/train_on_era5.py
ADDED
|
@@ -0,0 +1,491 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Train all ML models on real ERA5 reanalysis data.
|
| 3 |
+
|
| 4 |
+
Steps:
|
| 5 |
+
1. Fetch 2 years of ERA5 data for all 20 zones via Google ARCO Zarr store
|
| 6 |
+
2. Validate data quality (coverage, temp ranges, nulls)
|
| 7 |
+
3. Retrain XGBoost heat predictor on real data
|
| 8 |
+
4. Retrain LSTM on real data
|
| 9 |
+
5. Verify UHI model works with real ERA5 temps
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import sys
|
| 13 |
+
import os
|
| 14 |
+
import time
|
| 15 |
+
import logging
|
| 16 |
+
import math
|
| 17 |
+
|
| 18 |
+
import numpy as np
|
| 19 |
+
|
| 20 |
+
# Project root on sys.path
|
| 21 |
+
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 22 |
+
|
| 23 |
+
from config import ZONES, ZONE_MAP
|
| 24 |
+
from src.ingestion.era5_fetcher import fetch_era5_sync
|
| 25 |
+
from src.ingestion.models import DailyReading
|
| 26 |
+
from src.indexing.heat_index import calculate_wbgt
|
| 27 |
+
|
| 28 |
+
logging.basicConfig(
|
| 29 |
+
level=logging.INFO,
|
| 30 |
+
format="%(asctime)s %(name)s %(levelname)s %(message)s",
|
| 31 |
+
datefmt="%H:%M:%S",
|
| 32 |
+
)
|
| 33 |
+
log = logging.getLogger("train_era5")
|
| 34 |
+
|
| 35 |
+
# Expected temp ranges per city (max daily temps, deg C)
|
| 36 |
+
EXPECTED_RANGES = {
|
| 37 |
+
"Nairobi": (18, 35),
|
| 38 |
+
"Dar es Salaam": (25, 40),
|
| 39 |
+
"Kampala": (22, 36),
|
| 40 |
+
"Kigali": (20, 34),
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
# ======================================================================
|
| 45 |
+
# Step 1: Fetch ERA5 data
|
| 46 |
+
# ======================================================================
|
| 47 |
+
|
| 48 |
+
def fetch_data():
|
| 49 |
+
log.info("=" * 60)
|
| 50 |
+
log.info("STEP 1: Fetching 2 years of ERA5 data for %d zones", len(ZONES))
|
| 51 |
+
log.info("=" * 60)
|
| 52 |
+
t0 = time.time()
|
| 53 |
+
data = fetch_era5_sync(ZONES, days_back=730)
|
| 54 |
+
elapsed = time.time() - t0
|
| 55 |
+
log.info("Fetch complete in %.1f seconds", elapsed)
|
| 56 |
+
return data
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
# ======================================================================
|
| 60 |
+
# Step 2: Validate data quality
|
| 61 |
+
# ======================================================================
|
| 62 |
+
|
| 63 |
+
def validate_data(data: dict[str, list[DailyReading]]):
|
| 64 |
+
log.info("=" * 60)
|
| 65 |
+
log.info("STEP 2: Validating ERA5 data quality")
|
| 66 |
+
log.info("=" * 60)
|
| 67 |
+
|
| 68 |
+
issues = []
|
| 69 |
+
stats = {}
|
| 70 |
+
|
| 71 |
+
for zone in ZONES:
|
| 72 |
+
zid = zone.zone_id
|
| 73 |
+
readings = data.get(zid, [])
|
| 74 |
+
|
| 75 |
+
if not readings:
|
| 76 |
+
issues.append(f"{zid}: NO DATA")
|
| 77 |
+
stats[zid] = {"days": 0, "issue": "no data"}
|
| 78 |
+
continue
|
| 79 |
+
|
| 80 |
+
temps = [r.temp_max_c for r in readings if r.temp_max_c is not None]
|
| 81 |
+
humids = [r.humidity_pct for r in readings if r.humidity_pct is not None]
|
| 82 |
+
winds = [r.wind_speed_ms for r in readings if r.wind_speed_ms is not None]
|
| 83 |
+
|
| 84 |
+
if not temps:
|
| 85 |
+
issues.append(f"{zid}: all temps are null")
|
| 86 |
+
stats[zid] = {"days": len(readings), "issue": "all null temps"}
|
| 87 |
+
continue
|
| 88 |
+
|
| 89 |
+
t_min, t_max = min(temps), max(temps)
|
| 90 |
+
t_mean = sum(temps) / len(temps)
|
| 91 |
+
|
| 92 |
+
# Check physical reasonableness
|
| 93 |
+
exp_lo, exp_hi = EXPECTED_RANGES.get(zone.city, (15, 42))
|
| 94 |
+
if t_min < exp_lo - 5 or t_max > exp_hi + 5:
|
| 95 |
+
issues.append(
|
| 96 |
+
f"{zid} ({zone.city}): temp range [{t_min:.1f}, {t_max:.1f}] "
|
| 97 |
+
f"outside expected [{exp_lo-5}, {exp_hi+5}]"
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
null_count = sum(1 for r in readings if r.temp_max_c is None)
|
| 101 |
+
|
| 102 |
+
stats[zid] = {
|
| 103 |
+
"days": len(readings),
|
| 104 |
+
"temp_days": len(temps),
|
| 105 |
+
"temp_min": round(t_min, 1),
|
| 106 |
+
"temp_max": round(t_max, 1),
|
| 107 |
+
"temp_mean": round(t_mean, 1),
|
| 108 |
+
"humidity_mean": round(sum(humids)/len(humids), 1) if humids else None,
|
| 109 |
+
"wind_mean": round(sum(winds)/len(winds), 1) if winds else None,
|
| 110 |
+
"null_temps": null_count,
|
| 111 |
+
}
|
| 112 |
+
|
| 113 |
+
# Print summary
|
| 114 |
+
print("\n--- ERA5 Data Summary ---")
|
| 115 |
+
print(f"{'Zone':<12} {'City':<16} {'Days':>5} {'Temp min':>9} {'Temp max':>9} {'Temp mean':>10} {'Humidity':>9} {'Nulls':>6}")
|
| 116 |
+
print("-" * 90)
|
| 117 |
+
|
| 118 |
+
by_city = {}
|
| 119 |
+
for zone in ZONES:
|
| 120 |
+
s = stats.get(zone.zone_id, {})
|
| 121 |
+
days = s.get("days", 0)
|
| 122 |
+
t_lo = s.get("temp_min", "N/A")
|
| 123 |
+
t_hi = s.get("temp_max", "N/A")
|
| 124 |
+
t_mn = s.get("temp_mean", "N/A")
|
| 125 |
+
hum = s.get("humidity_mean", "N/A")
|
| 126 |
+
nulls = s.get("null_temps", "N/A")
|
| 127 |
+
print(f"{zone.zone_id:<12} {zone.city:<16} {days:>5} {t_lo:>9} {t_hi:>9} {t_mn:>10} {hum:>9} {nulls:>6}")
|
| 128 |
+
|
| 129 |
+
city = zone.city
|
| 130 |
+
if city not in by_city:
|
| 131 |
+
by_city[city] = []
|
| 132 |
+
by_city[city].append(s)
|
| 133 |
+
|
| 134 |
+
print("\n--- Per-city aggregated temp ranges ---")
|
| 135 |
+
for city, zone_stats in by_city.items():
|
| 136 |
+
all_mins = [s["temp_min"] for s in zone_stats if s.get("temp_min") is not None]
|
| 137 |
+
all_maxs = [s["temp_max"] for s in zone_stats if s.get("temp_max") is not None]
|
| 138 |
+
if all_mins and all_maxs:
|
| 139 |
+
print(f" {city:<16}: {min(all_mins):.1f} - {max(all_maxs):.1f} C")
|
| 140 |
+
|
| 141 |
+
if issues:
|
| 142 |
+
print(f"\n ISSUES ({len(issues)}):")
|
| 143 |
+
for issue in issues:
|
| 144 |
+
print(f" - {issue}")
|
| 145 |
+
else:
|
| 146 |
+
print("\n No data quality issues found.")
|
| 147 |
+
|
| 148 |
+
zones_with_data = sum(1 for s in stats.values() if s.get("days", 0) > 0)
|
| 149 |
+
assert zones_with_data == len(ZONES), f"Only {zones_with_data}/{len(ZONES)} zones have data"
|
| 150 |
+
print(f"\n All {zones_with_data} zones have data.\n")
|
| 151 |
+
|
| 152 |
+
return stats
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
# ======================================================================
|
| 156 |
+
# Step 3: Retrain XGBoost heat predictor on real data
|
| 157 |
+
# ======================================================================
|
| 158 |
+
|
| 159 |
+
def retrain_xgboost(data: dict[str, list[DailyReading]]):
|
| 160 |
+
log.info("=" * 60)
|
| 161 |
+
log.info("STEP 3: Retraining XGBoost heat predictor on real ERA5 data")
|
| 162 |
+
log.info("=" * 60)
|
| 163 |
+
|
| 164 |
+
from src.prediction.heat_forecast import HeatWavePredictor, CITY_THRESHOLDS, CITY_CLIMATE
|
| 165 |
+
from src.prediction.lstm_model import CITY_CLIMATE as _ # ensure import works
|
| 166 |
+
|
| 167 |
+
import xgboost as xgb
|
| 168 |
+
|
| 169 |
+
# We replicate the training logic from HeatWavePredictor.train() but
|
| 170 |
+
# use real ERA5 temps/humidity instead of synthetic series.
|
| 171 |
+
|
| 172 |
+
all_X = []
|
| 173 |
+
all_y = []
|
| 174 |
+
|
| 175 |
+
for zone in ZONES:
|
| 176 |
+
zid = zone.zone_id
|
| 177 |
+
readings = data.get(zid, [])
|
| 178 |
+
if len(readings) < 40:
|
| 179 |
+
log.warning("Zone %s has only %d readings, skipping for XGBoost training", zid, len(readings))
|
| 180 |
+
continue
|
| 181 |
+
|
| 182 |
+
city = zone.city
|
| 183 |
+
threshold = CITY_THRESHOLDS.get(city, 33.0)
|
| 184 |
+
|
| 185 |
+
# Extract time series from real data
|
| 186 |
+
temps = []
|
| 187 |
+
humidity = []
|
| 188 |
+
for r in readings:
|
| 189 |
+
t = r.temp_max_c
|
| 190 |
+
h = r.humidity_pct
|
| 191 |
+
if t is None:
|
| 192 |
+
continue
|
| 193 |
+
temps.append(t)
|
| 194 |
+
humidity.append(h if h is not None else 65.0)
|
| 195 |
+
|
| 196 |
+
n_days = len(temps)
|
| 197 |
+
if n_days < 40:
|
| 198 |
+
log.warning("Zone %s has only %d valid temp readings, skipping", zid, n_days)
|
| 199 |
+
continue
|
| 200 |
+
|
| 201 |
+
# Compute WBGT series
|
| 202 |
+
wbgt_series = [calculate_wbgt(t, h) for t, h in zip(temps, humidity)]
|
| 203 |
+
|
| 204 |
+
# Labels: trigger within next 7 days (2+ consecutive above threshold)
|
| 205 |
+
labels = [0] * n_days
|
| 206 |
+
for day in range(n_days - 7):
|
| 207 |
+
window = temps[day + 1:day + 8]
|
| 208 |
+
consec = 0
|
| 209 |
+
triggered = False
|
| 210 |
+
for t in window:
|
| 211 |
+
if t > threshold:
|
| 212 |
+
consec += 1
|
| 213 |
+
if consec >= 2:
|
| 214 |
+
triggered = True
|
| 215 |
+
break
|
| 216 |
+
else:
|
| 217 |
+
consec = 0
|
| 218 |
+
labels[day] = 1 if triggered else 0
|
| 219 |
+
|
| 220 |
+
# Vulnerability encoding
|
| 221 |
+
vuln_map = {"high": 1.0, "moderate": 0.5, "low": 0.0}
|
| 222 |
+
zone_vuln = vuln_map.get(zone.heat_vulnerability, 0.5)
|
| 223 |
+
|
| 224 |
+
rng = np.random.default_rng(42)
|
| 225 |
+
|
| 226 |
+
# Build features (need 30-day lookback)
|
| 227 |
+
for day in range(30, n_days - 7):
|
| 228 |
+
t_window = temps[day - 30:day + 1]
|
| 229 |
+
h_window = humidity[day - 30:day + 1]
|
| 230 |
+
w_window = wbgt_series[day - 30:day + 1]
|
| 231 |
+
|
| 232 |
+
current_temp = t_window[-1]
|
| 233 |
+
current_wbgt = w_window[-1]
|
| 234 |
+
current_humidity = h_window[-1]
|
| 235 |
+
|
| 236 |
+
# Trend: slope of last 7 days
|
| 237 |
+
x7 = np.arange(7, dtype=np.float64)
|
| 238 |
+
y7 = np.array(t_window[-7:], dtype=np.float64)
|
| 239 |
+
temp_trend = float(np.polyfit(x7, y7, 1)[0])
|
| 240 |
+
|
| 241 |
+
# Anomaly: current vs 30-day mean
|
| 242 |
+
temp_anomaly = current_temp - float(np.mean(t_window))
|
| 243 |
+
|
| 244 |
+
# Soil moisture proxy
|
| 245 |
+
soil_proxy = float(np.clip(1.0 - (temp_anomaly + 2.0) / 4.0, 0.0, 1.0))
|
| 246 |
+
|
| 247 |
+
# Rolling error (use neutral prior for training data)
|
| 248 |
+
rolling_err = rng.uniform(0.1, 0.5)
|
| 249 |
+
|
| 250 |
+
# Day-of-year encoding (use day index within 365-day cycle)
|
| 251 |
+
doy = day % 365
|
| 252 |
+
doy_sin = np.sin(2 * np.pi * doy / 365.0)
|
| 253 |
+
doy_cos = np.cos(2 * np.pi * doy / 365.0)
|
| 254 |
+
|
| 255 |
+
# Random hour for variety
|
| 256 |
+
hour = rng.integers(6, 19)
|
| 257 |
+
hour_sin = np.sin(2 * np.pi * hour / 24.0)
|
| 258 |
+
hour_cos = np.cos(2 * np.pi * hour / 24.0)
|
| 259 |
+
|
| 260 |
+
row = [
|
| 261 |
+
current_temp,
|
| 262 |
+
current_wbgt,
|
| 263 |
+
current_humidity,
|
| 264 |
+
temp_trend,
|
| 265 |
+
temp_anomaly,
|
| 266 |
+
soil_proxy,
|
| 267 |
+
rolling_err,
|
| 268 |
+
doy_sin,
|
| 269 |
+
doy_cos,
|
| 270 |
+
hour_sin,
|
| 271 |
+
hour_cos,
|
| 272 |
+
zone_vuln,
|
| 273 |
+
]
|
| 274 |
+
|
| 275 |
+
all_X.append(row)
|
| 276 |
+
all_y.append(labels[day])
|
| 277 |
+
|
| 278 |
+
X = np.array(all_X, dtype=np.float32)
|
| 279 |
+
y = np.array(all_y, dtype=np.int32)
|
| 280 |
+
|
| 281 |
+
pos_rate = y.sum() / len(y) if len(y) > 0 else 0
|
| 282 |
+
log.info(
|
| 283 |
+
"XGBoost training data: %d samples, %.1f%% positive rate",
|
| 284 |
+
len(X), pos_rate * 100,
|
| 285 |
+
)
|
| 286 |
+
|
| 287 |
+
# Create a fresh predictor to get the model object, then retrain
|
| 288 |
+
predictor = HeatWavePredictor.__new__(HeatWavePredictor)
|
| 289 |
+
predictor.model_path = HeatWavePredictor.__init__.__defaults__[0] # fallback
|
| 290 |
+
from pathlib import Path
|
| 291 |
+
predictor.model_path = Path(__file__).resolve().parents[1] / "models" / "heat_predictor_xgb.json"
|
| 292 |
+
predictor._rolling_errors = []
|
| 293 |
+
|
| 294 |
+
model = xgb.XGBClassifier(
|
| 295 |
+
n_estimators=150,
|
| 296 |
+
max_depth=5,
|
| 297 |
+
learning_rate=0.1,
|
| 298 |
+
eval_metric="logloss",
|
| 299 |
+
random_state=42,
|
| 300 |
+
)
|
| 301 |
+
|
| 302 |
+
# Train/validation split (temporal: first 75% train, last 25% val)
|
| 303 |
+
split = int(len(X) * 0.75)
|
| 304 |
+
X_train, X_val = X[:split], X[split:]
|
| 305 |
+
y_train, y_val = y[:split], y[split:]
|
| 306 |
+
|
| 307 |
+
model.fit(
|
| 308 |
+
X_train, y_train,
|
| 309 |
+
eval_set=[(X_val, y_val)],
|
| 310 |
+
verbose=False,
|
| 311 |
+
)
|
| 312 |
+
|
| 313 |
+
# Evaluate on validation set
|
| 314 |
+
from sklearn.metrics import roc_auc_score, precision_score, recall_score
|
| 315 |
+
|
| 316 |
+
val_probs = model.predict_proba(X_val)[:, 1]
|
| 317 |
+
val_preds = (val_probs > 0.5).astype(int)
|
| 318 |
+
|
| 319 |
+
if len(set(y_val)) > 1:
|
| 320 |
+
auroc = roc_auc_score(y_val, val_probs)
|
| 321 |
+
precision = precision_score(y_val, val_preds, zero_division=0)
|
| 322 |
+
recall = recall_score(y_val, val_preds, zero_division=0)
|
| 323 |
+
else:
|
| 324 |
+
auroc, precision, recall = 0.5, 0.0, 0.0
|
| 325 |
+
|
| 326 |
+
print(f"\n--- XGBoost Results (real ERA5 data) ---")
|
| 327 |
+
print(f" Training samples: {len(X_train)}")
|
| 328 |
+
print(f" Validation samples: {len(X_val)}")
|
| 329 |
+
print(f" Positive rate: {pos_rate:.1%}")
|
| 330 |
+
print(f" Val AUROC: {auroc:.4f}")
|
| 331 |
+
print(f" Val Precision: {precision:.4f}")
|
| 332 |
+
print(f" Val Recall: {recall:.4f}")
|
| 333 |
+
|
| 334 |
+
# Save model
|
| 335 |
+
predictor.model_path.parent.mkdir(parents=True, exist_ok=True)
|
| 336 |
+
model.save_model(str(predictor.model_path))
|
| 337 |
+
log.info("XGBoost model saved to %s", predictor.model_path)
|
| 338 |
+
|
| 339 |
+
return {
|
| 340 |
+
"train_samples": len(X_train),
|
| 341 |
+
"val_samples": len(X_val),
|
| 342 |
+
"positive_rate": round(pos_rate, 4),
|
| 343 |
+
"val_auroc": round(auroc, 4),
|
| 344 |
+
"val_precision": round(precision, 4),
|
| 345 |
+
"val_recall": round(recall, 4),
|
| 346 |
+
}
|
| 347 |
+
|
| 348 |
+
|
| 349 |
+
# ======================================================================
|
| 350 |
+
# Step 4: Retrain LSTM on real data
|
| 351 |
+
# ======================================================================
|
| 352 |
+
|
| 353 |
+
def retrain_lstm(data: dict[str, list[DailyReading]]):
|
| 354 |
+
log.info("=" * 60)
|
| 355 |
+
log.info("STEP 4: Retraining LSTM on real ERA5 data")
|
| 356 |
+
log.info("=" * 60)
|
| 357 |
+
|
| 358 |
+
from src.prediction.lstm_model import LSTMTrainer
|
| 359 |
+
|
| 360 |
+
# Convert ERA5 DailyReading objects into the format the LSTM trainer expects:
|
| 361 |
+
# dict of zone_id -> list of dicts with keys: temp_max_c, humidity_pct, wind_speed_ms, city
|
| 362 |
+
zone_readings = {}
|
| 363 |
+
for zone in ZONES:
|
| 364 |
+
zid = zone.zone_id
|
| 365 |
+
readings = data.get(zid, [])
|
| 366 |
+
days = []
|
| 367 |
+
for r in readings:
|
| 368 |
+
if r.temp_max_c is None:
|
| 369 |
+
continue
|
| 370 |
+
days.append({
|
| 371 |
+
"temp_max_c": r.temp_max_c,
|
| 372 |
+
"humidity_pct": r.humidity_pct if r.humidity_pct is not None else 65.0,
|
| 373 |
+
"wind_speed_ms": r.wind_speed_ms if r.wind_speed_ms is not None else 3.0,
|
| 374 |
+
"city": zone.city,
|
| 375 |
+
})
|
| 376 |
+
if len(days) > 30:
|
| 377 |
+
zone_readings[zid] = days
|
| 378 |
+
log.info("Zone %s: %d valid readings for LSTM", zid, len(days))
|
| 379 |
+
else:
|
| 380 |
+
log.warning("Zone %s: only %d valid readings, skipping LSTM", zid, len(days))
|
| 381 |
+
|
| 382 |
+
log.info("Training LSTM on %d zones", len(zone_readings))
|
| 383 |
+
trainer = LSTMTrainer(epochs=50, patience=5)
|
| 384 |
+
metrics = trainer.train(zone_readings)
|
| 385 |
+
|
| 386 |
+
print(f"\n--- LSTM Results (real ERA5 data) ---")
|
| 387 |
+
for k, v in metrics.items():
|
| 388 |
+
print(f" {k}: {v}")
|
| 389 |
+
|
| 390 |
+
return metrics
|
| 391 |
+
|
| 392 |
+
|
| 393 |
+
# ======================================================================
|
| 394 |
+
# Step 5: Verify UHI model with real ERA5 temps
|
| 395 |
+
# ======================================================================
|
| 396 |
+
|
| 397 |
+
def verify_uhi(data: dict[str, list[DailyReading]]):
|
| 398 |
+
log.info("=" * 60)
|
| 399 |
+
log.info("STEP 5: Verifying UHI model with real ERA5 temperatures")
|
| 400 |
+
log.info("=" * 60)
|
| 401 |
+
|
| 402 |
+
from src.downscaling.uhi_model import UHICorrector
|
| 403 |
+
|
| 404 |
+
corrector = UHICorrector()
|
| 405 |
+
|
| 406 |
+
results = {}
|
| 407 |
+
for zone in ZONES:
|
| 408 |
+
zid = zone.zone_id
|
| 409 |
+
readings = data.get(zid, [])
|
| 410 |
+
if not readings:
|
| 411 |
+
continue
|
| 412 |
+
|
| 413 |
+
# Use real ERA5 temps as grid baseline
|
| 414 |
+
real_temps = [r.temp_max_c for r in readings if r.temp_max_c is not None]
|
| 415 |
+
if not real_temps:
|
| 416 |
+
continue
|
| 417 |
+
|
| 418 |
+
# Sample a few real temps and apply UHI correction
|
| 419 |
+
sample_indices = np.linspace(0, len(real_temps) - 1, min(20, len(real_temps)), dtype=int)
|
| 420 |
+
deltas = []
|
| 421 |
+
corrected_temps = []
|
| 422 |
+
|
| 423 |
+
for idx in sample_indices:
|
| 424 |
+
grid_temp = real_temps[idx]
|
| 425 |
+
corrected, delta, conf = corrector.correct_temperature(zone, grid_temp, hour=14, month=1)
|
| 426 |
+
deltas.append(delta)
|
| 427 |
+
corrected_temps.append(corrected)
|
| 428 |
+
|
| 429 |
+
results[zid] = {
|
| 430 |
+
"city": zone.city,
|
| 431 |
+
"settlement": zone.settlement_type,
|
| 432 |
+
"mean_grid_temp": round(sum(real_temps) / len(real_temps), 1),
|
| 433 |
+
"mean_uhi_delta": round(sum(deltas) / len(deltas), 2),
|
| 434 |
+
"mean_corrected": round(sum(corrected_temps) / len(corrected_temps), 1),
|
| 435 |
+
}
|
| 436 |
+
|
| 437 |
+
print(f"\n--- UHI Verification with Real ERA5 Temps ---")
|
| 438 |
+
print(f"{'Zone':<12} {'City':<16} {'Type':<12} {'Grid T':>7} {'UHI +':>7} {'Corrected':>10}")
|
| 439 |
+
print("-" * 70)
|
| 440 |
+
for zid, r in results.items():
|
| 441 |
+
print(
|
| 442 |
+
f"{zid:<12} {r['city']:<16} {r['settlement']:<12} "
|
| 443 |
+
f"{r['mean_grid_temp']:>6.1f}C {r['mean_uhi_delta']:>+6.2f}C {r['mean_corrected']:>9.1f}C"
|
| 444 |
+
)
|
| 445 |
+
|
| 446 |
+
return results
|
| 447 |
+
|
| 448 |
+
|
| 449 |
+
# ======================================================================
|
| 450 |
+
# Main
|
| 451 |
+
# ======================================================================
|
| 452 |
+
|
| 453 |
+
def main():
|
| 454 |
+
t_start = time.time()
|
| 455 |
+
|
| 456 |
+
# Step 1: Fetch
|
| 457 |
+
data = fetch_data()
|
| 458 |
+
|
| 459 |
+
# Step 2: Validate
|
| 460 |
+
data_stats = validate_data(data)
|
| 461 |
+
|
| 462 |
+
# Step 3: XGBoost
|
| 463 |
+
xgb_metrics = retrain_xgboost(data)
|
| 464 |
+
|
| 465 |
+
# Step 4: LSTM
|
| 466 |
+
lstm_metrics = retrain_lstm(data)
|
| 467 |
+
|
| 468 |
+
# Step 5: UHI verification
|
| 469 |
+
uhi_results = verify_uhi(data)
|
| 470 |
+
|
| 471 |
+
total_time = time.time() - t_start
|
| 472 |
+
|
| 473 |
+
print("\n" + "=" * 60)
|
| 474 |
+
print("TRAINING COMPLETE")
|
| 475 |
+
print("=" * 60)
|
| 476 |
+
|
| 477 |
+
total_days = sum(
|
| 478 |
+
len([r for r in data.get(z.zone_id, []) if r.temp_max_c is not None])
|
| 479 |
+
for z in ZONES
|
| 480 |
+
)
|
| 481 |
+
print(f" Total real data points: {total_days} zone-days across {len(ZONES)} zones")
|
| 482 |
+
print(f" XGBoost val AUROC: {xgb_metrics['val_auroc']:.4f}")
|
| 483 |
+
print(f" LSTM val AUROC: {lstm_metrics.get('val_auroc', 'N/A')}")
|
| 484 |
+
print(f" LSTM epochs trained: {lstm_metrics.get('epochs_trained', 'N/A')}")
|
| 485 |
+
print(f" LSTM final val loss: {lstm_metrics.get('val_loss', 'N/A')}")
|
| 486 |
+
print(f" Total time: {total_time:.1f}s")
|
| 487 |
+
print()
|
| 488 |
+
|
| 489 |
+
|
| 490 |
+
if __name__ == "__main__":
|
| 491 |
+
main()
|