clarindasusan commited on
Commit
72d7353
·
verified ·
1 Parent(s): c91dc2e

Update src/train_model.py

Browse files
Files changed (1) hide show
  1. src/train_model.py +8 -6
src/train_model.py CHANGED
@@ -49,12 +49,14 @@ torch.manual_seed(SEED)
49
  # y shape: (n_samples,) — continuous risk score in [0, 1]
50
 
51
  def generate_flood_data(n: int = 5000):
52
- rainfall = pd.read_csv("rainfall_clean.csv")
53
- flood_hist = pd.read_csv("flood_history_clean.csv")
54
- soil = pd.read_csv("soil_moisture.csv")
55
- drainage = pd.read_csv("drainage_capacity.csv")
56
- rivers = pd.read_csv("river_network.csv")
57
- elevation = pd.read_csv("elevation.csv")
 
 
58
 
59
  # ==============================
60
  # Prepare flood labels
 
49
  # y shape: (n_samples,) — continuous risk score in [0, 1]
50
 
51
  def generate_flood_data(n: int = 5000):
52
+ DATA_DIR = os.getenv("DATA_DIR", "/src/data")
53
+
54
+ rainfall = pd.read_csv(os.path.join(DATA_DIR, "rainfall_clean.csv"))
55
+ flood_hist = pd.read_csv(os.path.join(DATA_DIR, "flood_history_clean.csv"))
56
+ soil = pd.read_csv(os.path.join(DATA_DIR, "soil_moisture.csv"))
57
+ drainage = pd.read_csv(os.path.join(DATA_DIR, "drainage_capacity.csv"))
58
+ rivers = pd.read_csv(os.path.join(DATA_DIR, "river_network.csv"))
59
+ elevation = pd.read_csv(os.path.join(DATA_DIR, "elevation.csv"))
60
 
61
  # ==============================
62
  # Prepare flood labels