asteroidddd commited on
Commit
5a2a0c6
ยท
1 Parent(s): 4b4d652

Add trained pipeline + preprocessing code

Browse files
Files changed (1) hide show
  1. onbid-map-round-train.py +18 -15
onbid-map-round-train.py CHANGED
@@ -12,18 +12,19 @@ from sklearn.pipeline import Pipeline
12
  from xgboost import XGBClassifier
13
  from huggingface_hub import HfApi, Repository
14
 
15
- # ํ™˜๊ฒฝ ๋ณ€์ˆ˜์—์„œ ํ† ํฐ ์ฝ์–ด์˜ค๊ธฐ
16
  HF_REPO_NAME = "asteroidddd/onbid-map-round"
17
  HF_TOKEN = os.getenv("HF_TOKEN")
18
  if HF_TOKEN is None:
19
  raise ValueError("ํ™˜๊ฒฝ ๋ณ€์ˆ˜ HF_TOKEN์ด ์„ค์ •๋˜์–ด ์žˆ์ง€ ์•Š์Šต๋‹ˆ๋‹ค.")
20
 
21
- # ์Šคํฌ๋ฆฝํŠธ ๊ฒฝ๋กœ/์ด๋ฆ„
22
  SCRIPT_PATH = os.path.abspath(__file__)
23
  SCRIPT_NAME = os.path.basename(SCRIPT_PATH)
24
 
25
-
26
  class DateFeatures(BaseEstimator, TransformerMixin):
 
27
  def __init__(self, date_column="์ตœ์ดˆ์ž…์ฐฐ์‹œ๊ธฐ"):
28
  self.date_column = date_column
29
 
@@ -40,27 +41,29 @@ class DateFeatures(BaseEstimator, TransformerMixin):
40
  return X.drop(columns=[self.date_column])
41
 
42
 
 
43
  def rm_readonly(func, path, exc_info):
44
  os.chmod(path, stat.S_IWRITE)
45
  func(path)
46
 
47
 
48
  def main():
49
- # 1) ๋ฐ์ดํ„ฐ ๋กœ๋“œ
 
50
  df = pd.read_pickle(r'C:\Users\hwang\Desktop\OSSP\data.pkl')
51
 
52
- # 2) ๋ผ๋ฒจ ์ธ์ฝ”๋”ฉ ๋ฐ ํฌ๊ท€ ๋ ˆ์ด๋ธ” ์ œ๊ฑฐ
53
  le_label = LabelEncoder()
54
  df["๋‚™์ฐฐ์ฐจ์ˆ˜_LE"] = le_label.fit_transform(df["๋‚™์ฐฐ์ฐจ์ˆ˜"])
55
  counts = df["๋‚™์ฐฐ์ฐจ์ˆ˜_LE"].value_counts()
56
  rare = counts[counts <= 10].index.tolist()
57
  df = df[~df["๋‚™์ฐฐ์ฐจ์ˆ˜_LE"].isin(rare)].reset_index(drop=True)
58
 
59
- # 3) ํ”ผ์ฒ˜/ํƒ€๊นƒ ๋ถ„๋ฆฌ
60
  X = df[["๋Œ€๋ถ„๋ฅ˜", "์ค‘๋ถ„๋ฅ˜", "๊ธฐ๊ด€", "์ตœ์ดˆ์ž…์ฐฐ์‹œ๊ธฐ", "1์ฐจ์ตœ์ €์ž…์ฐฐ๊ฐ€"]]
61
  y = df["๋‚™์ฐฐ์ฐจ์ˆ˜_LE"]
62
 
63
- # 4) ์ „์ฒ˜๋ฆฌ + ๋ชจ๋ธ ํŒŒ์ดํ”„๋ผ์ธ
64
  cat_cols = ["๋Œ€๋ถ„๋ฅ˜", "์ค‘๋ถ„๋ฅ˜", "๊ธฐ๊ด€"]
65
  preprocessor = ColumnTransformer(
66
  transformers=[
@@ -74,40 +77,40 @@ def main():
74
  ("classifier", XGBClassifier(eval_metric="mlogloss", random_state=42))
75
  ])
76
 
77
- # 5) ํ•™์Šต
78
  pipeline.fit(X, y)
79
 
80
- # 6) ํŒŒ์ดํ”„๋ผ์ธ ๋ฐ ๋ผ๋ฒจ ์ธ์ฝ”๋” ์ €์žฅ
81
  os.makedirs("output", exist_ok=True)
82
  pipeline_path = "output/auction_pipeline.pkl"
83
  label_path = "output/label_encoder.pkl"
84
  joblib.dump(pipeline, pipeline_path)
85
  joblib.dump(le_label, label_path)
86
 
87
- # 7) requirements.txt ์ž‘์„ฑ
88
  deps = ["pandas", "scikit-learn", "xgboost", "joblib", "huggingface_hub"]
89
  with open("requirements.txt", "w", encoding="utf-8") as f:
90
  f.write("\n".join(deps))
91
 
92
- # 8) Hugging Face ๋ฆฌํฌ ์ƒ์„ฑ ์‹œ๋„
93
  api = HfApi()
94
  try:
95
  api.create_repo(repo_id=HF_REPO_NAME, token=HF_TOKEN)
96
  except:
97
- pass
98
 
99
- # 9) ๋กœ์ปฌ์— ๋ ˆํฌ ํด๋ก  (๊ธฐ์กด ์‚ญ์ œ ์‹œ read-only ๋ฌธ์ œ ํ•ด๊ฒฐ)
100
  local_dir = "hf_repo"
101
  if os.path.isdir(local_dir):
102
  shutil.rmtree(local_dir, onerror=rm_readonly)
103
  repo = Repository(local_dir=local_dir, clone_from=HF_REPO_NAME, use_auth_token=HF_TOKEN)
104
 
105
- # 10) ํŒŒ์ผ ๋ณต์‚ฌ
106
  for src in [SCRIPT_PATH, "requirements.txt", pipeline_path, label_path]:
107
  dst = os.path.join(local_dir, os.path.basename(src))
108
  shutil.copy(src, dst)
109
 
110
- # 11) ์ปค๋ฐ‹ ๋ฐ ํ‘ธ์‹œ
111
  repo.git_add(auto_lfs_track=True)
112
  repo.git_commit("Add trained pipeline + preprocessing code")
113
  repo.git_push()
 
12
  from xgboost import XGBClassifier
13
  from huggingface_hub import HfApi, Repository
14
 
15
+ # ํ™˜๊ฒฝ ๋ณ€์ˆ˜์—์„œ Hugging Face ํ† ํฐ ์ฝ์–ด์˜ค๊ธฐ
16
  HF_REPO_NAME = "asteroidddd/onbid-map-round"
17
  HF_TOKEN = os.getenv("HF_TOKEN")
18
  if HF_TOKEN is None:
19
  raise ValueError("ํ™˜๊ฒฝ ๋ณ€์ˆ˜ HF_TOKEN์ด ์„ค์ •๋˜์–ด ์žˆ์ง€ ์•Š์Šต๋‹ˆ๋‹ค.")
20
 
21
+ # ์ด ์Šคํฌ๋ฆฝํŠธ์˜ ๊ฒฝ๋กœ์™€ ํŒŒ์ผ๋ช…
22
  SCRIPT_PATH = os.path.abspath(__file__)
23
  SCRIPT_NAME = os.path.basename(SCRIPT_PATH)
24
 
25
+ # ์ตœ์ดˆ์ž…์ฐฐ์‹œ๊ธฐ ์—ฐ๋„/์›”/์ผ/์š”์ผ ์ถ”์ถœ ๋ณ€ํ™˜๊ธฐ
26
  class DateFeatures(BaseEstimator, TransformerMixin):
27
+
28
  def __init__(self, date_column="์ตœ์ดˆ์ž…์ฐฐ์‹œ๊ธฐ"):
29
  self.date_column = date_column
30
 
 
41
  return X.drop(columns=[self.date_column])
42
 
43
 
44
+ # ์ฝ๊ธฐ ์ „์šฉ ํŒŒ์ผ ์‚ญ์ œ ์‹œ ๊ถŒํ•œ ๋ณ€๊ฒฝ ํ›„ ์žฌ์‹œ๋„
45
  def rm_readonly(func, path, exc_info):
46
  os.chmod(path, stat.S_IWRITE)
47
  func(path)
48
 
49
 
50
  def main():
51
+
52
+ # ๋ฐ์ดํ„ฐ ๋กœ๋“œ
53
  df = pd.read_pickle(r'C:\Users\hwang\Desktop\OSSP\data.pkl')
54
 
55
+ # ๋‚™์ฐฐ์ฐจ์ˆ˜ ๋ ˆ์ด๋ธ” ์ธ์ฝ”๋”ฉ ํ›„, ๋นˆ๋„ โ‰ค 10์ธ ํด๋ž˜์Šค ์ œ๊ฑฐ
56
  le_label = LabelEncoder()
57
  df["๋‚™์ฐฐ์ฐจ์ˆ˜_LE"] = le_label.fit_transform(df["๋‚™์ฐฐ์ฐจ์ˆ˜"])
58
  counts = df["๋‚™์ฐฐ์ฐจ์ˆ˜_LE"].value_counts()
59
  rare = counts[counts <= 10].index.tolist()
60
  df = df[~df["๋‚™์ฐฐ์ฐจ์ˆ˜_LE"].isin(rare)].reset_index(drop=True)
61
 
62
+ # ์ž…๋ ฅ(X)๊ณผ ํƒ€๊นƒ(y) ๋ถ„๋ฆฌ
63
  X = df[["๋Œ€๋ถ„๋ฅ˜", "์ค‘๋ถ„๋ฅ˜", "๊ธฐ๊ด€", "์ตœ์ดˆ์ž…์ฐฐ์‹œ๊ธฐ", "1์ฐจ์ตœ์ €์ž…์ฐฐ๊ฐ€"]]
64
  y = df["๋‚™์ฐฐ์ฐจ์ˆ˜_LE"]
65
 
66
+ # ์ „์ฒ˜๋ฆฌ ๋ฐ ๋ชจ๋ธ ํŒŒ์ดํ”„๋ผ์ธ ์ •์˜
67
  cat_cols = ["๋Œ€๋ถ„๋ฅ˜", "์ค‘๋ถ„๋ฅ˜", "๊ธฐ๊ด€"]
68
  preprocessor = ColumnTransformer(
69
  transformers=[
 
77
  ("classifier", XGBClassifier(eval_metric="mlogloss", random_state=42))
78
  ])
79
 
80
+ # ํŒŒ์ดํ”„๋ผ์ธ ํ•™์Šต
81
  pipeline.fit(X, y)
82
 
83
+ # ํ•™์Šต๋œ ํŒŒ์ดํ”„๋ผ์ธ๊ณผ ๋ ˆ์ด๋ธ” ์ธ์ฝ”๋” ์ €์žฅ
84
  os.makedirs("output", exist_ok=True)
85
  pipeline_path = "output/auction_pipeline.pkl"
86
  label_path = "output/label_encoder.pkl"
87
  joblib.dump(pipeline, pipeline_path)
88
  joblib.dump(le_label, label_path)
89
 
90
+ # requirements.txt ์ž‘์„ฑ
91
  deps = ["pandas", "scikit-learn", "xgboost", "joblib", "huggingface_hub"]
92
  with open("requirements.txt", "w", encoding="utf-8") as f:
93
  f.write("\n".join(deps))
94
 
95
+ # Hugging Face ๋ ˆํฌ์ง€ํ† ๋ฆฌ ์ƒ์„ฑ ์‹œ๋„
96
  api = HfApi()
97
  try:
98
  api.create_repo(repo_id=HF_REPO_NAME, token=HF_TOKEN)
99
  except:
100
+ pass # ์ด๋ฏธ ์กด์žฌํ•˜๋ฉด ๋ฌด์‹œ
101
 
102
+ # ๋กœ์ปฌ์— ๋ ˆํฌ ํด๋ก  (๊ธฐ์กด ์‚ญ์ œ ์‹œ read-only ์˜ค๋ฅ˜ ์ฒ˜๋ฆฌ)
103
  local_dir = "hf_repo"
104
  if os.path.isdir(local_dir):
105
  shutil.rmtree(local_dir, onerror=rm_readonly)
106
  repo = Repository(local_dir=local_dir, clone_from=HF_REPO_NAME, use_auth_token=HF_TOKEN)
107
 
108
+ # ํ•„์š”ํ•œ ํŒŒ์ผ ๋ณต์‚ฌ
109
  for src in [SCRIPT_PATH, "requirements.txt", pipeline_path, label_path]:
110
  dst = os.path.join(local_dir, os.path.basename(src))
111
  shutil.copy(src, dst)
112
 
113
+ # ์ปค๋ฐ‹ ๋ฐ ํ‘ธ์‹œ
114
  repo.git_add(auto_lfs_track=True)
115
  repo.git_commit("Add trained pipeline + preprocessing code")
116
  repo.git_push()