sharktide commited on
Commit
23e714d
·
verified ·
1 Parent(s): 848217c

Update load.py

Browse files
Files changed (1) hide show
  1. load.py +35 -1
load.py CHANGED
@@ -17,6 +17,39 @@ hft.download_model("sharktide", "HurricaneTrustNet")
17
  hft.download_model("sharktide", "TornadoNet")
18
  hft.download_model("sharktide", "TornadoTrustNet")
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  FireNet = hft.load_model("sharktide", "FireNet", "tf_model.h5", True)
21
  FireTrustNet = hft.load_model("sharktide", "FireTrustNet", "tf_model.h5", True)
22
  FireScaler = joblib.load("scalers/firetrust_scaler.pkl")
@@ -25,7 +58,8 @@ FloodNet = hft.load_model("sharktide", "FV-FloodNet", "tf_model.h5", True)
25
  FloodTrustNet = hft.load_model("sharktide", "FV-FloodTrustNet", "tf_model.h5", True)
26
  FloodScaler = joblib.load("scalers/FV-floodtrust_scaler.pkl")
27
 
28
- PV_FloodNet = hft.load_model("sharktide", "PV-FloodNet", "tf_model.h5", True)
 
29
  PV_FloodTrustNet = hft.load_model("sharktide", "PV-FloodTrustNet", "tf_model.h5", True)
30
  PV_FloodScaler = joblib.load("scalers/PV-floodtrust_scaler.pkl")
31
 
 
17
  hft.download_model("sharktide", "TornadoNet")
18
  hft.download_model("sharktide", "TornadoTrustNet")
19
 
20
+ import tensorflow as tf
21
+ from tensorflow.keras import layers, models, callbacks
22
+ from tensorflow.keras.saving import register_keras_serializable
23
+
24
+ @register_keras_serializable()
25
+ def surface_runoff_amplifier(inputs):
26
+ rain = inputs[:, 0]
27
+ impervious = inputs[:, 1]
28
+ rain_boost = tf.sigmoid((rain - 60) * 0.06)
29
+ impervious_boost = tf.sigmoid((impervious - 0.6) * 10)
30
+ return (1.0 + 0.3 * rain_boost * impervious_boost)[:, None]
31
+
32
+ @register_keras_serializable()
33
+ def drainage_penalty(inputs):
34
+ dd = inputs[:, 2]
35
+ return (1.0 - 0.4 * tf.sigmoid((dd - 3.5) * 2))[:, None]
36
+
37
+ @register_keras_serializable()
38
+ def convergence_suppressor(inputs):
39
+ ci = inputs[:, 4]
40
+ return (1.0 + 0.3 * tf.sigmoid((ci - 0.5) * 8))[:, None]
41
+
42
+ @register_keras_serializable()
43
+ def clip_modulation(x):
44
+ return tf.clip_by_value(x, 0.7, 1.3)
45
+
46
+ CUSTOM_OBJECTS = {
47
+ 'drainage_penalty': drainage_penalty,
48
+ 'convergence_suppressor': convergence_suppressor,
49
+ 'surface_runoff_amplifier': surface_runoff_amplifier,
50
+ 'clip_modulation': clip_modulation
51
+ }
52
+
53
  FireNet = hft.load_model("sharktide", "FireNet", "tf_model.h5", True)
54
  FireTrustNet = hft.load_model("sharktide", "FireTrustNet", "tf_model.h5", True)
55
  FireScaler = joblib.load("scalers/firetrust_scaler.pkl")
 
58
  FloodTrustNet = hft.load_model("sharktide", "FV-FloodTrustNet", "tf_model.h5", True)
59
  FloodScaler = joblib.load("scalers/FV-floodtrust_scaler.pkl")
60
 
61
+ get_path = lambda usr, model: (str(hft.get_model_folder(user, model)) + "tf_model.h5")
62
+ PV_FloodNet = tf.keras.models.load_model(get_path("sharktide", "PV-FloodNet"), safe_mode=False, custom_objects=CUSTOM_OBJECTS)
63
  PV_FloodTrustNet = hft.load_model("sharktide", "PV-FloodTrustNet", "tf_model.h5", True)
64
  PV_FloodScaler = joblib.load("scalers/PV-floodtrust_scaler.pkl")
65