asteroidddd commited on
Commit
ba69daf
ยท
1 Parent(s): b6795f3

Add trained pipeline + preprocessing code

Browse files
Files changed (2) hide show
  1. auction_pipeline.pkl +2 -2
  2. onbid-map-round-train.py +21 -26
auction_pipeline.pkl CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:1b2302cd4ef6f2af0d667e28288ebf90cf823cef5f08a4e372f443d506f8a42e
3
- size 3567270
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3c1acd90a22aaf1e5520ebfc531247896e5d33df2f3fb66e0e9e74020f59ed71
3
+ size 3575369
onbid-map-round-train.py CHANGED
@@ -5,34 +5,22 @@ import shutil
5
  import stat
6
  import pandas as pd
7
  import joblib
8
- from sklearn.preprocessing import OneHotEncoder, LabelEncoder, FunctionTransformer
9
  from sklearn.compose import ColumnTransformer
10
  from sklearn.pipeline import Pipeline
11
  from xgboost import XGBClassifier
12
  from huggingface_hub import HfApi, Repository
13
 
14
- # ํ™˜๊ฒฝ ๋ณ€์ˆ˜์—์„œ ํ† ํฐ ์ฝ์–ด์˜ค๊ธฐ
15
  HF_REPO_NAME = "asteroidddd/onbid-map-round"
16
  HF_TOKEN = os.getenv("HF_TOKEN")
17
  if HF_TOKEN is None:
18
  raise ValueError("ํ™˜๊ฒฝ ๋ณ€์ˆ˜ HF_TOKEN์ด ์„ค์ •๋˜์–ด ์žˆ์ง€ ์•Š์Šต๋‹ˆ๋‹ค.")
19
 
20
- # ์ด ์Šคํฌ๋ฆฝํŠธ์˜ ๊ฒฝ๋กœ์™€ ํŒŒ์ผ๋ช…
21
  SCRIPT_PATH = os.path.abspath(__file__)
22
- SCRIPT_NAME = os.path.basename(SCRIPT_PATH)
23
-
24
- def extract_date_features(df):
25
- """์ตœ์ดˆ์ž…์ฐฐ์‹œ๊ธฐ์—์„œ ์—ฐ๋„/์›”/์ผ/์š”์ผ์„ ์ถ”์ถœํ•˜๊ณ  ์›๋ณธ ์ปฌ๋Ÿผ ์ œ๊ฑฐ."""
26
- X = df.copy()
27
- dt = pd.to_datetime(X["์ตœ์ดˆ์ž…์ฐฐ์‹œ๊ธฐ"])
28
- X["์ตœ์ดˆ์ž…์ฐฐ_์—ฐ๋„"] = dt.dt.year
29
- X["์ตœ์ดˆ์ž…์ฐฐ_์›”"] = dt.dt.month
30
- X["์ตœ์ดˆ์ž…์ฐฐ_์ผ"] = dt.dt.day
31
- X["์ตœ์ดˆ์ž…์ฐฐ_์š”์ผ"] = dt.dt.weekday
32
- return X.drop(columns=["์ตœ์ดˆ์ž…์ฐฐ์‹œ๊ธฐ"])
33
 
34
  def rm_readonly(func, path, exc_info):
35
- """์ฝ๊ธฐ ์ „์šฉ ํŒŒ์ผ ์‚ญ์ œ ์‹œ ๊ถŒํ•œ ๋ณ€๊ฒฝ ํ›„ ์žฌ์‹œ๋„."""
36
  os.chmod(path, stat.S_IWRITE)
37
  func(path)
38
 
@@ -40,22 +28,30 @@ def main():
40
  # ๋ฐ์ดํ„ฐ ๋กœ๋“œ
41
  df = pd.read_pickle(r'C:\Users\hwang\Desktop\OSSP\data.pkl')
42
 
43
- # ๋‚™์ฐฐ์ฐจ์ˆ˜ ๋ ˆ์ด๋ธ” ์ธ์ฝ”๋”ฉ ํ›„, ๋นˆ๋„ โ‰ค 10์ธ ํด๋ž˜์Šค ์ œ๊ฑฐ
44
  le_label = LabelEncoder()
45
  df["๋‚™์ฐฐ์ฐจ์ˆ˜_LE"] = le_label.fit_transform(df["๋‚™์ฐฐ์ฐจ์ˆ˜"])
46
  counts = df["๋‚™์ฐฐ์ฐจ์ˆ˜_LE"].value_counts()
47
  rare = counts[counts <= 10].index.tolist()
48
  df = df[~df["๋‚™์ฐฐ์ฐจ์ˆ˜_LE"].isin(rare)].reset_index(drop=True)
49
 
50
- # ์ž…๋ ฅ(X)๊ณผ ํƒ€๊นƒ(y) ๋ถ„๋ฆฌ
51
- X = df[["๋Œ€๋ถ„๋ฅ˜", "์ค‘๋ถ„๋ฅ˜", "๊ธฐ๊ด€", "์ตœ์ดˆ์ž…์ฐฐ์‹œ๊ธฐ", "1์ฐจ์ตœ์ €์ž…์ฐฐ๊ฐ€"]]
 
 
 
 
 
 
 
 
 
52
  y = df["๋‚™์ฐฐ์ฐจ์ˆ˜_LE"]
53
 
54
- # ์ „์ฒ˜๋ฆฌ ๋ฐ ๋ชจ๋ธ ํŒŒ์ดํ”„๋ผ์ธ ์ •์˜
55
  cat_cols = ["๋Œ€๋ถ„๋ฅ˜", "์ค‘๋ถ„๋ฅ˜", "๊ธฐ๊ด€"]
56
  preprocessor = ColumnTransformer(
57
  transformers=[
58
- ("datefeat", FunctionTransformer(extract_date_features, validate=False), ["์ตœ์ดˆ์ž…์ฐฐ์‹œ๊ธฐ"]),
59
  ("ohe", OneHotEncoder(handle_unknown="ignore"), cat_cols)
60
  ],
61
  remainder="passthrough"
@@ -65,10 +61,10 @@ def main():
65
  ("classifier", XGBClassifier(eval_metric="mlogloss", random_state=42))
66
  ])
67
 
68
- # ํŒŒ์ดํ”„๋ผ์ธ ํ•™์Šต
69
  pipeline.fit(X, y)
70
 
71
- # ํ•™์Šต๋œ ํŒŒ์ดํ”„๋ผ์ธ๊ณผ ๋ ˆ์ด๋ธ” ์ธ์ฝ”๋” ์ €์žฅ
72
  os.makedirs("output", exist_ok=True)
73
  pipeline_path = "output/auction_pipeline.pkl"
74
  label_path = "output/label_encoder.pkl"
@@ -80,20 +76,20 @@ def main():
80
  with open("requirements.txt", "w", encoding="utf-8") as f:
81
  f.write("\n".join(deps))
82
 
83
- # Hugging Face ๋ ˆํฌ์ง€ํ† ๋ฆฌ ์ƒ์„ฑ ์‹œ๋„
84
  api = HfApi()
85
  try:
86
  api.create_repo(repo_id=HF_REPO_NAME, token=HF_TOKEN)
87
  except:
88
  pass
89
 
90
- # ๋กœ์ปฌ์— ๋ ˆํฌ ํด๋ก  (๊ธฐ์กด ์‚ญ์ œ ์‹œ read-only ์˜ค๋ฅ˜ ์ฒ˜๋ฆฌ)
91
  local_dir = "hf_repo"
92
  if os.path.isdir(local_dir):
93
  shutil.rmtree(local_dir, onerror=rm_readonly)
94
  repo = Repository(local_dir=local_dir, clone_from=HF_REPO_NAME, use_auth_token=HF_TOKEN)
95
 
96
- # ํ•„์š”ํ•œ ํŒŒ์ผ ๋ณต์‚ฌ
97
  for src in [SCRIPT_PATH, "requirements.txt", pipeline_path, label_path]:
98
  dst = os.path.join(local_dir, os.path.basename(src))
99
  shutil.copy(src, dst)
@@ -103,6 +99,5 @@ def main():
103
  repo.git_commit("Add trained pipeline + preprocessing code")
104
  repo.git_push()
105
 
106
-
107
  if __name__ == "__main__":
108
  main()
 
5
  import stat
6
  import pandas as pd
7
  import joblib
8
+ from sklearn.preprocessing import OneHotEncoder, LabelEncoder
9
  from sklearn.compose import ColumnTransformer
10
  from sklearn.pipeline import Pipeline
11
  from xgboost import XGBClassifier
12
  from huggingface_hub import HfApi, Repository
13
 
14
+ # ํ™˜๊ฒฝ ๋ณ€์ˆ˜์—์„œ Hugging Face ํ† ํฐ ์ฝ๊ธฐ
15
  HF_REPO_NAME = "asteroidddd/onbid-map-round"
16
  HF_TOKEN = os.getenv("HF_TOKEN")
17
  if HF_TOKEN is None:
18
  raise ValueError("ํ™˜๊ฒฝ ๋ณ€์ˆ˜ HF_TOKEN์ด ์„ค์ •๋˜์–ด ์žˆ์ง€ ์•Š์Šต๋‹ˆ๋‹ค.")
19
 
20
+ # ์ด ์Šคํฌ๋ฆฝํŠธ์˜ ๊ฒฝ๋กœ
21
  SCRIPT_PATH = os.path.abspath(__file__)
 
 
 
 
 
 
 
 
 
 
 
22
 
23
  def rm_readonly(func, path, exc_info):
 
24
  os.chmod(path, stat.S_IWRITE)
25
  func(path)
26
 
 
28
  # ๋ฐ์ดํ„ฐ ๋กœ๋“œ
29
  df = pd.read_pickle(r'C:\Users\hwang\Desktop\OSSP\data.pkl')
30
 
31
+ # ๋ผ๋ฒจ ์ธ์ฝ”๋”ฉ & ๋นˆ๋„ โ‰ค 10์ธ ํด๋ž˜์Šค ์ œ๊ฑฐ
32
  le_label = LabelEncoder()
33
  df["๋‚™์ฐฐ์ฐจ์ˆ˜_LE"] = le_label.fit_transform(df["๋‚™์ฐฐ์ฐจ์ˆ˜"])
34
  counts = df["๋‚™์ฐฐ์ฐจ์ˆ˜_LE"].value_counts()
35
  rare = counts[counts <= 10].index.tolist()
36
  df = df[~df["๋‚™์ฐฐ์ฐจ์ˆ˜_LE"].isin(rare)].reset_index(drop=True)
37
 
38
+ # ๋‚ ์งœ ํŒŒ์ƒ ๋ณ€์ˆ˜ ์ƒ์„ฑ
39
+ df["์ตœ์ดˆ์ž…์ฐฐ_์—ฐ๋„"] = df["์ตœ์ดˆ์ž…์ฐฐ์‹œ๊ธฐ"].dt.year
40
+ df["์ตœ์ดˆ์ž…์ฐฐ_์›”"] = df["์ตœ์ดˆ์ž…์ฐฐ์‹œ๊ธฐ"].dt.month
41
+ df["์ตœ์ดˆ์ž…์ฐฐ_์ผ"] = df["์ตœ์ดˆ์ž…์ฐฐ์‹œ๊ธฐ"].dt.day
42
+ df["์ตœ์ดˆ์ž…์ฐฐ_์š”์ผ"] = df["์ตœ์ดˆ์ž…์ฐฐ์‹œ๊ธฐ"].dt.weekday
43
+ df = df.drop(columns=["์ตœ์ดˆ์ž…์ฐฐ์‹œ๊ธฐ"])
44
+
45
+ # ํ”ผ์ฒ˜/ํƒ€๊นƒ ๋ถ„๋ฆฌ
46
+ X = df[["๋Œ€๋ถ„๋ฅ˜", "์ค‘๋ถ„๋ฅ˜", "๊ธฐ๊ด€",
47
+ "์ตœ์ดˆ์ž…์ฐฐ_์—ฐ๋„", "์ตœ์ดˆ์ž…์ฐฐ_์›”", "์ตœ์ดˆ์ž…์ฐฐ_์ผ", "์ตœ์ดˆ์ž…์ฐฐ_์š”์ผ",
48
+ "1์ฐจ์ตœ์ €์ž…์ฐฐ๊ฐ€"]]
49
  y = df["๋‚™์ฐฐ์ฐจ์ˆ˜_LE"]
50
 
51
+ # ์ „์ฒ˜๋ฆฌ + ๋ชจ๋ธ ํŒŒ์ดํ”„๋ผ์ธ
52
  cat_cols = ["๋Œ€๋ถ„๋ฅ˜", "์ค‘๋ถ„๋ฅ˜", "๊ธฐ๊ด€"]
53
  preprocessor = ColumnTransformer(
54
  transformers=[
 
55
  ("ohe", OneHotEncoder(handle_unknown="ignore"), cat_cols)
56
  ],
57
  remainder="passthrough"
 
61
  ("classifier", XGBClassifier(eval_metric="mlogloss", random_state=42))
62
  ])
63
 
64
+ # ํ•™์Šต
65
  pipeline.fit(X, y)
66
 
67
+ # ํŒŒ์ดํ”„๋ผ์ธ & ๋ผ๋ฒจ ์ธ์ฝ”๋” ์ €์žฅ
68
  os.makedirs("output", exist_ok=True)
69
  pipeline_path = "output/auction_pipeline.pkl"
70
  label_path = "output/label_encoder.pkl"
 
76
  with open("requirements.txt", "w", encoding="utf-8") as f:
77
  f.write("\n".join(deps))
78
 
79
+ # Hugging Face ๋ ˆํฌ ์ƒ์„ฑ ์‹œ๋„
80
  api = HfApi()
81
  try:
82
  api.create_repo(repo_id=HF_REPO_NAME, token=HF_TOKEN)
83
  except:
84
  pass
85
 
86
+ # ๋กœ์ปฌ์— ๋ ˆํฌ ํด๋ก  (๊ธฐ์กด ์‚ญ์ œ ์‹œ read-only ์ฒ˜๋ฆฌ)
87
  local_dir = "hf_repo"
88
  if os.path.isdir(local_dir):
89
  shutil.rmtree(local_dir, onerror=rm_readonly)
90
  repo = Repository(local_dir=local_dir, clone_from=HF_REPO_NAME, use_auth_token=HF_TOKEN)
91
 
92
+ # ํŒŒ์ผ ๋ณต์‚ฌ
93
  for src in [SCRIPT_PATH, "requirements.txt", pipeline_path, label_path]:
94
  dst = os.path.join(local_dir, os.path.basename(src))
95
  shutil.copy(src, dst)
 
99
  repo.git_commit("Add trained pipeline + preprocessing code")
100
  repo.git_push()
101
 
 
102
  if __name__ == "__main__":
103
  main()