Syntrex commited on
Commit
6e8fc87
·
verified ·
1 Parent(s): 5f24e3f

Update utils/import_savant_csvs.py

Browse files
Files changed (1) hide show
  1. utils/import_savant_csvs.py +73 -32
utils/import_savant_csvs.py CHANGED
@@ -1,5 +1,10 @@
 
 
1
  from pathlib import Path
 
 
2
  import pandas as pd
 
3
 
4
  from database.remote_db import get_connection
5
 
@@ -7,48 +12,84 @@ from database.remote_db import get_connection
7
  BATTER_CSV_PATH = Path("data/batter_savant_data.csv")
8
  PITCHER_CSV_PATH = Path("data/pitcher_savant_data.csv")
9
 
 
10
 
11
- def import_savant_csvs():
12
 
13
- if not BATTER_CSV_PATH.exists():
14
- raise FileNotFoundError("Missing batter CSV")
 
 
 
 
15
 
16
- if not PITCHER_CSV_PATH.exists():
17
- raise FileNotFoundError("Missing pitcher CSV")
18
 
19
- batter_df = pd.read_csv(BATTER_CSV_PATH)
20
- pitcher_df = pd.read_csv(PITCHER_CSV_PATH)
 
21
 
22
- batter_df.columns = [c.lower() for c in batter_df.columns]
23
- pitcher_df.columns = [c.lower() for c in pitcher_df.columns]
24
 
25
- conn = get_connection()
 
 
26
 
 
27
  try:
28
- batter_df.to_sql(
29
- "mlb_batter_statcast_features",
30
- conn,
31
- if_exists="append",
32
- index=False,
33
- method="multi",
34
- chunksize=1000,
35
- )
36
-
37
- pitcher_df.to_sql(
38
- "mlb_pitcher_statcast_features",
39
- conn,
40
- if_exists="append",
41
- index=False,
42
- method="multi",
43
- chunksize=1000,
44
- )
45
-
46
  conn.commit()
47
-
48
  finally:
49
  conn.close()
50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  return {
52
- "batter_rows": len(batter_df),
53
- "pitcher_rows": len(pitcher_df),
54
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
  from pathlib import Path
4
+ from typing import Iterable
5
+
6
  import pandas as pd
7
+ from sqlalchemy import text
8
 
9
  from database.remote_db import get_connection
10
 
 
12
  BATTER_CSV_PATH = Path("data/batter_savant_data.csv")
13
  PITCHER_CSV_PATH = Path("data/pitcher_savant_data.csv")
14
 
15
+ CHUNK_SIZE = 500
16
 
 
17
 
18
+ def _clean_dataframe(df: pd.DataFrame, source_file: str) -> pd.DataFrame:
19
+ df = df.copy()
20
+ df.columns = [str(c).strip().lower() for c in df.columns]
21
+ df["source_file"] = source_file
22
+ df = df.where(pd.notnull(df), None)
23
+ return df
24
 
 
 
25
 
26
+ def _chunk_dataframe(df: pd.DataFrame, chunk_size: int) -> Iterable[pd.DataFrame]:
27
+ for start in range(0, len(df), chunk_size):
28
+ yield df.iloc[start:start + chunk_size].copy()
29
 
 
 
30
 
31
+ def _truncate_if_requested(table_name: str, clear_first: bool) -> None:
32
+ if not clear_first:
33
+ return
34
 
35
+ conn = get_connection()
36
  try:
37
+ conn.execute(text(f"TRUNCATE TABLE {table_name}"))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  conn.commit()
 
39
  finally:
40
  conn.close()
41
 
42
+
43
+ def _import_dataframe(df: pd.DataFrame, table_name: str, clear_first: bool = False) -> dict:
44
+ _truncate_if_requested(table_name, clear_first=clear_first)
45
+
46
+ total_rows = len(df)
47
+ inserted_rows = 0
48
+
49
+ for chunk in _chunk_dataframe(df, CHUNK_SIZE):
50
+ conn = get_connection()
51
+ try:
52
+ chunk.to_sql(
53
+ table_name,
54
+ conn,
55
+ if_exists="append",
56
+ index=False,
57
+ method="multi",
58
+ chunksize=250,
59
+ )
60
+ conn.commit()
61
+ inserted_rows += len(chunk)
62
+ finally:
63
+ conn.close()
64
+
65
  return {
66
+ "table_name": table_name,
67
+ "total_rows": total_rows,
68
+ "inserted_rows": inserted_rows,
69
+ }
70
+
71
+
72
+ def import_batter_savant_csv(clear_first: bool = False) -> dict:
73
+ if not BATTER_CSV_PATH.exists():
74
+ raise FileNotFoundError(f"Missing file: {BATTER_CSV_PATH}")
75
+
76
+ batter_df = pd.read_csv(BATTER_CSV_PATH)
77
+ batter_df = _clean_dataframe(batter_df, "batter_savant_data.csv")
78
+ return _import_dataframe(
79
+ batter_df,
80
+ "mlb_batter_statcast_features",
81
+ clear_first=clear_first,
82
+ )
83
+
84
+
85
+ def import_pitcher_savant_csv(clear_first: bool = False) -> dict:
86
+ if not PITCHER_CSV_PATH.exists():
87
+ raise FileNotFoundError(f"Missing file: {PITCHER_CSV_PATH}")
88
+
89
+ pitcher_df = pd.read_csv(PITCHER_CSV_PATH)
90
+ pitcher_df = _clean_dataframe(pitcher_df, "pitcher_savant_data.csv")
91
+ return _import_dataframe(
92
+ pitcher_df,
93
+ "mlb_pitcher_statcast_features",
94
+ clear_first=clear_first,
95
+ )