jyyd23 commited on
Commit
6ed9cb6
·
verified ·
1 Parent(s): 130d8c3

Upload 6 files

Browse files
RL_SingleShot.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
dataset/all_test_pred2.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bf6c0136608c8ef425d40ca4031f0af7b160238d9c4f2b5676ebbaf7be21391e
3
+ size 74641224
dataset/combined_test_special.txt ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ min5_m_v1_0_d270_sc2_s15_04123.npz
2
+ min5_m_v1_0_d270_sc2_s21_04130.npz
3
+ min5_m_v1_0_d270_sc2_s52_04164.npz
4
+ min5_m_v1_0_d270_sc2_s61_04174.npz
5
+ min5_m_v1_0_d270_sc2_s72_04186.npz
6
+ min5_m_v1_0_d270_sc2_s83_04198.npz
7
+ min5_m_v1_0_d270_sc2_s87_04202.npz
8
+ min5_m_v1_0_d270_sc2_s93_04209.npz
9
+ min5_m_v1_0_d270_sc4_s15_04221.npz
10
+ min5_m_v1_0_d270_sc4_s21_04228.npz
11
+ min5_m_v1_0_d270_sc4_s52_04262.npz
12
+ min5_m_v1_0_d270_sc4_s61_04272.npz
13
+ min5_m_v1_0_d270_sc4_s72_04284.npz
14
+ min5_m_v1_0_d270_sc4_s83_04296.npz
15
+ min5_m_v1_0_d270_sc4_s87_04300.npz
16
+ min5_m_v1_0_d270_sc4_s93_04307.npz
17
+ min5_m_v1_0_d270_sc6_s15_04319.npz
18
+ min5_m_v1_0_d270_sc6_s21_04326.npz
19
+ min5_m_v1_0_d270_sc6_s52_04360.npz
20
+ min5_m_v1_0_d270_sc6_s61_04370.npz
21
+ min5_m_v1_0_d270_sc6_s72_04382.npz
22
+ min5_m_v1_0_d270_sc6_s83_04394.npz
23
+ min5_m_v1_0_d270_sc6_s87_04398.npz
24
+ min5_m_v1_0_d270_sc6_s93_04405.npz
25
+ min5_m_v1_0_d90_sc2_s15_04417.npz
26
+ min5_m_v1_0_d90_sc2_s21_04424.npz
27
+ min5_m_v1_0_d90_sc2_s52_04458.npz
28
+ min5_m_v1_0_d90_sc2_s61_04468.npz
29
+ min5_m_v1_0_d90_sc2_s72_04480.npz
30
+ min5_m_v1_0_d90_sc2_s83_04492.npz
31
+ min5_m_v1_0_d90_sc2_s87_04496.npz
32
+ min5_m_v1_0_d90_sc2_s93_04503.npz
33
+ min5_m_v1_0_d90_sc4_s15_04515.npz
34
+ min5_m_v1_0_d90_sc4_s21_04522.npz
35
+ min5_m_v1_0_d90_sc4_s52_04556.npz
36
+ min5_m_v1_0_d90_sc4_s61_04566.npz
37
+ min5_m_v1_0_d90_sc4_s72_04578.npz
38
+ min5_m_v1_0_d90_sc4_s83_04590.npz
39
+ min5_m_v1_0_d90_sc4_s87_04594.npz
40
+ min5_m_v1_0_d90_sc4_s93_04601.npz
41
+ min5_m_v1_0_d90_sc6_s15_04613.npz
42
+ min5_m_v1_0_d90_sc6_s21_04620.npz
43
+ min5_m_v1_0_d90_sc6_s52_04654.npz
44
+ min5_m_v1_0_d90_sc6_s61_04664.npz
45
+ min5_m_v1_0_d90_sc6_s72_04676.npz
46
+ min5_m_v1_0_d90_sc6_s83_04688.npz
47
+ min5_m_v1_0_d90_sc6_s87_04692.npz
48
+ min5_m_v1_0_d90_sc6_s93_04699.npz
49
+ min5_m_v2_0_d270_sc2_s15_08827.npz
50
+ min5_m_v2_0_d270_sc2_s21_08834.npz
51
+ min5_m_v2_0_d270_sc2_s52_08868.npz
52
+ min5_m_v2_0_d270_sc2_s61_08878.npz
53
+ min5_m_v2_0_d270_sc2_s72_08890.npz
54
+ min5_m_v2_0_d270_sc2_s83_08902.npz
55
+ min5_m_v2_0_d270_sc2_s87_08906.npz
56
+ min5_m_v2_0_d270_sc2_s93_08913.npz
57
+ min5_m_v2_0_d270_sc4_s15_08925.npz
58
+ min5_m_v2_0_d270_sc4_s21_08932.npz
59
+ min5_m_v2_0_d270_sc4_s52_08966.npz
60
+ min5_m_v2_0_d270_sc4_s61_08976.npz
61
+ min5_m_v2_0_d270_sc4_s72_08988.npz
62
+ min5_m_v2_0_d270_sc4_s83_09000.npz
63
+ min5_m_v2_0_d270_sc4_s87_09004.npz
64
+ min5_m_v2_0_d270_sc4_s93_09011.npz
65
+ min5_m_v2_0_d270_sc6_s15_09023.npz
66
+ min5_m_v2_0_d270_sc6_s21_09030.npz
67
+ min5_m_v2_0_d270_sc6_s52_09064.npz
68
+ min5_m_v2_0_d270_sc6_s61_09074.npz
69
+ min5_m_v2_0_d270_sc6_s72_09086.npz
70
+ min5_m_v2_0_d270_sc6_s83_09098.npz
71
+ min5_m_v2_0_d270_sc6_s87_09102.npz
72
+ min5_m_v2_0_d270_sc6_s93_09109.npz
73
+ min5_m_v2_0_d90_sc2_s15_09121.npz
74
+ min5_m_v2_0_d90_sc2_s21_09128.npz
75
+ min5_m_v2_0_d90_sc2_s52_09162.npz
76
+ min5_m_v2_0_d90_sc2_s61_09172.npz
77
+ min5_m_v2_0_d90_sc2_s72_09184.npz
78
+ min5_m_v2_0_d90_sc2_s83_09196.npz
79
+ min5_m_v2_0_d90_sc2_s87_09200.npz
80
+ min5_m_v2_0_d90_sc2_s93_09207.npz
81
+ min5_m_v2_0_d90_sc4_s15_09219.npz
82
+ min5_m_v2_0_d90_sc4_s21_09226.npz
83
+ min5_m_v2_0_d90_sc4_s52_09260.npz
84
+ min5_m_v2_0_d90_sc4_s61_09270.npz
85
+ min5_m_v2_0_d90_sc4_s72_09282.npz
86
+ min5_m_v2_0_d90_sc4_s83_09294.npz
87
+ min5_m_v2_0_d90_sc4_s87_09298.npz
88
+ min5_m_v2_0_d90_sc4_s93_09305.npz
89
+ min5_m_v2_0_d90_sc6_s15_09317.npz
90
+ min5_m_v2_0_d90_sc6_s21_09324.npz
91
+ min5_m_v2_0_d90_sc6_s52_09358.npz
92
+ min5_m_v2_0_d90_sc6_s61_09368.npz
93
+ min5_m_v2_0_d90_sc6_s72_09380.npz
94
+ min5_m_v2_0_d90_sc6_s83_09392.npz
95
+ min5_m_v2_0_d90_sc6_s87_09396.npz
96
+ min5_m_v2_0_d90_sc6_s93_09403.npz
97
+ min5_m_v3_0_d270_sc2_s15_13531.npz
98
+ min5_m_v3_0_d270_sc2_s21_13538.npz
99
+ min5_m_v3_0_d270_sc2_s52_13572.npz
100
+ min5_m_v3_0_d270_sc2_s61_13582.npz
101
+ min5_m_v3_0_d270_sc2_s72_13594.npz
102
+ min5_m_v3_0_d270_sc2_s83_13606.npz
103
+ min5_m_v3_0_d270_sc2_s87_13610.npz
104
+ min5_m_v3_0_d270_sc2_s93_13617.npz
105
+ min5_m_v3_0_d270_sc4_s15_13629.npz
106
+ min5_m_v3_0_d270_sc4_s21_13636.npz
107
+ min5_m_v3_0_d270_sc4_s52_13670.npz
108
+ min5_m_v3_0_d270_sc4_s61_13680.npz
109
+ min5_m_v3_0_d270_sc4_s72_13692.npz
110
+ min5_m_v3_0_d270_sc4_s83_13704.npz
111
+ min5_m_v3_0_d270_sc4_s87_13708.npz
112
+ min5_m_v3_0_d270_sc4_s93_13715.npz
113
+ min5_m_v3_0_d270_sc6_s15_13727.npz
114
+ min5_m_v3_0_d270_sc6_s21_13734.npz
115
+ min5_m_v3_0_d270_sc6_s52_13768.npz
116
+ min5_m_v3_0_d270_sc6_s61_13778.npz
117
+ min5_m_v3_0_d270_sc6_s72_13790.npz
118
+ min5_m_v3_0_d270_sc6_s83_13802.npz
119
+ min5_m_v3_0_d270_sc6_s87_13806.npz
120
+ min5_m_v3_0_d270_sc6_s93_13813.npz
121
+ min5_m_v3_0_d90_sc2_s15_13825.npz
122
+ min5_m_v3_0_d90_sc2_s21_13832.npz
123
+ min5_m_v3_0_d90_sc2_s52_13866.npz
124
+ min5_m_v3_0_d90_sc2_s61_13876.npz
125
+ min5_m_v3_0_d90_sc2_s72_13888.npz
126
+ min5_m_v3_0_d90_sc2_s83_13900.npz
127
+ min5_m_v3_0_d90_sc2_s87_13904.npz
128
+ min5_m_v3_0_d90_sc2_s93_13911.npz
129
+ min5_m_v3_0_d90_sc4_s15_13923.npz
130
+ min5_m_v3_0_d90_sc4_s21_13930.npz
131
+ min5_m_v3_0_d90_sc4_s52_13964.npz
132
+ min5_m_v3_0_d90_sc4_s61_13974.npz
133
+ min5_m_v3_0_d90_sc4_s72_13986.npz
134
+ min5_m_v3_0_d90_sc4_s83_13998.npz
135
+ min5_m_v3_0_d90_sc4_s87_14002.npz
136
+ min5_m_v3_0_d90_sc4_s93_14009.npz
137
+ min5_m_v3_0_d90_sc6_s15_14021.npz
138
+ min5_m_v3_0_d90_sc6_s21_14028.npz
139
+ min5_m_v3_0_d90_sc6_s52_14062.npz
140
+ min5_m_v3_0_d90_sc6_s61_14072.npz
141
+ min5_m_v3_0_d90_sc6_s72_14084.npz
142
+ min5_m_v3_0_d90_sc6_s83_14096.npz
143
+ min5_m_v3_0_d90_sc6_s87_14100.npz
144
+ min5_m_v3_0_d90_sc6_s93_14107.npz
145
+ min5_m_v4_0_d270_sc2_s15_18235.npz
146
+ min5_m_v4_0_d270_sc2_s21_18242.npz
147
+ min5_m_v4_0_d270_sc2_s52_18276.npz
148
+ min5_m_v4_0_d270_sc2_s61_18286.npz
149
+ min5_m_v4_0_d270_sc2_s72_18298.npz
150
+ min5_m_v4_0_d270_sc2_s83_18310.npz
151
+ min5_m_v4_0_d270_sc2_s87_18314.npz
152
+ min5_m_v4_0_d270_sc2_s93_18321.npz
153
+ min5_m_v4_0_d270_sc4_s15_18333.npz
154
+ min5_m_v4_0_d270_sc4_s21_18340.npz
155
+ min5_m_v4_0_d270_sc4_s52_18374.npz
156
+ min5_m_v4_0_d270_sc4_s61_18384.npz
157
+ min5_m_v4_0_d270_sc4_s72_18396.npz
158
+ min5_m_v4_0_d270_sc4_s83_18408.npz
159
+ min5_m_v4_0_d270_sc4_s87_18412.npz
160
+ min5_m_v4_0_d270_sc4_s93_18419.npz
161
+ min5_m_v4_0_d270_sc6_s15_18431.npz
162
+ min5_m_v4_0_d270_sc6_s21_18438.npz
163
+ min5_m_v4_0_d270_sc6_s52_18472.npz
164
+ min5_m_v4_0_d270_sc6_s61_18482.npz
165
+ min5_m_v4_0_d270_sc6_s72_18494.npz
166
+ min5_m_v4_0_d270_sc6_s83_18506.npz
167
+ min5_m_v4_0_d270_sc6_s87_18510.npz
168
+ min5_m_v4_0_d270_sc6_s93_18517.npz
169
+ min5_m_v4_0_d90_sc2_s15_18529.npz
170
+ min5_m_v4_0_d90_sc2_s21_18536.npz
171
+ min5_m_v4_0_d90_sc2_s52_18570.npz
172
+ min5_m_v4_0_d90_sc2_s61_18580.npz
173
+ min5_m_v4_0_d90_sc2_s72_18592.npz
174
+ min5_m_v4_0_d90_sc2_s83_18604.npz
175
+ min5_m_v4_0_d90_sc2_s87_18608.npz
176
+ min5_m_v4_0_d90_sc2_s93_18615.npz
177
+ min5_m_v4_0_d90_sc4_s15_18627.npz
178
+ min5_m_v4_0_d90_sc4_s21_18634.npz
179
+ min5_m_v4_0_d90_sc4_s52_18668.npz
180
+ min5_m_v4_0_d90_sc4_s61_18678.npz
181
+ min5_m_v4_0_d90_sc4_s72_18690.npz
182
+ min5_m_v4_0_d90_sc4_s83_18702.npz
183
+ min5_m_v4_0_d90_sc4_s87_18706.npz
184
+ min5_m_v4_0_d90_sc4_s93_18713.npz
185
+ min5_m_v4_0_d90_sc6_s15_18725.npz
186
+ min5_m_v4_0_d90_sc6_s21_18732.npz
187
+ min5_m_v4_0_d90_sc6_s52_18766.npz
188
+ min5_m_v4_0_d90_sc6_s61_18776.npz
189
+ min5_m_v4_0_d90_sc6_s72_18788.npz
190
+ min5_m_v4_0_d90_sc6_s83_18800.npz
191
+ min5_m_v4_0_d90_sc6_s87_18804.npz
192
+ min5_m_v4_0_d90_sc6_s93_18811.npz
193
+ min5_m_v5_0_d270_sc2_s15_22939.npz
194
+ min5_m_v5_0_d270_sc2_s21_22946.npz
195
+ min5_m_v5_0_d270_sc2_s52_22980.npz
196
+ min5_m_v5_0_d270_sc2_s61_22990.npz
197
+ min5_m_v5_0_d270_sc2_s72_23002.npz
198
+ min5_m_v5_0_d270_sc2_s83_23014.npz
199
+ min5_m_v5_0_d270_sc2_s87_23018.npz
200
+ min5_m_v5_0_d270_sc2_s93_23025.npz
201
+ min5_m_v5_0_d270_sc4_s15_23037.npz
202
+ min5_m_v5_0_d270_sc4_s21_23044.npz
203
+ min5_m_v5_0_d270_sc4_s52_23078.npz
204
+ min5_m_v5_0_d270_sc4_s61_23088.npz
205
+ min5_m_v5_0_d270_sc4_s72_23100.npz
206
+ min5_m_v5_0_d270_sc4_s83_23112.npz
207
+ min5_m_v5_0_d270_sc4_s87_23116.npz
208
+ min5_m_v5_0_d270_sc4_s93_23123.npz
209
+ min5_m_v5_0_d270_sc6_s15_23135.npz
210
+ min5_m_v5_0_d270_sc6_s21_23142.npz
211
+ min5_m_v5_0_d270_sc6_s52_23176.npz
212
+ min5_m_v5_0_d270_sc6_s61_23186.npz
213
+ min5_m_v5_0_d270_sc6_s72_23198.npz
214
+ min5_m_v5_0_d270_sc6_s83_23210.npz
215
+ min5_m_v5_0_d270_sc6_s87_23214.npz
216
+ min5_m_v5_0_d270_sc6_s93_23221.npz
217
+ min5_m_v5_0_d90_sc2_s15_23233.npz
218
+ min5_m_v5_0_d90_sc2_s21_23240.npz
219
+ min5_m_v5_0_d90_sc2_s52_23274.npz
220
+ min5_m_v5_0_d90_sc2_s61_23284.npz
221
+ min5_m_v5_0_d90_sc2_s72_23296.npz
222
+ min5_m_v5_0_d90_sc2_s83_23308.npz
223
+ min5_m_v5_0_d90_sc2_s87_23312.npz
224
+ min5_m_v5_0_d90_sc2_s93_23319.npz
225
+ min5_m_v5_0_d90_sc4_s15_23331.npz
226
+ min5_m_v5_0_d90_sc4_s21_23338.npz
227
+ min5_m_v5_0_d90_sc4_s52_23372.npz
228
+ min5_m_v5_0_d90_sc4_s61_23382.npz
229
+ min5_m_v5_0_d90_sc4_s72_23394.npz
230
+ min5_m_v5_0_d90_sc4_s83_23406.npz
231
+ min5_m_v5_0_d90_sc4_s87_23410.npz
232
+ min5_m_v5_0_d90_sc4_s93_23417.npz
233
+ min5_m_v5_0_d90_sc6_s15_23429.npz
234
+ min5_m_v5_0_d90_sc6_s21_23436.npz
235
+ min5_m_v5_0_d90_sc6_s52_23470.npz
236
+ min5_m_v5_0_d90_sc6_s61_23480.npz
237
+ min5_m_v5_0_d90_sc6_s72_23492.npz
238
+ min5_m_v5_0_d90_sc6_s83_23504.npz
239
+ min5_m_v5_0_d90_sc6_s87_23508.npz
240
+ min5_m_v5_0_d90_sc6_s93_23515.npz
dataset/conditioned_results_v0_5_d45_n40.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3b1ad22eb1731f16f69c468666acda2df38d65f9c9b3c19804c6f5a5f50c47ed
3
+ size 2516841027
dataset/min5_m_v1_0_d270_sc2_s10_04118.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:21ef8d148671d73c06742cfcf61ffc48d4590fbc70f55f91edab628252160705
3
+ size 2523009
function.py ADDED
@@ -0,0 +1,1945 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import numpy as np
3
+ import os
4
+ from scipy import stats
5
+ from scipy.spatial import cKDTree
6
+ from scipy.ndimage import binary_dilation
7
+ from sklearn.metrics import r2_score, mean_squared_error, mean_absolute_error
8
+ from skimage.metrics import structural_similarity as ssim
9
+ from tqdm import trange
10
+
11
+ import matplotlib.pyplot as plt
12
+ import matplotlib.ticker as ticker
13
+ import matplotlib.cm as cm
14
+ import matplotlib
15
+ matplotlib.rcParams['font.sans-serif'] = ['Arial']
16
+ matplotlib.rcParams['font.size'] = 16
17
+
18
+
19
+
20
+ class DataLoader:
21
+ '''
22
+ Fucntions:
23
+ 1. load_predictions: 从 npz 文件加载预测和真实浓度场
24
+ 2. load_metadata: 从 meta txt 文件加载元信息(风速、风向、稳定度、源编号等)
25
+ 3. load_conds_data: 从 pkl 文件加载条件预测数据(如果有)
26
+ 4. log2ppm: 将 log 浓度转换为 ppm 浓度(根据给定的关系)
27
+ 5. get_sample: 根据索引获取单个样本的预测场、真实场、条件预测和元信息
28
+ '''
29
+ def __init__(self, pred_npz_path, meta_txt_path, conds_pkl_path):
30
+ self.pred_npz_path = pred_npz_path
31
+ self.meta_txt_path = meta_txt_path
32
+ self.conds_pkl_path = conds_pkl_path
33
+
34
+ self.load_predictions()
35
+ self.meta = self.load_metadata()
36
+ self.conds_data = self.load_conds_data()
37
+
38
+ def load_predictions(self):
39
+ data = np.load(self.pred_npz_path)
40
+ self.preds = data['preds'].squeeze(1)
41
+ self.trues = data['trues'].squeeze(1)
42
+ _, non_building_mask = PrintMetrics.get_building_area()
43
+ self.preds = self.preds * non_building_mask[None, :, :]
44
+ self.trues = self.trues * non_building_mask[None, :, :]
45
+ return self.preds, self.trues
46
+
47
+ def load_metadata(self):
48
+ df = pd.read_csv(self.meta_txt_path, sep=',', header=None)
49
+ df.columns = ['npz_colname']
50
+ pattern = r'v([0-9_]+)_d(\d+)_sc(\d+)_s(\d+)'
51
+ df[['wind_speed', 'wind_direction', 'sc', 'source_number']] = (
52
+ df['npz_colname'].str.extract(pattern))
53
+ df['wind_speed'] = df['wind_speed'].str.replace('_', '.').astype(float)
54
+ df[['wind_direction', 'sc', 'source_number']] = df[['wind_direction', 'sc',
55
+ 'source_number']].astype(int)
56
+ return df
57
+
58
+ def load_conds_data(self):
59
+ conds_data = np.load(self.conds_pkl_path, allow_pickle=True)
60
+ return conds_data
61
+
62
+ @staticmethod
63
+ def log2ppm(log_conc):
64
+ log_conc = np.asarray(log_conc)
65
+ log_conc = np.minimum(log_conc, 15.0)
66
+ ppm_conc = np.expm1(log_conc) * (0.7449)
67
+ return np.maximum(ppm_conc, 0.0)
68
+
69
+ def get_sample(self, idx, in_ppm=True):
70
+ psi_f = self.preds[idx]
71
+ psi_t = self.trues[idx]
72
+ meta = self.meta.iloc[idx]
73
+ conds_preds = self.conds_data[idx]['conds']['preds']
74
+ if in_ppm:
75
+ psi_f = DataLoader.log2ppm(psi_f)
76
+ psi_t = DataLoader.log2ppm(psi_t)
77
+ conds_preds = DataLoader.log2ppm(conds_preds)
78
+ return psi_f, psi_t, conds_preds, meta
79
+
80
+ class ObservationModel:
81
+ '''
82
+ Functions:
83
+ 1. observation_operator_H: 从浓度场 ψ 中提取点位浓度,使用双线性插值(线性算子)
84
+ 2. observation_operator_H_ens: 对 ensemble 预测场批量应用观测算子,得到每个成员的点位浓度
85
+ '''
86
+ @staticmethod # 不依赖实例状态,可以直接通过类调用
87
+ def observation_operator_H(psi, obs_xy):
88
+ # 观测算子 M:
89
+ # 从浓度场 ψ 中提取点位浓度
90
+ # 使用双线性插值(线性算子)
91
+ Hh, Ww = psi.shape
92
+ xs = np.clip(obs_xy[:, 0], 0, Ww - 1 - 1e-6)
93
+ ys = np.clip(obs_xy[:, 1], 0, Hh - 1 - 1e-6)
94
+ x0 = np.floor(xs).astype(int)
95
+ y0 = np.floor(ys).astype(int)
96
+ x1 = np.clip(x0 + 1, 0, Ww - 1)
97
+ y1 = np.clip(y0 + 1, 0, Hh - 1)
98
+ dx = xs - x0
99
+ dy = ys - y0
100
+ f00 = psi[y0, x0]
101
+ f10 = psi[y0, x1]
102
+ f01 = psi[y1, x0]
103
+ f11 = psi[y1, x1]
104
+
105
+ return (
106
+ f00 * (1 - dx) * (1 - dy) +
107
+ f10 * dx * (1 - dy) +
108
+ f01 * (1 - dx) * dy +
109
+ f11 * dx * dy
110
+ )
111
+
112
+ @staticmethod
113
+ def observation_operator_H_ens(psi_ens, obs_xy):
114
+ """
115
+ psi_ens: (N_ens, H, W)
116
+ obs_xy : (n_obs, 2)
117
+ return : HX (N_ens, n_obs)
118
+ """
119
+ N_ens, Hh, Ww = psi_ens.shape
120
+ xs = np.clip(obs_xy[:, 0], 0, Ww - 1 - 1e-6)
121
+ ys = np.clip(obs_xy[:, 1], 0, Hh - 1 - 1e-6)
122
+ x0 = np.floor(xs).astype(np.int64)
123
+ y0 = np.floor(ys).astype(np.int64)
124
+ x1 = np.clip(x0 + 1, 0, Ww - 1)
125
+ y1 = np.clip(y0 + 1, 0, Hh - 1)
126
+ dx = xs - x0
127
+ dy = ys - y0
128
+ f00 = psi_ens[:, y0, x0]
129
+ f10 = psi_ens[:, y0, x1]
130
+ f01 = psi_ens[:, y1, x0]
131
+ f11 = psi_ens[:, y1, x1]
132
+
133
+ HX = (
134
+ f00 * (1 - dx) * (1 - dy) +
135
+ f10 * dx * (1 - dy) +
136
+ f01 * (1 - dx) * dy +
137
+ f11 * dx * dy
138
+ )
139
+ return HX
140
+
141
+ class SamplingStrategies:
142
+
143
+ # =========================
144
+ # (1) Sampling strategies
145
+ # =========================
146
+ @staticmethod
147
+ def sample_random(field_shape, num_points, seed=42):
148
+ rng = np.random.default_rng(seed)
149
+ H, W = field_shape
150
+ _, non_building_mask = PrintMetrics.get_building_area()
151
+ valid_idx = np.where(non_building_mask.ravel())[0]
152
+ chosen = rng.choice(valid_idx, size=num_points, replace=False)
153
+ yy, xx = np.meshgrid(np.arange(H), np.arange(W), indexing="ij")
154
+ coords = np.stack([xx.ravel(), yy.ravel()], axis=1)
155
+ return coords[chosen].astype(float)
156
+
157
+ @staticmethod
158
+ def sample_uniform(field_shape, num_points, margin=20):
159
+ H, W = field_shape
160
+ nx = int(np.ceil(np.sqrt(num_points * W / H)))
161
+ ny = int(np.ceil(num_points / nx))
162
+ xs = np.linspace(margin, W - 1 - margin, nx)
163
+ ys = np.linspace(margin, H - 1 - margin, ny)
164
+ xx, yy = np.meshgrid(xs, ys)
165
+ grid_xy = np.stack([xx.ravel(), yy.ravel()], axis=1)
166
+ _, non_building_mask = PrintMetrics.get_building_area()
167
+ xi = np.clip(grid_xy[:, 0].astype(int), 0, W - 1)
168
+ yi = np.clip(grid_xy[:, 1].astype(int), 0, H - 1)
169
+ valid = non_building_mask[yi, xi] == 1
170
+ grid_xy = grid_xy[valid]
171
+ if len(grid_xy) > num_points:
172
+ idx = np.linspace(0, len(grid_xy) - 1, num_points).astype(int)
173
+ grid_xy = grid_xy[idx]
174
+ return grid_xy
175
+
176
+ @staticmethod
177
+ def two_stage_sampling(
178
+ true_field,
179
+ pred_field,
180
+ num_points,
181
+ ens_preds_ppm=None,
182
+ seed=42,
183
+
184
+ # ====== 全局控制 ======
185
+ min_dist=22,
186
+ n1_ratio=0.65, # Stage1 比例
187
+
188
+ # ====== Stage1 可调参数 =====
189
+ stage1_grad_power=0.8, # 梯度权重幂次
190
+ stage1_value_power=1.2, # 值权重幂次
191
+ stage1_center_boost=1.2, # 是否增强高值区域
192
+ ):
193
+
194
+ # 内部函数: 基于排斥采样的加权随机选择
195
+ def repulse_pick(candidate_idx, weights, k, selected_idx):
196
+ if k <= 0 or len(candidate_idx) == 0:
197
+ return list(selected_idx)
198
+ candidate_idx = np.asarray(candidate_idx, dtype=np.int64)
199
+ weights = np.maximum(np.asarray(weights, dtype=float), 0.0)
200
+ if weights.sum() <= 0:
201
+ weights = np.ones_like(weights)
202
+ weights = weights / weights.sum()
203
+ overs = min(len(candidate_idx), max(k * 15, 200))
204
+ cand = rng.choice(candidate_idx, size=overs, replace=False, p=weights)
205
+ selected = list(selected_idx)
206
+ for idx in cand:
207
+ xy = coords[idx]
208
+ if not selected:
209
+ selected.append(idx)
210
+ continue
211
+ sel_xy = coords[np.asarray(selected)]
212
+ if cKDTree(sel_xy).query(xy, k=1)[0] >= min_dist:
213
+ selected.append(idx)
214
+ if len(selected) >= k + len(selected_idx):
215
+ break
216
+ return selected
217
+
218
+ rng = np.random.default_rng(seed)
219
+ H, W = pred_field.shape
220
+ yy, xx = np.meshgrid(np.arange(H), np.arange(W), indexing="ij")
221
+ coords = np.stack([xx.ravel(), yy.ravel()], axis=1) # 二维坐标网格 (HW, 2) [x,y]
222
+ v = np.maximum(pred_field, 0.0).ravel()
223
+ vmax = float(v.max())
224
+ _, non_building_mask = PrintMetrics.get_building_area()
225
+ non_building_flat = non_building_mask.ravel().astype(bool)
226
+
227
+ if vmax <= 1e-6:
228
+ valid_idx = np.where(non_building_flat)[0]
229
+ idx = rng.choice(valid_idx, size=num_points, replace=False)
230
+ obs_xy = coords[idx]
231
+ obs_val = ObservationModel.observation_operator_H(true_field, obs_xy)
232
+ return obs_xy, obs_val
233
+
234
+ n_center_ratio= 1 - n1_ratio
235
+ n_center = max(2, int(num_points * n_center_ratio))
236
+ n1 = num_points - n_center
237
+
238
+ # 1. 结构修正(LOG)
239
+ z = np.log1p(np.maximum(pred_field, 0.0))
240
+ z_flat = z.ravel()
241
+ gx, gy = np.gradient(z)
242
+ grad = np.sqrt(gx**2 + gy**2).ravel()
243
+ nz = z_flat > 1e-6 # 只在非零区域上算分位数,避免大量0把lo/hi压塌
244
+ z_nz = z_flat[nz]
245
+ if z_nz.size < 50:
246
+ support_mask = (grad.reshape(H, W) > np.quantile(grad, 0.90))
247
+ else:
248
+ lo = np.quantile(z_nz, 0.70)
249
+ hi = np.quantile(z_nz, 0.90)
250
+ core_mask = (z >= lo) & (z <= hi)
251
+ r_out = 12 # 外扩:把结构带膨胀一��,让采样不只盯着最强边界
252
+ support_mask = binary_dilation(core_mask, iterations=r_out)
253
+
254
+ support_idx = np.where(support_mask.ravel() & non_building_flat)[0]
255
+ if len(support_idx) < num_points * 2:
256
+ support_idx = np.where(non_building_flat)[0]
257
+
258
+ # Stage1:梯度主导 + 适度保留外圈
259
+ weights1 = (
260
+ (grad[support_idx] ** stage1_grad_power) *
261
+ (z_flat[support_idx] ** stage1_value_power + 1e-6)
262
+ )
263
+ if stage1_center_boost > 1.0:
264
+ weights1 *= (1 + stage1_center_boost * (z_flat[support_idx] / z_flat.max()))
265
+ selected = repulse_pick(support_idx, weights1, n1, [])
266
+
267
+ # Stage 2: 中心峰值区,只取2-3个点
268
+ peak_idx = np.argmax(z_flat * non_building_flat.astype(float))
269
+ peak_xy = coords[peak_idx]
270
+ # print(f"峰值位置: {peak_xy}, z值: {z_flat[peak_idx]:.3f}")
271
+ selected.append(int(peak_idx)) # 直接把峰值点加进去(1个)
272
+
273
+ # 再在峰值极近邻选1-2个,min_dist放松到5保证不重叠
274
+ if n_center > 1:
275
+ peak_radius = 10 # 很小的半径,只捕捉最高点附近
276
+ stage2_idx = np.where(
277
+ non_building_flat &
278
+ (np.sqrt((coords[:, 0] - peak_xy[0])**2 +
279
+ (coords[:, 1] - peak_xy[1])**2) <= peak_radius)
280
+ )[0]
281
+ stage2_idx = np.setdiff1d(stage2_idx, np.array(selected))
282
+
283
+ if len(stage2_idx) >= 1:
284
+ weights2 = z_flat[stage2_idx]
285
+ weights2 = weights2 / (weights2.sum() + 1e-12)
286
+ overs = min(len(stage2_idx), max((n_center - 1) * 10, 20))
287
+ cands = rng.choice(stage2_idx, size=overs, replace=False, p=weights2)
288
+ for idx in cands:
289
+ xy = coords[idx]
290
+ if cKDTree(coords[np.array(selected)]).query(xy, k=1)[0] >= 5:
291
+ selected.append(int(idx))
292
+ if len(selected) >= n_center + len([]): # 只加到n_center个为止
293
+ break
294
+ if len(selected) - (num_points - n_center) >= n_center:
295
+ break
296
+ selected = list(dict.fromkeys(selected))
297
+
298
+ # 补足剩余点(从Stage1结构带里再补,如果selected不够num_points)
299
+ if len(selected) < num_points:
300
+ remain = np.setdiff1d(support_idx, np.array(selected))
301
+ if len(remain) > 0:
302
+ w_remain = (
303
+ (grad[remain] ** stage1_grad_power) *
304
+ (z_flat[remain] ** stage1_value_power + 1e-6)
305
+ )
306
+ extra = repulse_pick(remain, w_remain,
307
+ num_points - len(selected), selected)
308
+ selected = extra
309
+
310
+ obs_xy = coords[np.array(selected[:num_points])]
311
+ obs_val = ObservationModel.observation_operator_H(true_field, obs_xy)
312
+
313
+ return obs_xy, obs_val
314
+
315
+ @staticmethod
316
+ def two_stage_pro(
317
+ true_field,
318
+ pred_field,
319
+ num_points,
320
+ ens_preds_ppm=None,
321
+ seed=42,
322
+ min_dist=22,
323
+ n1_ratio=0.65,
324
+ stage1_grad_power=0.8,
325
+ stage1_value_power=1.2,
326
+ stage1_center_boost=1.2,
327
+ ):
328
+ import numpy as np
329
+ from scipy.spatial import cKDTree
330
+ from scipy.ndimage import binary_dilation
331
+
332
+ def repulse_pick(candidate_idx, weights, k, selected_idx, this_min_dist):
333
+ if k <= 0 or len(candidate_idx) == 0:
334
+ return list(selected_idx)
335
+
336
+ candidate_idx = np.asarray(candidate_idx, dtype=np.int64)
337
+ weights = np.maximum(np.asarray(weights, dtype=float), 0.0)
338
+
339
+ if weights.sum() <= 0:
340
+ weights = np.ones_like(weights, dtype=float)
341
+
342
+ weights = weights / weights.sum()
343
+
344
+ overs = min(len(candidate_idx), max(k * 15, 200))
345
+ cand = rng.choice(candidate_idx, size=overs, replace=False, p=weights)
346
+
347
+ selected = list(selected_idx)
348
+ for idx in cand:
349
+ idx = int(idx)
350
+ if idx in selected:
351
+ continue
352
+
353
+ xy = coords[idx]
354
+ if not selected:
355
+ selected.append(idx)
356
+ continue
357
+
358
+ sel_xy = coords[np.asarray(selected)]
359
+ if cKDTree(sel_xy).query(xy, k=1)[0] >= this_min_dist:
360
+ selected.append(idx)
361
+
362
+ if len(selected) >= k + len(selected_idx):
363
+ break
364
+
365
+ return selected
366
+
367
+ rng = np.random.default_rng(seed)
368
+ H, W = pred_field.shape
369
+ yy, xx = np.meshgrid(np.arange(H), np.arange(W), indexing="ij")
370
+ coords = np.stack([xx.ravel(), yy.ravel()], axis=1) # (HW, 2), [x, y]
371
+
372
+ v = np.maximum(pred_field, 0.0).ravel()
373
+ vmax = float(v.max())
374
+
375
+ _, non_building_mask = PrintMetrics.get_building_area()
376
+ non_building_flat = non_building_mask.ravel().astype(bool)
377
+
378
+ if vmax <= 1e-6:
379
+ valid_idx = np.where(non_building_flat)[0]
380
+ idx = rng.choice(valid_idx, size=min(num_points, len(valid_idx)), replace=False)
381
+
382
+ if len(idx) < num_points:
383
+ raise ValueError(f"Not enough valid non-building points: need {num_points}, got {len(idx)}")
384
+
385
+ obs_xy = coords[idx]
386
+ obs_val = ObservationModel.observation_operator_H(true_field, obs_xy)
387
+ return obs_xy, obs_val
388
+
389
+ n_center_ratio = 1 - n1_ratio
390
+ n_center = max(2, int(num_points * n_center_ratio))
391
+ n1 = num_points - n_center
392
+
393
+ z = np.log1p(np.maximum(pred_field, 0.0))
394
+ z_flat = z.ravel()
395
+ gx, gy = np.gradient(z)
396
+ grad = np.sqrt(gx**2 + gy**2).ravel()
397
+ nz = z_flat > 1e-6
398
+ z_nz = z_flat[nz]
399
+ if z_nz.size < 50:
400
+ support_mask = (grad.reshape(H, W) > np.quantile(grad, 0.90))
401
+ else:
402
+ lo = np.quantile(z_nz, 0.70)
403
+ hi = np.quantile(z_nz, 0.90)
404
+ core_mask = (z >= lo) & (z <= hi)
405
+ # r_out = int(0.60 * num_points)
406
+ r_out = int(np.clip(num_points / 3 + 16 / 3, 12, 24)) # 根据 num_points 动态调整外扩半径,保持在12-24范围内
407
+ support_mask = binary_dilation(core_mask, iterations=r_out)
408
+
409
+ support_idx = np.where(support_mask.ravel() & non_building_flat)[0]
410
+ if len(support_idx) < num_points * 2:
411
+ support_idx = np.where(non_building_flat)[0]
412
+
413
+ weights1 = (
414
+ (grad[support_idx] ** stage1_grad_power) *
415
+ (z_flat[support_idx] ** stage1_value_power + 1e-6)
416
+ )
417
+
418
+ if stage1_center_boost > 1.0:
419
+ weights1 *= (1 + stage1_center_boost * (z_flat[support_idx] / (z_flat.max() + 1e-12)))
420
+ selected = repulse_pick(support_idx, weights1, n1, [], min_dist)
421
+
422
+ peak_idx = int(np.argmax(z_flat * non_building_flat.astype(float)))
423
+ peak_xy = coords[peak_idx]
424
+ selected.append(int(peak_idx))
425
+
426
+ if n_center > 1:
427
+ peak_radius = 10
428
+ stage2_idx = np.where(
429
+ non_building_flat &
430
+ (np.sqrt((coords[:, 0] - peak_xy[0]) ** 2 +
431
+ (coords[:, 1] - peak_xy[1]) ** 2) <= peak_radius)
432
+ )[0]
433
+ stage2_idx = np.setdiff1d(stage2_idx, np.array(selected))
434
+
435
+ if len(stage2_idx) >= 1:
436
+ weights2 = z_flat[stage2_idx]
437
+ weights2 = weights2 / (weights2.sum() + 1e-12)
438
+
439
+ overs = min(len(stage2_idx), max((n_center - 1) * 10, 20))
440
+ cands = rng.choice(stage2_idx, size=overs, replace=False, p=weights2)
441
+
442
+ for idx in cands:
443
+ idx = int(idx)
444
+ xy = coords[idx]
445
+ if cKDTree(coords[np.array(selected)]).query(xy, k=1)[0] >= 5:
446
+ selected.append(idx)
447
+
448
+ if len(selected) >= n_center + len([]):
449
+ break
450
+ if len(selected) - (num_points - n_center) >= n_center:
451
+ break
452
+
453
+ selected = list(dict.fromkeys(selected))
454
+
455
+ if len(selected) >= num_points:
456
+ selected = selected[:num_points]
457
+ obs_xy = coords[np.array(selected)]
458
+ obs_val = ObservationModel.observation_operator_H(true_field, obs_xy)
459
+ return obs_xy, obs_val
460
+
461
+ remain = np.setdiff1d(support_idx, np.array(selected))
462
+ if len(remain) > 0:
463
+ for d_try in [
464
+ min_dist,
465
+ max(1, int(min_dist * 0.7)),
466
+ max(1, int(min_dist * 0.4)),
467
+ 5,
468
+ 3
469
+ ]:
470
+ if len(selected) >= num_points:
471
+ break
472
+
473
+ remain = np.setdiff1d(support_idx, np.array(selected))
474
+ if len(remain) == 0:
475
+ break
476
+
477
+ w_remain = (
478
+ (grad[remain] ** stage1_grad_power) *
479
+ (z_flat[remain] ** stage1_value_power + 1e-6)
480
+ )
481
+
482
+ selected = repulse_pick(
483
+ remain,
484
+ w_remain,
485
+ num_points - len(selected),
486
+ selected,
487
+ d_try
488
+ )
489
+
490
+ selected = list(dict.fromkeys(selected))
491
+
492
+ if len(selected) < num_points:
493
+ remain_support = np.setdiff1d(support_idx, np.array(selected))
494
+
495
+ if len(remain_support) > 0:
496
+ need = num_points - len(selected)
497
+ extra = rng.choice(
498
+ remain_support,
499
+ size=min(need, len(remain_support)),
500
+ replace=False
501
+ )
502
+ selected.extend(extra.tolist())
503
+
504
+ selected = list(dict.fromkeys(selected))
505
+
506
+ if len(selected) < num_points:
507
+ all_valid = np.where(non_building_flat)[0]
508
+ remain_all = np.setdiff1d(all_valid, np.array(selected))
509
+
510
+ if len(remain_all) > 0:
511
+ need = num_points - len(selected)
512
+ extra = rng.choice(
513
+ remain_all,
514
+ size=min(need, len(remain_all)),
515
+ replace=False
516
+ )
517
+ selected.extend(extra.tolist())
518
+
519
+ selected = list(dict.fromkeys(selected))
520
+ selected = selected[:num_points]
521
+ assert len(selected) == num_points, f"Expected {num_points} points, got {len(selected)}"
522
+
523
+ obs_xy = coords[np.array(selected)]
524
+ obs_val = ObservationModel.observation_operator_H(true_field, obs_xy)
525
+
526
+ return obs_xy, obs_val
527
+
528
+ @staticmethod
529
+ def smart_two_pass(
530
+ enkf,
531
+ psi_f,
532
+ conds_preds,
533
+ true_field,
534
+ n1,
535
+ n2,
536
+ n_rounds=2,
537
+ phase1_method='two_stage',
538
+ min_dist_p2=22,
539
+ under_correct_alpha=1.5,
540
+ use_localization=False,
541
+ loc_radius_pixobs=35.0,
542
+ loc_radius_obsobs=40.0,
543
+ seed=42,
544
+ verbose=True,
545
+ ):
546
+ """
547
+ 多轮迭代选点 + EnKF 同化。
548
+
549
+ 每轮流程:
550
+ Phase 1 — 基于当前先验场选 n1 个点,做 pilot EnKF;
551
+ Phase 2 — 基于 pilot 残差找欠校正区,再选 n2 个点,做 final EnKF;
552
+ 本轮分析场作为下一轮的先验(psi_f)。
553
+
554
+ 参数:
555
+ enkf : EnKF 实例
556
+ psi_f : 初始先验场 (H, W)
557
+ conds_preds : 集合预测场 (N_ens, H, W),协方差来源,全程不变
558
+ true_field : 真值场 (H, W),仅用于观测值提取
559
+ n1 : 每轮 Phase-1 选点数
560
+ n2 : 每轮 Phase-2 选点数
561
+ n_rounds : 迭代轮数(默认 1,即原始两阶段行为)
562
+ phase1_method : Phase-1 采样策略('two_stage' 或其他 generate 支持的方法)
563
+ min_dist_p2 : Phase-2 选点与已有点的最小距离(像素)
564
+ under_correct_alpha : Phase-2 欠校正权重幂次
565
+ use_localization: 是否使用局地化 EnKF
566
+ loc_radius_pixobs / loc_radius_obsobs : 局地化半径
567
+ seed : 随机种子
568
+ verbose : 是否打印中间日志
569
+
570
+ 返回:
571
+ psi_a_final : 最终分析场 (H, W)
572
+ all_obs_xy : 所有轮次累计观测坐标 (n_rounds*(n1+n2), 2)
573
+ all_obs_val : 所有轮次累计观测值
574
+ psi_pilot : 最后一轮的 pilot(Phase-1)分析场
575
+ obs_xy_p1_last : 最后一轮 Phase-1 选点坐标
576
+ """
577
+ conds_preds = np.asarray(conds_preds)
578
+ N_ens, H, W = conds_preds.shape
579
+ rng = np.random.default_rng(seed)
580
+
581
+ _, non_building_mask = PrintMetrics.get_building_area()
582
+ non_building_flat = non_building_mask.ravel().astype(bool)
583
+ yy, xx = np.meshgrid(np.arange(H), np.arange(W), indexing="ij")
584
+ coords = np.stack([xx.ravel(), yy.ravel()], axis=1)
585
+
586
+ # 预先构建集合(整个函数中只用这一个 X_f,协方差永远基于原始集合)
587
+ ens_mean = np.mean(conds_preds, axis=0)
588
+ X_f_base = conds_preds - ens_mean[None, :, :] + psi_f[None, :, :]
589
+
590
+ # 累计所有轮次的观测(跨轮次积累,最终一次性返回)
591
+ all_obs_xy_list = []
592
+ all_obs_val_list = []
593
+
594
+ # 当前先验:第 1 轮用 psi_f,后续轮次用上一轮的分析场
595
+ psi_current = psi_f
596
+ psi_pilot = None
597
+ obs_xy_p1_last = None
598
+
599
+ for round_idx in range(n_rounds):
600
+ round_seed = seed + round_idx # 每轮不同种子,避免重复采样
601
+
602
+ # ── Phase 1:基于当前先验选点 + pilot EnKF ────────────────
603
+ if phase1_method == 'two_stage':
604
+ obs_xy_p1, obs_val_p1 = SamplingStrategies.two_stage_sampling(
605
+ true_field=true_field, pred_field=psi_current,
606
+ num_points=n1, seed=round_seed)
607
+ else:
608
+ obs_xy_p1, obs_val_p1 = SamplingStrategies.generate(
609
+ true_field, psi_current, n1, method=phase1_method, seed=round_seed)
610
+
611
+ # pilot EnKF 使用当前先验重新中心化的集合
612
+ ens_mean_cur = np.mean(conds_preds, axis=0)
613
+ X_f_cur = conds_preds - ens_mean_cur[None, :, :] + psi_current[None, :, :]
614
+
615
+ if use_localization:
616
+ psi_pilot = enkf._enkf_update_localized(
617
+ X_f_cur, obs_xy_p1, obs_val_p1,
618
+ loc_radius_pixobs, loc_radius_obsobs, round_seed)
619
+ else:
620
+ psi_pilot = enkf._enkf_update_standard(X_f_cur, obs_xy_p1, obs_val_p1)
621
+
622
+ if verbose:
623
+ from sklearn.metrics import r2_score as _r2
624
+ print(f"[SmartEnKF] Round {round_idx+1}/{n_rounds} Phase1: "
625
+ f"{n1} 点, pilot R²={_r2(true_field.ravel(), psi_pilot.ravel()):.4f}")
626
+
627
+ # ── Phase 2:找欠校正区,补充选点 ──────────────────────────
628
+ psi_f_flat = psi_current.ravel()
629
+ psi_pilot_flat = psi_pilot.ravel()
630
+ correction_map = np.abs(psi_pilot_flat - psi_f_flat)
631
+
632
+ nz_vals = psi_f_flat[non_building_flat & (psi_f_flat > 1e-4)]
633
+ prior_thresh = np.quantile(nz_vals, 0.20) if len(nz_vals) > 20 else 1e-4
634
+ plume_support = non_building_flat & (psi_f_flat > prior_thresh)
635
+ cand_idx = np.where(plume_support)[0]
636
+ if len(cand_idx) < n2 * 3:
637
+ cand_idx = np.where(non_building_flat & (psi_f_flat > 1e-6))[0]
638
+
639
+ prior_cand = psi_f_flat[cand_idx]
640
+ corr_cand = correction_map[cand_idx]
641
+ prior_norm = prior_cand / (prior_cand.max() + 1e-12)
642
+ corr_norm = corr_cand / (corr_cand.max() + 1e-12)
643
+ under_score = prior_norm * (1.0 - corr_norm + 0.05)
644
+
645
+ p1_tree = cKDTree(obs_xy_p1)
646
+ dist_p1, _ = p1_tree.query(coords[cand_idx], k=1)
647
+ dist_w = np.tanh(dist_p1 / (min_dist_p2 * 2.5))
648
+
649
+ weights_p2 = (under_score ** under_correct_alpha) * (dist_w + 0.05)
650
+ weights_p2 = np.maximum(weights_p2, 1e-12)
651
+ weights_p2 /= weights_p2.sum()
652
+
653
+ rng_round = np.random.default_rng(round_seed)
654
+ n_over = min(len(cand_idx), max(n2 * 30, 600))
655
+ cands = rng_round.choice(cand_idx, size=n_over, replace=False, p=weights_p2)
656
+
657
+ selected_p2 = []
658
+ for cidx in cands:
659
+ xy = coords[cidx]
660
+ if p1_tree.query(xy, k=1)[0] < min_dist_p2:
661
+ continue
662
+ if selected_p2:
663
+ if cKDTree(coords[np.array(selected_p2)]).query(xy, k=1)[0] < min_dist_p2:
664
+ continue
665
+ selected_p2.append(int(cidx))
666
+ if len(selected_p2) >= n2:
667
+ break
668
+
669
+ if len(selected_p2) < n2:
670
+ remain = np.setdiff1d(cand_idx, np.array(selected_p2, dtype=int))
671
+ extra = rng_round.choice(remain,
672
+ size=min(n2 - len(selected_p2), len(remain)),
673
+ replace=False)
674
+ selected_p2.extend(extra.tolist())
675
+
676
+ obs_xy_p2 = coords[np.array(selected_p2[:n2])]
677
+ obs_val_p2 = ObservationModel.observation_operator_H(true_field, obs_xy_p2)
678
+
679
+ if verbose:
680
+ print(f"[SmartEnKF] Round {round_idx+1}/{n_rounds} Phase2: 补充 {n2} 个欠校正区域点")
681
+
682
+ # ── Final:本轮全部点 + 当前先验做最终 EnKF ────────────────
683
+ round_obs_xy = np.vstack([obs_xy_p1, obs_xy_p2])
684
+ round_obs_val = np.concatenate([obs_val_p1, obs_val_p2])
685
+
686
+ if use_localization:
687
+ psi_a_round = enkf._enkf_update_localized(
688
+ X_f_cur, round_obs_xy, round_obs_val,
689
+ loc_radius_pixobs, loc_radius_obsobs, round_seed)
690
+ else:
691
+ psi_a_round = enkf._enkf_update_standard(X_f_cur, round_obs_xy, round_obs_val)
692
+
693
+ if verbose:
694
+ from sklearn.metrics import r2_score as _r2
695
+ print(f"[SmartEnKF] Round {round_idx+1}/{n_rounds} Final: "
696
+ f"{n1+n2} 点, R²={_r2(true_field.ravel(), psi_a_round.ravel()):.4f}")
697
+
698
+ # 累计观测,更新先验进入下一轮
699
+ all_obs_xy_list.append(round_obs_xy)
700
+ all_obs_val_list.append(round_obs_val)
701
+ psi_current = np.maximum(psi_a_round, 0.0)
702
+ obs_xy_p1_last = obs_xy_p1
703
+
704
+ all_obs_xy = np.vstack(all_obs_xy_list)
705
+ all_obs_val = np.concatenate(all_obs_val_list)
706
+
707
+ return (psi_current, all_obs_xy, all_obs_val,
708
+ np.maximum(psi_pilot, 0.0), obs_xy_p1_last)
709
+
710
+ @staticmethod
711
+ def generate(true_field, pred_field, num_points, method="uniform", seed=42,
712
+ enkf=None, conds_preds=None, **sample_params):
713
+ field_shape = true_field.shape
714
+ if method == "random":
715
+ obs_xy = SamplingStrategies.sample_random(field_shape, num_points, seed)
716
+ elif method == "uniform":
717
+ obs_xy = SamplingStrategies.sample_uniform(field_shape, num_points)
718
+ elif method == "two_stage":
719
+ obs_xy, _ = SamplingStrategies.two_stage_sampling(
720
+ true_field,
721
+ pred_field,
722
+ num_points,
723
+ seed=seed,
724
+ **sample_params
725
+ )
726
+ elif method == "two_stage_pro":
727
+ obs_xy, _ = SamplingStrategies.two_stage_pro(
728
+ true_field,
729
+ pred_field,
730
+ num_points,
731
+ seed=seed,
732
+ **sample_params
733
+ )
734
+ elif method == "smart_two_pass":
735
+ if enkf is None or conds_preds is None:
736
+ raise ValueError(
737
+ "method='smart_two_pass' 需要传入 enkf 实例和 conds_preds 集合场。"
738
+ )
739
+ # 解析 n1 / n2(支持用 n1_ratio 自动计算)
740
+ n1_ratio = float(sample_params.pop('n1_ratio', 0.6))
741
+ n1_default = int(round(num_points * n1_ratio))
742
+ n1 = int(sample_params.pop('n1', n1_default))
743
+ if num_points > 1:
744
+ n1 = max(1, min(n1, num_points - 1))
745
+ else:
746
+ n1 = 1
747
+ n2 = int(sample_params.pop('n2', num_points - n1))
748
+ return SamplingStrategies.smart_two_pass(
749
+ enkf=enkf,
750
+ psi_f=pred_field,
751
+ conds_preds=conds_preds,
752
+ true_field=true_field,
753
+ n1=n1,
754
+ n2=n2,
755
+ seed=seed,
756
+ **sample_params,
757
+ )
758
+ else:
759
+ raise ValueError(f"Unknown observation sampling method: {method}")
760
+
761
+ obs_val = ObservationModel.observation_operator_H(true_field, obs_xy)
762
+ return obs_xy, obs_val
763
+
764
+
765
+ class EnKF:
766
+
767
+ def __init__(
768
+ self,
769
+ obs_std_scale=0.08, # relative observation noise level
770
+ damping=1.0,
771
+ jitter=1e-5,
772
+ ):
773
+ self.obs_std_scale = obs_std_scale
774
+ self.damping = damping
775
+ self.jitter = jitter
776
+
777
+ def standard_enkf(self, psi_f, conds_preds, obs_xy, d_obs):
778
+ """
779
+ psi_f: Unet预测的最佳先验场 (H, W)
780
+ conds_preds: 通过扰动参数生成的集合场 (N_ens, H, W)
781
+ obs_xy: 监测站坐标 (n_obs, 2)
782
+ d_obs: 监测站真实浓度 (n_obs,)
783
+ """
784
+ conds_preds = np.asarray(conds_preds)
785
+ N_ens, H, W = conds_preds.shape
786
+ n_obs = obs_xy.shape[0]
787
+
788
+ ens_mean = np.mean(conds_preds, axis=0) # 计算集合均值
789
+ # 重要:将集合成员的波动叠加到 Unet 预测场 psi_f 上 ,
790
+ # 确保分析场的统计中心是 Unet 预测的那个场,而不是集合均值(可能有偏差导致更新不好)
791
+ X_f = conds_preds - ens_mean[None, :, :] + psi_f[None, :, :]
792
+ X_f_flat = X_f.reshape(N_ens, -1) # (N_ens, Pixels)
793
+ HX = ObservationModel.observation_operator_H_ens(X_f, obs_xy) # (N_ens, n_obs)
794
+ HX_mean = np.mean(HX, axis=0)
795
+ X_f_bar = np.mean(X_f_flat, axis=0) # 计算偏差矩阵
796
+ A_prime = (X_f_flat - X_f_bar[None, :]).T # A_prime (状态偏差): (Pixels, N_ens)
797
+ Y_prime = (HX - HX_mean).T # Y_prime (观测空间偏差): (n_obs, N_ens)
798
+
799
+ # # 构造观测误差矩阵 R_e
800
+ # # 基于观测值大小设定自适应噪声 (8% 相对误差)
801
+ obs_std = self.obs_std_scale * np.maximum(np.abs(d_obs), 1.0) # 先定标准差
802
+ rng = np.random.default_rng(42) # SVD正交化生成E
803
+ Z = rng.standard_normal((N_ens, n_obs))
804
+ U, _, Vt = np.linalg.svd(Z, full_matrices=False)
805
+ Z = U @ Vt * np.sqrt(N_ens - 1)
806
+ E = Z * obs_std[None, :] # (N_ens, n_obs),E.T即为文献中的E矩阵
807
+ # 从E计算Re(按照文献公式 Re = EE^T / N-1)
808
+ E_T = E.T # (n_obs, N_ens),对应文献的E
809
+ R_e = (E_T @ E_T.T) / (N_ens - 1) # (n_obs, n_obs)
810
+ R_e += self.jitter * np.eye(n_obs) # 数值稳定项
811
+ Y_o = d_obs[None, :] + E # (N_ens, n_obs)
812
+
813
+ # 增益计算与状态更新 (对应公式 3-16, 3-17)
814
+ # 计算 Pe*H.T 和 H*Pe*H.T 的统计估计值
815
+ Pe_HT = (A_prime @ Y_prime.T) / (N_ens - 1)
816
+ H_Pe_HT = (Y_prime @ Y_prime.T) / (N_ens - 1)
817
+ # 计算集合卡尔曼增益 K_e = Pe*H.T * inverse(H*Pe*H.T + R_e)
818
+ # 使用 solve 提高数值稳定性
819
+ K_e = np.linalg.solve((H_Pe_HT + R_e).T, Pe_HT.T).T
820
+ # 计算创新值 (Innovation): (n_obs, N_ens)
821
+ # 每个成员根据自己的观测扰动和预测值进行修正
822
+ innovation = (Y_o - HX).T
823
+ # 更新系统状态的集合预测矩阵 X_a
824
+ # X_a = X_f + K_e * (Y_o - HX_f)
825
+ X_a_flat = X_f_flat + (self.damping * (K_e @ innovation)).T
826
+ # 输出最终分析场,取集合均值作为最终结果
827
+ psi_a_flat = np.mean(X_a_flat, axis=0)
828
+ psi_a = psi_a_flat.reshape(H, W)
829
+ # 物理约束:确保浓度不为负数
830
+ return np.maximum(psi_a, 0.0)
831
+
832
+ def enkf_localization(self, psi_f, conds_preds, obs_xy, d_obs,
833
+ loc_radius_pixobs=40.0, # Pixel-Obs localization radius (in pixels)
834
+ loc_radius_obsobs=60.0, # Obs-Obs localization radius (in pixels)
835
+ seed=42,
836
+ SAVE_DIAGNOSTICS=False,
837
+ ):
838
+ conds_preds = np.asarray(conds_preds)
839
+ N_ens, H, W = conds_preds.shape
840
+ n_obs = obs_xy.shape[0]
841
+
842
+ # ========= 1) prior ensemble centered at psi_f =========
843
+ ens_mean = np.mean(conds_preds, axis=0)
844
+ X_f = conds_preds - ens_mean[None, :, :] + psi_f[None, :, :]
845
+ X_f_flat = X_f.reshape(N_ens, -1)
846
+
847
+ HX = ObservationModel.observation_operator_H_ens(X_f, obs_xy) # 注意:必须用 X_f
848
+ HX_mean = np.mean(HX, axis=0)
849
+
850
+ X_f_bar = np.mean(X_f_flat, axis=0)
851
+ A_prime = (X_f_flat - X_f_bar[None, :]).T # (Pixels, N_ens)
852
+ Y_prime = (HX - HX_mean).T # (n_obs, N_ens)
853
+
854
+ # # ========= 2) perturbed obs (deterministic-ish, fixed seed) =========
855
+ obs_std = self.obs_std_scale * np.maximum(np.abs(d_obs), 1.0) # 先定标准差
856
+ rng = np.random.default_rng(seed) # SVD正交化生成E
857
+ Z = rng.standard_normal((N_ens, n_obs))
858
+ U, _, Vt = np.linalg.svd(Z, full_matrices=False)
859
+ Z = U @ Vt * np.sqrt(N_ens - 1)
860
+ E = Z * obs_std[None, :] # (N_ens, n_obs),E.T即为文献中的E矩阵
861
+ # 从E计算Re(按照文献公式 Re = EE^T / N-1)
862
+ E_T = E.T # (n_obs, N_ens),对应文献的E
863
+ R_e = (E_T @ E_T.T) / (N_ens - 1) # (n_obs, n_obs)
864
+ R_e += self.jitter * np.eye(n_obs) # 数值稳定项
865
+ Y_o = d_obs[None, :] + E
866
+
867
+ # ========= 3) sample covariances =========
868
+ Pe_HT = (A_prime @ Y_prime.T) / (N_ens - 1) # (Pixels, n_obs)
869
+ H_Pe_HT = (Y_prime @ Y_prime.T) / (N_ens - 1) # (n_obs, n_obs)
870
+
871
+ # ========= 4) localization =========
872
+ # (a) Pixel-Obs localization: rho_xy (Pixels, n_obs)
873
+ yy, xx = np.meshgrid(np.arange(H), np.arange(W), indexing="ij")
874
+ grid = np.stack([xx.ravel(), yy.ravel()], axis=1) # (Pixels,2)
875
+ dx = grid[:, None, 0] - obs_xy[None, :, 0]
876
+ dy = grid[:, None, 1] - obs_xy[None, :, 1]
877
+ dist2_xy = dx*dx + dy*dy
878
+ rho_xy = np.exp(-0.5 * dist2_xy / (loc_radius_pixobs**2))
879
+
880
+ # (b) Obs-Obs localization: rho_oo (n_obs, n_obs)
881
+ dox = obs_xy[:, None, 0] - obs_xy[None, :, 0]
882
+ doy = obs_xy[:, None, 1] - obs_xy[None, :, 1]
883
+ dist2_oo = dox*dox + doy*doy
884
+ rho_oo = np.exp(-0.5 * dist2_oo / (loc_radius_obsobs**2))
885
+ Pe_HT = Pe_HT * rho_xy
886
+ H_Pe_HT = H_Pe_HT * rho_oo
887
+
888
+ # ========= [诊断] P_e 的谱结构 =========
889
+ # P_e = A_prime @ A_prime.T / (N_ens-1),直接分解 A_prime 的奇异值更高效
890
+ # A_prime shape: (Pixels, N_ens),SVD给出 P_e 的特征值 = sigma^2
891
+ U_ens, sigma, Vt_ens = np.linalg.svd(A_prime / np.sqrt(N_ens - 1), full_matrices=False)
892
+ # sigma shape: (N_ens,),对应 P_e 的特征值平方根
893
+ eigenvalues = sigma ** 2 # P_e 的特征值,降序排列
894
+
895
+ # --- 指标1:有效秩 r_eff = (Σλ)² / Σλ² 衡量特征值分布均匀程度---
896
+ # r_eff→1: 近似秩1(能量集中于单一方向)r_eff→N: 各向同性(能量均匀分布)
897
+ r_eff = (eigenvalues.sum() ** 2) / (eigenvalues ** 2).sum()
898
+
899
+ # --- 指标2:主特征值 λ1 = P_e 在主方向上的方差 ---
900
+ # 只受幅度参数(v, Q)影响,随d单调增大
901
+ lambda1 = eigenvalues[0]
902
+ lambda_min = eigenvalues[-2]
903
+
904
+ # --- 指标3:方向集中度 λ1/λ2 衡量P_e各向异性程度 ---
905
+ # 峰值对应最优d配置(d**),超过后集合引入非物理方向
906
+ ratio_1_2 = eigenvalues[0] / eigenvalues[1] if len(eigenvalues) > 1 else np.inf
907
+
908
+ # --- 指标4:主特征向量峰值位置---
909
+ # 峰值位置随d系统性漂移,随v/Q不变,随n随机漂移
910
+ u1 = U_ens[:, 0].reshape(H, W) # u1 = P_e 的第一特征向量,代表集合扰动的主方向
911
+ u1_peak = np.unravel_index(np.abs(u1).argmax(), u1.shape)
912
+
913
+ # ========= 5) Kalman gain =========
914
+ S = H_Pe_HT + R_e
915
+ K_e = np.linalg.solve(S.T, Pe_HT.T).T
916
+
917
+ innovation = (Y_o - HX).T
918
+ X_a_flat = X_f_flat + (self.damping * (K_e @ innovation)).T
919
+
920
+ psi_a = np.mean(X_a_flat, axis=0).reshape(H, W)
921
+ psi_a = np.maximum(psi_a, 0.0)
922
+
923
+ if SAVE_DIAGNOSTICS:
924
+ print("=" * 50)
925
+ print(f"[P_e 谱诊断]")
926
+ print(f" 指标1 r_eff = {r_eff:.2f} # (Σλ)²/Σλ²,建筑影响下界≈2.1")
927
+ print(f" 指���2 λ1 = {lambda1:.2f} {lambda_min:.2f} # 主方向方差,随d单调增大")
928
+ print(f" 指标3 λ1/λ2 = {ratio_1_2:.2f} # 各向异性,d=45°时峰值→最优配置")
929
+ print(f" 指标4 u1峰值位置 = {u1_peak} # d变化时系统漂移,v/Q不变")
930
+ print("=" * 50)
931
+ diag = {
932
+ 'r_eff': r_eff,
933
+ 'lambda1': lambda1,
934
+ 'ratio_1_2': ratio_1_2,
935
+ 'u1_peak_row': u1_peak[0],
936
+ 'u1_peak_col': u1_peak[1],
937
+ }
938
+ return psi_a, diag
939
+ else:
940
+ return psi_a
941
+
942
+ def _enkf_update_standard(self, X_f, obs_xy, d_obs):
943
+ """
944
+ 标准 EnKF 更新,直接接受已中心化的集合 X_f (N_ens, H, W)。
945
+ 返回分析场均值 psi_a (H, W),>=0。
946
+ """
947
+ N_ens, H, W = X_f.shape
948
+ n_obs = obs_xy.shape[0]
949
+ X_f_flat = X_f.reshape(N_ens, -1)
950
+
951
+ HX = ObservationModel.observation_operator_H_ens(X_f, obs_xy)
952
+ HX_mean = np.mean(HX, axis=0)
953
+ X_f_bar = np.mean(X_f_flat, axis=0)
954
+ A_prime = (X_f_flat - X_f_bar[None, :]).T # (Pixels, N_ens)
955
+ Y_prime = (HX - HX_mean).T # (n_obs, N_ens)
956
+
957
+ obs_std = self.obs_std_scale * np.maximum(np.abs(d_obs), 1.0)
958
+ rng = np.random.default_rng(42)
959
+ Z = rng.standard_normal((N_ens, n_obs))
960
+ U, _, Vt = np.linalg.svd(Z, full_matrices=False)
961
+ Z = U @ Vt * np.sqrt(N_ens - 1)
962
+ E = Z * obs_std[None, :]
963
+ E_T = E.T
964
+ R_e = (E_T @ E_T.T) / (N_ens - 1)
965
+ R_e += self.jitter * np.eye(n_obs)
966
+ Y_o = d_obs[None, :] + E
967
+
968
+ Pe_HT = (A_prime @ Y_prime.T) / (N_ens - 1)
969
+ H_Pe_HT = (Y_prime @ Y_prime.T) / (N_ens - 1)
970
+ K_e = np.linalg.solve((H_Pe_HT + R_e).T, Pe_HT.T).T
971
+ innovation = (Y_o - HX).T
972
+ X_a_flat = X_f_flat + (self.damping * (K_e @ innovation)).T
973
+ psi_a = np.mean(X_a_flat, axis=0).reshape(H, W)
974
+ return np.maximum(psi_a, 0.0)
975
+
976
+ def _enkf_update_localized(self, X_f, obs_xy, d_obs,
977
+ loc_radius_pixobs=35.0,
978
+ loc_radius_obsobs=40.0,
979
+ seed=42):
980
+ """
981
+ 局地化 EnKF 更新,直接接受已中心化的集合 X_f (N_ens, H, W)。
982
+ 返回分析场均值 psi_a (H, W),>=0。
983
+ """
984
+ N_ens, H, W = X_f.shape
985
+ n_obs = obs_xy.shape[0]
986
+ X_f_flat = X_f.reshape(N_ens, -1)
987
+
988
+ HX = ObservationModel.observation_operator_H_ens(X_f, obs_xy)
989
+ HX_mean = np.mean(HX, axis=0)
990
+ X_f_bar = np.mean(X_f_flat, axis=0)
991
+ A_prime = (X_f_flat - X_f_bar[None, :]).T
992
+ Y_prime = (HX - HX_mean).T
993
+
994
+ obs_std = self.obs_std_scale * np.maximum(np.abs(d_obs), 1.0)
995
+ rng = np.random.default_rng(seed)
996
+ Z = rng.standard_normal((N_ens, n_obs))
997
+ U, _, Vt = np.linalg.svd(Z, full_matrices=False)
998
+ Z = U @ Vt * np.sqrt(N_ens - 1)
999
+ E = Z * obs_std[None, :]
1000
+ E_T = E.T
1001
+ R_e = (E_T @ E_T.T) / (N_ens - 1)
1002
+ R_e += self.jitter * np.eye(n_obs)
1003
+ Y_o = d_obs[None, :] + E
1004
+
1005
+ Pe_HT = (A_prime @ Y_prime.T) / (N_ens - 1)
1006
+ H_Pe_HT = (Y_prime @ Y_prime.T) / (N_ens - 1)
1007
+
1008
+ yy, xx = np.meshgrid(np.arange(H), np.arange(W), indexing="ij")
1009
+ grid = np.stack([xx.ravel(), yy.ravel()], axis=1)
1010
+ dx = grid[:, None, 0] - obs_xy[None, :, 0]
1011
+ dy = grid[:, None, 1] - obs_xy[None, :, 1]
1012
+ rho_xy = np.exp(-0.5 * (dx*dx + dy*dy) / (loc_radius_pixobs**2))
1013
+ dox = obs_xy[:, None, 0] - obs_xy[None, :, 0]
1014
+ doy = obs_xy[:, None, 1] - obs_xy[None, :, 1]
1015
+ rho_oo = np.exp(-0.5 * (dox*dox + doy*doy) / (loc_radius_obsobs**2))
1016
+
1017
+ Pe_HT = Pe_HT * rho_xy
1018
+ H_Pe_HT = H_Pe_HT * rho_oo
1019
+
1020
+ S = H_Pe_HT + R_e
1021
+ K_e = np.linalg.solve(S.T, Pe_HT.T).T
1022
+ innovation = (Y_o - HX).T
1023
+ X_a_flat = X_f_flat + (self.damping * (K_e @ innovation)).T
1024
+ psi_a = np.mean(X_a_flat, axis=0).reshape(H, W)
1025
+ return np.maximum(psi_a, 0.0)
1026
+
1027
+
1028
+ class PrintMetrics:
1029
+ @staticmethod
1030
+ def pad_center_crop(arr, center_y, center_x, out_h=256, out_w=256):
1031
+ # Pad and center-crop 2D or 3D array
1032
+ if arr.ndim == 3:
1033
+ C, H, W = arr.shape
1034
+ out = np.zeros((C, out_h, out_w), dtype=arr.dtype)
1035
+ else:
1036
+ H, W = arr.shape
1037
+ out = np.zeros((out_h, out_w), dtype=arr.dtype)
1038
+ y0, x0 = center_y - out_h // 2, center_x - out_w // 2
1039
+ y1, x1 = y0 + out_h, x0 + out_w
1040
+ sy0, sy1 = max(0, y0), min(H, y1)
1041
+ sx0, sx1 = max(0, x0), min(W, x1)
1042
+ dy0, dx0 = sy0 - y0, sx0 - x0
1043
+ dy1, dx1 = dy0 + (sy1 - sy0), dx0 + (sx1 - sx0)
1044
+ if arr.ndim == 3:
1045
+ out[:, dy0:dy1, dx0:dx1] = arr[:, sy0:sy1, sx0:sx1]
1046
+ else:
1047
+ out[dy0:dy1, dx0:dx1] = arr[sy0:sy1, sx0:sx1]
1048
+ return out
1049
+
1050
+ @staticmethod
1051
+ def get_building_area():
1052
+ # load building data
1053
+ npz_path = '../Gas_unet/Gas_code/dataset_m/5min_m_Data_special/min5_m_v1_0_d270_sc2_s10_04118.npz'
1054
+ data = np.load(npz_path)
1055
+ build_data = data['three_channel_data'][0]
1056
+ non_building_mask = (build_data == 0).astype(np.uint8)
1057
+ center_y, center_x = 498, 538
1058
+ build_data_256 = PrintMetrics.pad_center_crop(build_data, center_y,
1059
+ center_x, 256, 256)
1060
+ non_building_mask = PrintMetrics.pad_center_crop(non_building_mask,
1061
+ center_y, center_x, 256, 256)
1062
+ return build_data_256, non_building_mask
1063
+
1064
+ @staticmethod
1065
+ def weighted_r2(y_true, y_pred, gamma=1.0, eps=1e-12):
1066
+ """
1067
+ Weighted R2 score emphasizing high-value regions.
1068
+ """
1069
+ y_true = np.asarray(y_true)
1070
+ y_pred = np.asarray(y_pred)
1071
+
1072
+ w = np.maximum(y_true, eps) ** gamma
1073
+ w = w / np.sum(w)
1074
+
1075
+ y_bar = np.sum(w * y_true)
1076
+
1077
+ num = np.sum(w * (y_true - y_pred) ** 2)
1078
+ den = np.sum(w * (y_true - y_bar) ** 2)
1079
+
1080
+ if den < eps:
1081
+ return np.nan
1082
+
1083
+ return 1.0 - num / den
1084
+
1085
+
1086
+
1087
+ @staticmethod
1088
+ def print_metrics(i, wind_speed, wind_direction, sc, source_number,
1089
+ true_field, pred_field, analysis, obs_xy,
1090
+ metrics_save_flag=False, metrics_print_flag=True):
1091
+ """
1092
+ Metrics:
1093
+ 1) Field-wise (all pixels)
1094
+ 2) Plume-aware (true > eps)
1095
+ 3) At observations
1096
+ """
1097
+
1098
+ def nmse_metrics(y_true, y_pred):
1099
+ nmse = np.mean((y_true.flatten() - y_pred.flatten())**2) / (np.mean(y_true) * np.mean(y_pred) + 1e-12)
1100
+ return nmse
1101
+
1102
+ def nmae_metrics(y_true, y_pred):
1103
+ nmae = np.mean(np.abs(y_true.flatten() - y_pred.flatten())) / (np.mean(y_true) + 1e-12)
1104
+ return nmae
1105
+
1106
+ # ========= 保留原始 2D 场 =========
1107
+ _, non_building_mask = PrintMetrics.get_building_area()
1108
+ true_field = np.where(true_field > 0, true_field, 0) * non_building_mask
1109
+ pred_field = np.where(pred_field > 0, pred_field, 0) * non_building_mask
1110
+ analysis = np.where(analysis > 0, analysis, 0) * non_building_mask
1111
+ true_flat = true_field.ravel()
1112
+ pred_flat = pred_field.ravel()
1113
+ ana_flat = analysis.ravel()
1114
+
1115
+ # ===============================
1116
+ # (1) Field-wise (all pixels)
1117
+ # ===============================
1118
+ r2_before = r2_score(true_flat, pred_flat)
1119
+ r2_after = r2_score(true_flat, ana_flat)
1120
+ mse_before = mean_squared_error(true_flat, pred_flat)
1121
+ mse_after = mean_squared_error(true_flat, ana_flat)
1122
+ mae_before = mean_absolute_error(true_flat, pred_flat)
1123
+ mae_after = mean_absolute_error(true_flat, ana_flat)
1124
+ nmse_before = nmse_metrics(true_flat, pred_flat)
1125
+ nmse_after = nmse_metrics(true_flat, ana_flat)
1126
+ nmae_before = nmae_metrics(true_flat, pred_flat)
1127
+ nmae_after = nmae_metrics(true_flat, ana_flat)
1128
+
1129
+ # ===============================
1130
+ # (2) Plume-aware (true > eps)
1131
+ # ===============================
1132
+ plume_mask = true_flat > 1e-6
1133
+ true_p = true_flat[plume_mask]
1134
+ pred_p = pred_flat[plume_mask]
1135
+ ana_p = ana_flat[plume_mask]
1136
+ r2_plume_before = r2_score(true_p, pred_p)
1137
+ r2_plume_after = r2_score(true_p, ana_p)
1138
+ mse_plume_before = mean_squared_error(true_p, pred_p)
1139
+ mse_plume_after = mean_squared_error(true_p, ana_p)
1140
+ mae_plume_before = mean_absolute_error(true_p, pred_p)
1141
+ mae_plume_after = mean_absolute_error(true_p, ana_p)
1142
+ nmse_plume_before = nmse_metrics(true_p, pred_p)
1143
+ nmse_plume_after = nmse_metrics(true_p, ana_p)
1144
+ nmae_plume_before = nmae_metrics(true_p, pred_p)
1145
+ nmae_plume_after = nmae_metrics(true_p, ana_p)
1146
+
1147
+ # ---- Weighted R2 (plume-aware) ----
1148
+ wr2_plume_before = PrintMetrics.weighted_r2(true_p, pred_p, gamma=1.0)
1149
+ wr2_plume_after = PrintMetrics.weighted_r2(true_p, ana_p, gamma=1.0)
1150
+
1151
+
1152
+ # ===============================
1153
+ # (3) At observations
1154
+ # ===============================
1155
+ true_at_obs = ObservationModel.observation_operator_H(true_field, obs_xy)
1156
+ pred_at_obs = ObservationModel.observation_operator_H(pred_field, obs_xy)
1157
+ ana_at_obs = ObservationModel.observation_operator_H(analysis, obs_xy)
1158
+ r2_obs_before = r2_score(true_at_obs, pred_at_obs)
1159
+ r2_obs_after = r2_score(true_at_obs, ana_at_obs)
1160
+ mse_obs_before = mean_squared_error(true_at_obs, pred_at_obs)
1161
+ mse_obs_after = mean_squared_error(true_at_obs, ana_at_obs)
1162
+ mae_obs_before = mean_absolute_error(true_at_obs, pred_at_obs)
1163
+ mae_obs_after = mean_absolute_error(true_at_obs, ana_at_obs)
1164
+ nmse_obs_before = nmse_metrics(true_at_obs, pred_at_obs)
1165
+ nmse_obs_after = nmse_metrics(true_at_obs, ana_at_obs)
1166
+ nmae_obs_before = nmae_metrics(true_at_obs, pred_at_obs)
1167
+ nmae_obs_after = nmae_metrics(true_at_obs, ana_at_obs)
1168
+
1169
+ if metrics_print_flag:
1170
+ print("=== Assimilation Metrics ===")
1171
+ print("[Field-wise]")
1172
+ print(f"R2 : {r2_before:.4f}->{r2_after:.4f}")
1173
+ print(f"MSE : {mse_before:.4f}->{mse_after:.4f}")
1174
+ print(f"MAE : {mae_before:.4f}->{mae_after:.4f}")
1175
+ print("[Plume-aware]")
1176
+ print(f"R2 : {r2_plume_before:.4f}->{r2_plume_after:.4f}")
1177
+ print(f"MSE : {mse_plume_before:.4f}->{mse_plume_after:.4f}")
1178
+ print(f"MAE : {mae_plume_before:.4f}->{mae_plume_after:.4f}")
1179
+ print(f"W-R2 : {wr2_plume_before:.4f}->{wr2_plume_after:.4f}")
1180
+ print("[At observations]")
1181
+ print(f"R2 : {r2_obs_before:.4f}->{r2_obs_after:.4f}")
1182
+ print(f"MSE : {mse_obs_before:.4f}->{mse_obs_after:.4f}")
1183
+ print(f"MAE : {mae_obs_before:.4f}->{mae_obs_after:.4f}")
1184
+
1185
+ if metrics_save_flag:
1186
+ return {
1187
+ 'idx': i,
1188
+ 'wind_speed': wind_speed,
1189
+ 'wind_direction': wind_direction,
1190
+ 'stability_class': sc,
1191
+ 'source_number': source_number,
1192
+ "r2_before": r2_before,
1193
+ "r2_after": r2_after,
1194
+ "r2_plume_before": r2_plume_before,
1195
+ "r2_plume_after": r2_plume_after,
1196
+ "w_r2_plume_before": wr2_plume_before,
1197
+ "w_r2_plume_after": wr2_plume_after,
1198
+ "r2_obs_before": r2_obs_before,
1199
+ "r2_obs_after": r2_obs_after,
1200
+ "mse_before": mse_before,
1201
+ "mse_after": mse_after,
1202
+ "mse_plume_before": mse_plume_before,
1203
+ "mse_plume_after": mse_plume_after,
1204
+ "mse_obs_before": mse_obs_before,
1205
+ "mse_obs_after": mse_obs_after,
1206
+ "mae_before": mae_before,
1207
+ "mae_after": mae_after,
1208
+ "mae_plume_before": mae_plume_before,
1209
+ "mae_plume_after": mae_plume_after,
1210
+ "mae_obs_before": mae_obs_before,
1211
+ "mae_obs_after": mae_obs_after,
1212
+ "nmse_before": nmse_before,
1213
+ "nmse_after": nmse_after,
1214
+ "nmse_plume_before": nmse_plume_before,
1215
+ "nmse_plume_after": nmse_plume_after,
1216
+ "nmae_before": nmae_before,
1217
+ "nmae_after": nmae_after,
1218
+ "nmae_plume_before": nmae_plume_before,
1219
+ "nmae_plume_after": nmae_plume_after,
1220
+ "nmse_obs_before": nmse_obs_before,
1221
+ "nmse_obs_after": nmse_obs_after,
1222
+ "nmae_obs_before": nmae_obs_before,
1223
+ "nmae_obs_after": nmae_obs_after,
1224
+ }
1225
+
1226
+ class Visualization:
1227
+ def plot_assimilation_with_building(
1228
+ true_field,
1229
+ pred_field,
1230
+ analysis,
1231
+ obs_xy,
1232
+ vmax=10,
1233
+ title_suffix=""
1234
+ ):
1235
+ """
1236
+ - 建筑 mask
1237
+ - 非建筑区浓度
1238
+ - 同化前 / 后对比
1239
+ """
1240
+
1241
+ # ---------- 物理裁剪 + 建筑 mask ----------
1242
+ build_data_256, non_building_mask = PrintMetrics.get_building_area()
1243
+ true_field = np.where(true_field > 0, true_field, 0) * non_building_mask
1244
+ pred_field = np.where(pred_field > 0, pred_field, 0) * non_building_mask
1245
+ analysis = np.where(analysis > 0, analysis, 0) * non_building_mask
1246
+
1247
+ # ---------- 画图 ----------
1248
+ fig, axs = plt.subplots(1, 3, figsize=(14, 4), dpi=300)
1249
+ cmap = "inferno"
1250
+ levels = np.linspace(0, vmax, 21)
1251
+
1252
+ im0 = axs[0].contourf(true_field, levels=levels, cmap=cmap, vmin=0, vmax=vmax,
1253
+ extend='max')
1254
+ axs[0].set_title('True Field' + title_suffix)
1255
+ plt.colorbar(im0, ax=axs[0])
1256
+
1257
+ im1 = axs[1].contourf(pred_field, levels=levels, cmap=cmap, vmin=0, vmax=vmax,
1258
+ extend='max')
1259
+ axs[1].set_title(r'Prior Prediction Field $\psi^{f}$')
1260
+ plt.colorbar(im1, ax=axs[1])
1261
+
1262
+ im2 = axs[2].contourf(analysis, levels=levels, cmap=cmap, vmin=0, vmax=vmax,
1263
+ extend='max')
1264
+ axs[2].set_title(r'Analysis $\psi^{a}$')
1265
+ plt.colorbar(im2, ax=axs[2])
1266
+ axs[0].scatter(obs_xy[:, 0], obs_xy[:, 1], c='red', s=15, edgecolors='k')
1267
+ axs[1].scatter(obs_xy[:, 0], obs_xy[:, 1], c='red', s=15, edgecolors='k')
1268
+ axs[2].scatter(obs_xy[:, 0], obs_xy[:, 1], c='red', s=15, edgecolors='k')
1269
+
1270
+ # ---------- 指标 ----------
1271
+ axs[1].text(
1272
+ 80, 15,
1273
+ f"$R^2$={r2_score(true_field.ravel(), pred_field.ravel()):.4f}\n"
1274
+ f"$MSE$={mean_squared_error(true_field.ravel(), pred_field.ravel()):.4f}\n"
1275
+ f"$MAE@Obs$={mean_absolute_error(ObservationModel.observation_operator_H(true_field, obs_xy),
1276
+ ObservationModel.observation_operator_H(pred_field, obs_xy)):.3f}",
1277
+ color='white'
1278
+ )
1279
+ axs[2].text(
1280
+ 80, 15,
1281
+ f"$R^2$={r2_score(true_field.ravel(), analysis.ravel()):.4f}\n"
1282
+ f"$MSE$={mean_squared_error(true_field.ravel(), analysis.ravel()):.4f}\n"
1283
+ f"$MAE@Obs$={mean_absolute_error(ObservationModel.observation_operator_H(true_field, obs_xy),
1284
+ ObservationModel.observation_operator_H(analysis, obs_xy)):.3f}",
1285
+ color='white'
1286
+ )
1287
+ plt.tight_layout()
1288
+ plt.show()
1289
+
1290
+ def plot_assimilation_4panel(
1291
+ true_field,
1292
+ pred_field,
1293
+ analysis,
1294
+ obs_xy,
1295
+ obs_val,
1296
+ vmin=0,
1297
+ vmax=10,
1298
+ title_suffix=""
1299
+ ):
1300
+ # ---------- 物理裁剪 + 建筑 mask ----------
1301
+ _, non_building_mask = PrintMetrics.get_building_area()
1302
+ true_field = np.where(true_field > 0, true_field, 0) * non_building_mask
1303
+ pred_field = np.where(pred_field > 0, pred_field, 0) * non_building_mask
1304
+ analysis = np.where(analysis > 0, analysis, 0) * non_building_mask
1305
+
1306
+ # ---------- Figure ----------
1307
+ fig, axs = plt.subplots(1, 4, figsize=(18, 4), dpi=300)
1308
+ cmap = "inferno"
1309
+ levels = np.linspace(0, vmax, 21)
1310
+
1311
+ # ---------- (a) True field ----------
1312
+ im0 = axs[0].contourf(true_field, levels=levels, cmap=cmap, vmin=vmin, vmax=vmax,
1313
+ extend='both')
1314
+ axs[0].set_title("True Field" + title_suffix)
1315
+ plt.colorbar(im0, ax=axs[0])
1316
+
1317
+ # ---------- (b) Prior prediction ----------
1318
+ im1 = axs[1].contourf(pred_field, levels=levels, cmap=cmap, vmin=vmin, vmax=vmax,
1319
+ extend='both')
1320
+ axs[1].set_title(r"Prior Prediction $\psi^{f}$")
1321
+ plt.colorbar(im1, ax=axs[1])
1322
+
1323
+ # ---------- (c) Observations (points only) ----------
1324
+ sc = axs[2].scatter(
1325
+ obs_xy[:, 0],
1326
+ obs_xy[:, 1],
1327
+ c=obs_val,
1328
+ cmap=cmap,
1329
+ vmin=vmin,
1330
+ vmax=vmax,
1331
+ s=30,
1332
+ edgecolors="k",
1333
+ linewidths=0.4,
1334
+ alpha=0.9
1335
+ )
1336
+ axs[2].set_title("Observations $d_i$")
1337
+ axs[2].set_xlim(0, true_field.shape[1])
1338
+ axs[2].set_ylim(true_field.shape[0], 0)
1339
+ axs[2].set_aspect("equal")
1340
+ axs[2].invert_yaxis()
1341
+ plt.colorbar(sc, ax=axs[2], extend='both')
1342
+
1343
+ # ---------- (d) Analysis field ----------
1344
+ im3 = axs[3].contourf(analysis, levels=levels, cmap=cmap, vmin=vmin, vmax=vmax,
1345
+ extend='both')
1346
+ axs[3].set_title(r"Analysis $\psi^{a}$")
1347
+ plt.colorbar(im3, ax=axs[3])
1348
+
1349
+ # ---------- Metrics ----------pred_at_obs
1350
+ axs[1].text(
1351
+ 0.02, 0.95,
1352
+ f"$R^2$={r2_score(true_field.ravel(), pred_field.ravel()):.4f}\n"
1353
+ f"$MSE$={mean_squared_error(true_field.ravel(), pred_field.ravel()):.4f}",
1354
+ transform=axs[1].transAxes,
1355
+ va="top",
1356
+ color="white"
1357
+ )
1358
+
1359
+ axs[3].text(
1360
+ 0.02, 0.95,
1361
+ f"$R^2$={r2_score(true_field.ravel(), analysis.ravel()):.4f}\n"
1362
+ f"$MSE$={mean_squared_error(true_field.ravel(), analysis.ravel()):.4f}",
1363
+ transform=axs[3].transAxes,
1364
+ va="top",
1365
+ color="white"
1366
+ )
1367
+ plt.tight_layout()
1368
+ plt.show()
1369
+
1370
+ def plot_pe_spectrum(all_diags, save_flag=False):
1371
+ C_BLUE = "#488ABA"
1372
+ C_ORANGE = "#e5954e"
1373
+
1374
+ ds = [diag['d'] for diag in all_diags]
1375
+ r_eff_values = [diag['r_eff'] for diag in all_diags]
1376
+ ratio_12_values = [diag['ratio_1_2'] for diag in all_diags]
1377
+
1378
+ fig, ax = plt.subplots(figsize=(6, 3), dpi=300)
1379
+ l1, = ax.plot(ds, r_eff_values,
1380
+ marker='o', color=C_BLUE, linewidth=1.5,
1381
+ alpha=0.4,
1382
+ markersize=6, label=r'$r_{\rm eff}$')
1383
+ ax.set_xlabel(r'Wind direction', labelpad=3)
1384
+ ax.set_ylabel(r'Effective rank $r_{\rm eff}$',
1385
+ color=C_BLUE, labelpad=4)
1386
+ ax.tick_params(axis='y', colors=C_BLUE)
1387
+ ax.spines['left'].set_color(C_BLUE)
1388
+ ax.set_xticks(ds)
1389
+ ax.set_xticklabels([f'{d}°' for d in ds])
1390
+ ax.set_ylim(1.5, 3)
1391
+
1392
+ # 右轴:λ1/λ2
1393
+ ax2 = ax.twinx()
1394
+ l2, = ax2.plot(ds, ratio_12_values,
1395
+ marker='s', color=C_ORANGE, linewidth=1.5,
1396
+ alpha=0.4,
1397
+ markersize=6, label=r'$\lambda_1/\lambda_2$')
1398
+ ax2.set_ylabel(r'Anisotropy $\lambda_1/\lambda_2$',
1399
+ color=C_ORANGE, labelpad=4)
1400
+ ax2.tick_params(axis='y', colors=C_ORANGE)
1401
+ ax2.spines['right'].set_color(C_ORANGE)
1402
+ ax2.set_ylim(1.5, 3.5)
1403
+
1404
+ opt_idx = int(np.argmax(ratio_12_values))
1405
+ ax2.axvline(ds[opt_idx], color='grey', linewidth=0.8, linestyle='--', alpha=0.6)
1406
+ ax2.text(ds[opt_idx] + 0.8, 1.58, r'$d^{**}$', color='grey')
1407
+
1408
+ # 统一图例
1409
+ fig.legend(handles=[l1, l2],
1410
+ loc='upper left',
1411
+ bbox_to_anchor=(0.15, 0.95),
1412
+ ncol=1, frameon=False)
1413
+
1414
+ plt.tight_layout()
1415
+ if save_flag:
1416
+ plt.savefig('./figures/test1/reff_ratio.png', dpi=300, bbox_inches='tight',
1417
+ transparent=True)
1418
+ # plt.savefig('./figures/test1/reff_ratio.svg', dpi=300, bbox_inches='tight', format='svg')
1419
+ plt.show()
1420
+
1421
+ def assimilation_scatter(psi_t_log, psi_f_log, psi_a_log, obs_xy):
1422
+ def log10_formatter(x, pos):
1423
+ return r'$10^{%d}$' % x
1424
+
1425
+ obs_true = np.log10(ObservationModel.observation_operator_H(psi_t_log, obs_xy)+1e-3)
1426
+ obs_prior = np.log10(ObservationModel.observation_operator_H(psi_f_log, obs_xy)+1e-3)
1427
+ obs_analysis = np.log10(ObservationModel.observation_operator_H(psi_a_log, obs_xy)+1e-3)
1428
+
1429
+ fig, ax = plt.subplots(figsize=(5, 4.5), dpi=300)
1430
+ vmin, vamx = -4, 2
1431
+ lim = [vmin, vamx]
1432
+ ax.plot(lim, lim, 'k--', lw=1, label='1:1 line', zorder=1)
1433
+
1434
+ for obs_pred, label, color in zip(
1435
+ [obs_prior, obs_analysis],
1436
+ ['Prior', 'Analysis'],
1437
+ ['steelblue', 'tomato']
1438
+ ):
1439
+ # 散点
1440
+ ax.scatter(obs_true, obs_pred, s=25, alpha=0.6, color=color, zorder=3)
1441
+
1442
+ slope, intercept, r, _, _ = stats.linregress(obs_true, obs_pred)
1443
+ rmse = np.sqrt(np.mean((obs_pred - obs_true) ** 2))
1444
+ x_fit = np.linspace(lim[0], lim[1], 100)
1445
+ ax.plot(x_fit, slope * x_fit + intercept, '-', color=color, lw=1.5,
1446
+ label=f'{label}r:{r:.2f}', zorder=2)
1447
+
1448
+ ax.set_xlabel('log(True)')
1449
+ ax.set_ylabel('log(Predicted)')
1450
+ ax.legend(loc='lower right', frameon=False)
1451
+ ax.set_xlim(-4, 2)
1452
+ ax.set_ylim(-4, 2)
1453
+ ax.xaxis.set_major_formatter(ticker.FuncFormatter(log10_formatter))
1454
+ ax.yaxis.set_major_formatter(ticker.FuncFormatter(log10_formatter))
1455
+
1456
+ plt.tight_layout()
1457
+ plt.show()
1458
+
1459
+ def none_assimilation_scatter(psi_t_log, psi_f_log, psi_a_log, obs_xy):
1460
+ def sample_independent_points(field, obs_xy, num_points=100, seed=42):
1461
+ H, W = field.shape
1462
+ yy, xx = np.meshgrid(np.arange(H), np.arange(W), indexing='ij')
1463
+ all_xy = np.stack([xx.ravel(), yy.ravel()], axis=1)
1464
+ obs_set = set(map(tuple, np.round(obs_xy).astype(int)))
1465
+ obs_mask = np.array([tuple(p) not in obs_set for p in all_xy])
1466
+ _, non_building_mask = PrintMetrics.get_building_area()
1467
+ building_mask = non_building_mask[yy.ravel(), xx.ravel()] == 1
1468
+ num_mask = field.ravel() > 1e-4
1469
+ candidate_xy = all_xy[obs_mask & building_mask & num_mask]
1470
+ rng = np.random.default_rng(seed)
1471
+ idx = rng.choice(len(candidate_xy), num_points, replace=False)
1472
+ return candidate_xy[idx]
1473
+
1474
+ test_xy = sample_independent_points(psi_t_log, obs_xy, 200)
1475
+ obs_true = np.log10(ObservationModel.observation_operator_H(psi_t_log, test_xy) + 1e-6)
1476
+ obs_prior = np.log10(ObservationModel.observation_operator_H(psi_f_log, test_xy) + 1e-6)
1477
+ obs_analysis = np.log10(ObservationModel.observation_operator_H(psi_a_log, test_xy) + 1e-6)
1478
+
1479
+ def log10_formatter(x, pos):
1480
+ return r'$10^{%d}$' % x
1481
+
1482
+ fig, ax = plt.subplots(figsize=(5, 4.5), dpi=300)
1483
+
1484
+ vmin, vmax = -4, 2
1485
+ lim = [vmin, vmax]
1486
+
1487
+ # 1:1 line
1488
+ ax.plot(lim, lim, 'k--', lw=1, label='1:1 line', zorder=1)
1489
+
1490
+ for obs_pred, label, color in zip(
1491
+ [obs_prior, obs_analysis],
1492
+ ['Prior', 'Analysis'],
1493
+ ['steelblue', 'tomato']
1494
+ ):
1495
+
1496
+ # scatter
1497
+ ax.scatter(obs_true, obs_pred,
1498
+ s=30,
1499
+ alpha=0.65,
1500
+ color=color,
1501
+ zorder=3)
1502
+ slope, intercept, r, _, _ = stats.linregress(obs_true, obs_pred)
1503
+ rmse = np.sqrt(np.mean((obs_pred - obs_true) ** 2))
1504
+ x_fit = np.linspace(lim[0], lim[1], 100)
1505
+ ax.plot(x_fit, slope * x_fit + intercept, '-', color=color, lw=1.5,
1506
+ label=f'{label}r:{r:.2f}', zorder=2)
1507
+
1508
+ ax.set_xlim(vmin, vmax)
1509
+ ax.set_ylim(vmin, vmax)
1510
+ ax.set_xlabel('log(True)')
1511
+ ax.set_ylabel('log(Predicted)')
1512
+ ax.xaxis.set_major_formatter(ticker.FuncFormatter(log10_formatter))
1513
+ ax.yaxis.set_major_formatter(ticker.FuncFormatter(log10_formatter))
1514
+ ax.legend(loc='lower right', frameon=False)
1515
+ plt.tight_layout()
1516
+ plt.show()
1517
+
1518
+ def methods_comparison(source_idx=25,
1519
+ num_points = 10,
1520
+ methods = ['random', 'uniform', 'two_stage']):
1521
+
1522
+ for method in methods:
1523
+ print(f"\n=== 观测点采样方法: {method} ===")
1524
+ data = np.load(f'./dataset/assim_conds/fields_n{num_points}_{method}_obs1.npz', allow_pickle=True)
1525
+
1526
+ all_fields = data['all_fields']
1527
+ sample_data = all_fields[source_idx]
1528
+ psi_t_log = sample_data['trues_log']
1529
+ psi_f_log = sample_data['preds_log']
1530
+ psi_a_log = sample_data['analysis_log']
1531
+ psi_t_ppm = sample_data['trues_ppm']
1532
+ psi_f_ppm = sample_data['preds_ppm']
1533
+ psi_a_ppm = sample_data['analysis_ppm']
1534
+ obs_xy = sample_data['obs_xy']
1535
+ obs_value_log = sample_data['obs_value_log']
1536
+ obs_value_ppm = sample_data['obs_value_ppm']
1537
+ Visualization.plot_assimilation_with_building(
1538
+ true_field=psi_t_log,
1539
+ pred_field=psi_f_log,
1540
+ analysis=psi_a_log,
1541
+ obs_xy=obs_xy,
1542
+ vmax=10,
1543
+ title_suffix=f" (idx={source_idx})"
1544
+ )
1545
+ Visualization.plot_assimilation_4panel(
1546
+ true_field=psi_t_log,
1547
+ pred_field=psi_f_log,
1548
+ analysis=psi_a_log,
1549
+ obs_xy=obs_xy,
1550
+ obs_val=obs_value_log,
1551
+ vmax=10,
1552
+ title_suffix=f" (idx={source_idx})"
1553
+ )
1554
+ Visualization.plot_assimilation_4panel(
1555
+ true_field=psi_t_ppm,
1556
+ pred_field=psi_f_ppm,
1557
+ analysis=psi_a_ppm,
1558
+ obs_xy=obs_xy,
1559
+ obs_val=obs_value_ppm,
1560
+ vmax=200,
1561
+ title_suffix=f" in PPM SPACE (idx={source_idx})"
1562
+ )
1563
+
1564
+ def plot_n_hist_comparison(obs_tag=1, space_mode="log",
1565
+ methods=["random", "uniform", "two_stage"],
1566
+ n_list=[10, 20, 30, 40, 50],
1567
+ base_dir="./dataset/assim_conds",
1568
+ plot_mode="after",
1569
+ target_method="two_stage"):
1570
+ scope_labels = ["overall", "plume", "obs"]
1571
+ metric_keys = {
1572
+ "r2": ["r2", "r2_plume", "r2_obs"],
1573
+ "nmse": ["nmse", "nmse_plume", "nmse_obs"],
1574
+ "nmae": ["nmae", "nmae_plume", "nmae_obs"],
1575
+ }
1576
+
1577
+ def load_method_df(space, method, num_points):
1578
+ candidate_methods = [method]
1579
+ if method != "two_stage_pro":
1580
+ candidate_methods.append("two_stage_pro")
1581
+ for m in candidate_methods:
1582
+ fp = os.path.join(
1583
+ base_dir,
1584
+ f"assimi_{space}_n{num_points}_{m}_obs{obs_tag}.csv"
1585
+ )
1586
+ if os.path.exists(fp):
1587
+ if m != method:
1588
+ print(f"[Info] {method} not exsist, back to {fp}")
1589
+ return pd.read_csv(fp)
1590
+
1591
+ print(f"[Warning] file not found: {candidate_methods}")
1592
+ return None
1593
+
1594
+ def get_metric_value(df, metric_name):
1595
+ before_col = f"{metric_name}_before"
1596
+ after_col = f"{metric_name}_after"
1597
+
1598
+ if before_col not in df.columns or after_col not in df.columns:
1599
+ return np.nan
1600
+
1601
+ before_mean = df[before_col].mean()
1602
+ after_mean = df[after_col].mean()
1603
+
1604
+ if plot_mode == "delta":
1605
+ return after_mean - before_mean
1606
+ return after_mean
1607
+
1608
+ data = {
1609
+ "r2": np.full((len(n_list), 3), np.nan),
1610
+ "nmse": np.full((len(n_list), 3), np.nan),
1611
+ "nmae": np.full((len(n_list), 3), np.nan),
1612
+ }
1613
+
1614
+
1615
+ for i_n, n in enumerate(n_list):
1616
+ df = load_method_df(space_mode, target_method, n)
1617
+ if df is None:
1618
+ continue
1619
+
1620
+ for metric_type in ["r2", "nmse", "nmae"]:
1621
+ vals = []
1622
+ for mk in metric_keys[metric_type]:
1623
+ vals.append(get_metric_value(df, mk))
1624
+ data[metric_type][i_n, :] = vals
1625
+
1626
+ # print(f"space_mode = {space_mode}, plot_mode = {plot_mode}, method = {target_method}")
1627
+ # for metric_type in ["r2", "mse", "mae"]:
1628
+ # print(f"\n{metric_type.upper()}:")
1629
+ # print(pd.DataFrame(data[metric_type], index=n_list, columns=scope_labels))
1630
+
1631
+ fig, ax1 = plt.subplots(figsize=(12, 6), dpi=300)
1632
+ ax2 = ax1.twinx()
1633
+ x_base = np.arange(len(n_list))
1634
+ scope_offsets = {
1635
+ "overall": -0.24,
1636
+ "plume": 0.00,
1637
+ "obs": 0.24,
1638
+ }
1639
+ bar_w = 0.20
1640
+
1641
+ colors = cm.get_cmap("Blues")
1642
+ scope_colors = {
1643
+ "overall": colors(0.45),
1644
+ "plume": colors(0.65),
1645
+ "obs": colors(0.85),
1646
+ "edge": colors(0.85),
1647
+ }
1648
+ scope_linestyles = {
1649
+ "overall": "-",
1650
+ "plume": "--",
1651
+ "obs": ":",
1652
+ }
1653
+
1654
+ scope_markers_mse = {
1655
+ "overall": "o",
1656
+ "plume": "o",
1657
+ "obs": "o",
1658
+ }
1659
+
1660
+ scope_markers_mae = {
1661
+ "overall": "s",
1662
+ "plume": "s",
1663
+ "obs": "s",
1664
+ }
1665
+
1666
+ # -------------------------
1667
+ # 左轴:R2 柱状图
1668
+ # -------------------------
1669
+ text_r2 = r"$\mathit{R}^2$"
1670
+ for j, scope in enumerate(scope_labels):
1671
+ x = x_base + scope_offsets[scope]
1672
+ y = data["r2"][:, j]
1673
+
1674
+ ax1.bar(
1675
+ x, y,
1676
+ width=bar_w,
1677
+ color=scope_colors[scope],
1678
+ alpha=0.75,
1679
+ label=f"{text_r2}-{scope}",
1680
+ edgecolor=scope_colors["edge"],
1681
+ zorder=2
1682
+ )
1683
+ # 给每根柱子加数值
1684
+ for xi, yi in zip(x, y):
1685
+ if np.isfinite(yi):
1686
+ ax1.text(
1687
+ xi, yi + 0.001, f"{yi:.2f}",
1688
+ ha="center", va="bottom"
1689
+ )
1690
+
1691
+ ax1.set_ylabel(r"$\mathit{R}^2$")
1692
+ ax1.set_xticks(x_base)
1693
+ ax1.set_xticklabels([f"n={n}" for n in n_list])
1694
+ ax1.grid(axis="y", linestyle="--", alpha=0.25, zorder=0)
1695
+ # ax1.set_ylim(0.6, 1.1)
1696
+
1697
+ if plot_mode == "delta":
1698
+ ax1.axhline(0, color="k", linewidth=1)
1699
+
1700
+ # 在每个 n 下标出 overall / plume / obs
1701
+ y1_min, y1_max = ax1.get_ylim()
1702
+ y_text = y1_min - 0.06 * (y1_max - y1_min)
1703
+ # for i in range(len(n_list)):
1704
+ # ax1.text(x_base[i] + scope_offsets["overall"], y_text, "overall",
1705
+ # ha="center", va="top")
1706
+ # ax1.text(x_base[i] + scope_offsets["plume"], y_text, "plume",
1707
+ # ha="center", va="top")
1708
+ # ax1.text(x_base[i] + scope_offsets["obs"], y_text, "obs",
1709
+ # ha="center", va="top")
1710
+
1711
+ for j, scope in enumerate(scope_labels):
1712
+ x = x_base + scope_offsets[scope]
1713
+
1714
+ ax2.plot(
1715
+ x,
1716
+ data["nmse"][:, j],
1717
+ color=scope_colors[scope],
1718
+ linestyle="--",
1719
+ marker="o",
1720
+ linewidth=1.8,
1721
+ markersize=5,
1722
+ label=f"NMSE-{scope}",
1723
+ markeredgecolor=scope_colors["edge"],
1724
+ zorder=3
1725
+ )
1726
+
1727
+ ax2.plot(
1728
+ x,
1729
+ data["nmae"][:, j],
1730
+ color=scope_colors[scope],
1731
+ linestyle=":",
1732
+ marker="s",
1733
+ linewidth=1.8,
1734
+ markersize=5,
1735
+ markeredgecolor=scope_colors["edge"],
1736
+ label=f"NMAE-{scope}",
1737
+ zorder=3
1738
+ )
1739
+
1740
+ ax2.set_ylabel("NMSE / NMAE")
1741
+ # ax2.set_ylim(0, 0.5)
1742
+
1743
+ if plot_mode == "delta":
1744
+ ax2.axhline(0, color="gray", linewidth=1, alpha=0.6)
1745
+
1746
+ h1, l1 = ax1.get_legend_handles_labels()
1747
+ h2, l2 = ax2.get_legend_handles_labels()
1748
+
1749
+ ax1.legend(
1750
+ h1 + h2,
1751
+ l1 + l2,
1752
+ frameon=False,
1753
+ loc="center left",
1754
+ bbox_to_anchor=(1.15, 0.5)
1755
+ )
1756
+
1757
+ plt.tight_layout()
1758
+ plt.show()
1759
+
1760
+ def cal_all_test_resluts(config, conds_pkl_path=None,
1761
+ data_path=None,
1762
+ use_localization=True,
1763
+ save_fields_flag=True,
1764
+ save_metrics_flag=False):
1765
+ '''
1766
+ 使用说明:
1767
+ sample_method_lists = ["random", "uniform", "two_stage"]
1768
+ for method in sample_method_lists:
1769
+ config_test = {
1770
+ "num_points": 20,
1771
+ "sample_method": method,
1772
+ "obs_std_scale": 0.01,
1773
+ "damping": 1,
1774
+ "two_stage_params": {
1775
+ "min_dist": 28,
1776
+ "n1_ratio": 0.6,
1777
+ "stage1_support_frac": 0.2,
1778
+ "stage1_grad_power": 0.8,
1779
+ "stage1_value_power": 1.2,
1780
+ "stage1_center_boost": 1.2,
1781
+ },
1782
+ }
1783
+ cal_all_test_resluts(config_test)
1784
+ '''
1785
+ # ==== 加载数据 ====
1786
+ num_points = config['num_points']
1787
+ sample_method = config['sample_method'] # "random" or "uniform"
1788
+ # Kalman 参数
1789
+ obs_std_scale=config['obs_std_scale']
1790
+ damping=config['damping']
1791
+ sample_method = config["sample_method"]
1792
+ sample_params = {}
1793
+ params_key = f"{sample_method}_params"
1794
+ if params_key in config and isinstance(config[params_key], dict):
1795
+ sample_params = config[params_key]
1796
+
1797
+ if conds_pkl_path is not None:
1798
+ loader = DataLoader(
1799
+ pred_npz_path='./dataset/pre_data/all_test_pred2.npz',
1800
+ meta_txt_path='./dataset/pre_data/combined_test_special.txt',
1801
+ conds_pkl_path=conds_pkl_path
1802
+ )
1803
+ else:
1804
+ loader = DataLoader(
1805
+ pred_npz_path='./dataset/pre_data/all_test_pred2.npz',
1806
+ meta_txt_path='./dataset/pre_data/combined_test_special.txt',
1807
+ conds_pkl_path='./dataset/pre_data/pred_condition/test_results/conditioned_results_v0_5_d45_n40.pkl'
1808
+ )
1809
+ trues, preds = loader.trues, loader.preds
1810
+ # print(f"Total test samples: {len(preds)}")
1811
+ all_metrics_log = []
1812
+ all_metrics_ppm = []
1813
+ all_fields = []
1814
+ enkf = EnKF(obs_std_scale=obs_std_scale, damping=damping)
1815
+ for i in trange(len(preds), desc="Running assimilation"):
1816
+ psi_f_ppm, psi_t_ppm, conds_preds, meta = loader.get_sample(idx=i, in_ppm=True)
1817
+ psi_f_log = np.log1p(np.maximum(psi_f_ppm, 0))
1818
+ psi_t_log = np.log1p(np.maximum(psi_t_ppm, 0))
1819
+ conds_log = np.log1p(np.maximum(conds_preds, 0))
1820
+
1821
+ if sample_method == "smart_two_pass":
1822
+ n1_ratio = float(sample_params.get('n1_ratio', 0.6))
1823
+ n1_default = int(round(num_points * n1_ratio))
1824
+ n1 = int(sample_params.get('n1', n1_default))
1825
+ if num_points > 1:
1826
+ n1 = max(1, min(n1, num_points - 1))
1827
+ else:
1828
+ n1 = 1
1829
+ n2 = num_points - n1
1830
+
1831
+ psi_a_log, obs_xy, all_obs_val_log, _, _ = SamplingStrategies.smart_two_pass(
1832
+ enkf=enkf,
1833
+ psi_f=psi_f_log,
1834
+ conds_preds=conds_log,
1835
+ true_field=psi_t_log,
1836
+ n1=n1,
1837
+ n2=n2,
1838
+ phase1_method=sample_params.get('phase1_method', 'two_stage'),
1839
+ min_dist_p2=sample_params.get('min_dist_p2', 22),
1840
+ under_correct_alpha=sample_params.get('under_correct_alpha', 1.5),
1841
+ use_localization=sample_params.get('use_localization', use_localization),
1842
+ loc_radius_pixobs=sample_params.get('loc_radius_pixobs', 35.0),
1843
+ loc_radius_obsobs=sample_params.get('loc_radius_obsobs', 40.0),
1844
+ seed=42,
1845
+ verbose=sample_params.get('verbose', False),
1846
+ )
1847
+
1848
+ obs_value_log = np.asarray(all_obs_val_log)
1849
+ obs_value_ppm = DataLoader.log2ppm(obs_value_log)
1850
+ else:
1851
+ obs_xy, obs_value_ppm = SamplingStrategies.generate(psi_t_ppm, psi_f_ppm, num_points=num_points,
1852
+ seed=42, method=sample_method,
1853
+ ens_preds_ppm=conds_preds,
1854
+ **sample_params
1855
+ )
1856
+ d_obs_log = np.log1p(np.maximum(obs_value_ppm, 0)) # avoid log(0)
1857
+ if use_localization:
1858
+ psi_a_log = enkf.enkf_localization(psi_f_log, conds_log, obs_xy, d_obs_log,
1859
+ loc_radius_pixobs=35.0,
1860
+ loc_radius_obsobs=30.0)
1861
+ else:
1862
+ psi_a_log = enkf.standard_enkf(psi_f_log, conds_log, obs_xy, d_obs_log)
1863
+
1864
+ # 计算innovation,判断是否需要同化
1865
+ obs_prior_at_obs = np.log1p(np.maximum(
1866
+ ObservationModel.observation_operator_H(psi_f_ppm, obs_xy), 0
1867
+ ))
1868
+ obs_innovation = np.mean(np.abs(obs_prior_at_obs - d_obs_log))
1869
+ threshold = config.get('innovation_threshold', 0.05)
1870
+
1871
+ if obs_innovation < threshold:
1872
+ psi_a_log = psi_f_log
1873
+ else:
1874
+ psi_a_log = enkf.enkf_localization(
1875
+ psi_f_log, conds_log, obs_xy, d_obs_log,
1876
+ loc_radius_pixobs=35.0,
1877
+ loc_radius_obsobs=30.0
1878
+ )
1879
+ obs_value_log = np.log1p(np.maximum(obs_value_ppm, 0)) # avoid log(0)
1880
+
1881
+ psi_a_ppm = DataLoader.log2ppm(psi_a_log)
1882
+
1883
+ # 计算指标
1884
+ metrics_log = PrintMetrics.print_metrics(
1885
+ i=i,
1886
+ wind_speed=meta['wind_speed'],
1887
+ wind_direction=meta['wind_direction'],
1888
+ sc=meta['sc'],
1889
+ source_number=meta['source_number'],
1890
+ true_field=psi_t_log,
1891
+ pred_field=psi_f_log,
1892
+ analysis=psi_a_log,
1893
+ obs_xy=obs_xy,
1894
+ metrics_save_flag=True,
1895
+ metrics_print_flag=False
1896
+ )
1897
+ all_metrics_log.append(metrics_log)
1898
+ metrics_ppm = PrintMetrics.print_metrics(
1899
+ i=i,
1900
+ wind_speed=meta['wind_speed'],
1901
+ wind_direction=meta['wind_direction'],
1902
+ sc=meta['sc'],
1903
+ source_number=meta['source_number'],
1904
+ true_field=psi_t_ppm,
1905
+ pred_field=psi_f_ppm,
1906
+ analysis=psi_a_ppm,
1907
+ obs_xy=obs_xy,
1908
+ metrics_save_flag=True,
1909
+ metrics_print_flag=False
1910
+ )
1911
+ all_metrics_ppm.append(metrics_ppm)
1912
+ all_fields.append({
1913
+ "idx": i,
1914
+ "trues_log": psi_t_log,
1915
+ "preds_log": psi_f_log,
1916
+ "analysis_log": psi_a_log,
1917
+ "trues_ppm": psi_t_ppm,
1918
+ "preds_ppm": psi_f_ppm,
1919
+ "analysis_ppm": psi_a_ppm,
1920
+ "obs_xy": obs_xy,
1921
+ "obs_value_log": obs_value_log,
1922
+ "obs_value_ppm": obs_value_ppm,
1923
+ })
1924
+ data_paths = f'./dataset/assim_conds/{data_path}'
1925
+ if not os.path.exists(data_paths):
1926
+ os.makedirs(data_paths)
1927
+ if save_fields_flag:
1928
+ np.savez_compressed(f'./dataset/assim_conds/{data_path}/fields_n{num_points}_{sample_method}_obs{int(obs_std_scale*100)}_damping{damping}.npz',
1929
+ all_fields=all_fields)
1930
+ all_metrics_df_log = pd.DataFrame(all_metrics_log)
1931
+ all_metrics_df_ppm = pd.DataFrame(all_metrics_ppm)
1932
+ if save_metrics_flag:
1933
+ all_metrics_df_ppm.to_csv(f'./dataset/assim_conds/{data_path}/assimi_ppm_n{num_points}_{sample_method}_obs{int(obs_std_scale*100)}_damping{damping}.csv', index=False)
1934
+ all_metrics_df_log.to_csv(f'./dataset/assim_conds/{data_path}/assimi_log_n{num_points}_{sample_method}_obs{int(obs_std_scale*100)}_damping{damping}.csv', index=False)
1935
+ print("\n=== 平均指标提升 ===")
1936
+ metrics_list = ['r2', 'w_r2_plume', 'r2_plume','mse', 'mae']
1937
+ for metric in metrics_list:
1938
+ before_mean = all_metrics_df_log[f"{metric}_before"].mean()
1939
+ after_mean = all_metrics_df_log[f"{metric}_after"].mean()
1940
+ delta = after_mean - before_mean
1941
+ print(f'{metric.upper()}: before={before_mean:.4f}, after={after_mean:.4f}, delta={delta:.4f}')
1942
+ before_mean_ppm = all_metrics_df_ppm[f"{metric}_before"].mean()
1943
+ after_mean_ppm = all_metrics_df_ppm[f"{metric}_after"].mean()
1944
+ delta_ppm = after_mean_ppm - before_mean_ppm
1945
+ print(f'PPM {metric.upper()}: before={before_mean_ppm:.4f}, after={after_mean_ppm:.4f}, delta={delta_ppm:.4f}')