asteroidddd commited on
Commit
b6795f3
ยท
1 Parent(s): 5a2a0c6

Add trained pipeline + preprocessing code

Browse files
Files changed (2) hide show
  1. auction_pipeline.pkl +2 -2
  2. onbid-map-round-train.py +14 -26
auction_pipeline.pkl CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:1534cc58fcb68d75d575a23b1df3b0bbc23d4faa40efbe91f9796affd524f47a
3
- size 3566953
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1b2302cd4ef6f2af0d667e28288ebf90cf823cef5f08a4e372f443d506f8a42e
3
+ size 3567270
onbid-map-round-train.py CHANGED
@@ -5,14 +5,13 @@ import shutil
5
  import stat
6
  import pandas as pd
7
  import joblib
8
- from sklearn.base import BaseEstimator, TransformerMixin
9
- from sklearn.preprocessing import OneHotEncoder, LabelEncoder
10
  from sklearn.compose import ColumnTransformer
11
  from sklearn.pipeline import Pipeline
12
  from xgboost import XGBClassifier
13
  from huggingface_hub import HfApi, Repository
14
 
15
- # ํ™˜๊ฒฝ ๋ณ€์ˆ˜์—์„œ Hugging Face ํ† ํฐ ์ฝ์–ด์˜ค๊ธฐ
16
  HF_REPO_NAME = "asteroidddd/onbid-map-round"
17
  HF_TOKEN = os.getenv("HF_TOKEN")
18
  if HF_TOKEN is None:
@@ -22,33 +21,22 @@ if HF_TOKEN is None:
22
  SCRIPT_PATH = os.path.abspath(__file__)
23
  SCRIPT_NAME = os.path.basename(SCRIPT_PATH)
24
 
25
- # ์ตœ์ดˆ์ž…์ฐฐ์‹œ๊ธฐ ์—ฐ๋„/์›”/์ผ/์š”์ผ ์ถ”์ถœ ๋ณ€ํ™˜๊ธฐ
26
- class DateFeatures(BaseEstimator, TransformerMixin):
 
 
 
 
 
 
 
27
 
28
- def __init__(self, date_column="์ตœ์ดˆ์ž…์ฐฐ์‹œ๊ธฐ"):
29
- self.date_column = date_column
30
-
31
- def fit(self, X, y=None):
32
- return self
33
-
34
- def transform(self, X):
35
- X = X.copy()
36
- dt = pd.to_datetime(X[self.date_column])
37
- X["์ตœ์ดˆ์ž…์ฐฐ_์—ฐ๋„"] = dt.dt.year
38
- X["์ตœ์ดˆ์ž…์ฐฐ_์›”"] = dt.dt.month
39
- X["์ตœ์ดˆ์ž…์ฐฐ_์ผ"] = dt.dt.day
40
- X["์ตœ์ดˆ์ž…์ฐฐ_์š”์ผ"] = dt.dt.weekday
41
- return X.drop(columns=[self.date_column])
42
-
43
-
44
- # ์ฝ๊ธฐ ์ „์šฉ ํŒŒ์ผ ์‚ญ์ œ ์‹œ ๊ถŒํ•œ ๋ณ€๊ฒฝ ํ›„ ์žฌ์‹œ๋„
45
  def rm_readonly(func, path, exc_info):
 
46
  os.chmod(path, stat.S_IWRITE)
47
  func(path)
48
 
49
-
50
  def main():
51
-
52
  # ๋ฐ์ดํ„ฐ ๋กœ๋“œ
53
  df = pd.read_pickle(r'C:\Users\hwang\Desktop\OSSP\data.pkl')
54
 
@@ -67,7 +55,7 @@ def main():
67
  cat_cols = ["๋Œ€๋ถ„๋ฅ˜", "์ค‘๋ถ„๋ฅ˜", "๊ธฐ๊ด€"]
68
  preprocessor = ColumnTransformer(
69
  transformers=[
70
- ("datefeat", DateFeatures("์ตœ์ดˆ์ž…์ฐฐ์‹œ๊ธฐ"), ["์ตœ์ดˆ์ž…์ฐฐ์‹œ๊ธฐ"]),
71
  ("ohe", OneHotEncoder(handle_unknown="ignore"), cat_cols)
72
  ],
73
  remainder="passthrough"
@@ -97,7 +85,7 @@ def main():
97
  try:
98
  api.create_repo(repo_id=HF_REPO_NAME, token=HF_TOKEN)
99
  except:
100
- pass # ์ด๋ฏธ ์กด์žฌํ•˜๋ฉด ๋ฌด์‹œ
101
 
102
  # ๋กœ์ปฌ์— ๋ ˆํฌ ํด๋ก  (๊ธฐ์กด ์‚ญ์ œ ์‹œ read-only ์˜ค๋ฅ˜ ์ฒ˜๋ฆฌ)
103
  local_dir = "hf_repo"
 
5
  import stat
6
  import pandas as pd
7
  import joblib
8
+ from sklearn.preprocessing import OneHotEncoder, LabelEncoder, FunctionTransformer
 
9
  from sklearn.compose import ColumnTransformer
10
  from sklearn.pipeline import Pipeline
11
  from xgboost import XGBClassifier
12
  from huggingface_hub import HfApi, Repository
13
 
14
+ # ํ™˜๊ฒฝ ๋ณ€์ˆ˜์—์„œ ํ† ํฐ ์ฝ์–ด์˜ค๊ธฐ
15
  HF_REPO_NAME = "asteroidddd/onbid-map-round"
16
  HF_TOKEN = os.getenv("HF_TOKEN")
17
  if HF_TOKEN is None:
 
21
  SCRIPT_PATH = os.path.abspath(__file__)
22
  SCRIPT_NAME = os.path.basename(SCRIPT_PATH)
23
 
24
+ def extract_date_features(df):
25
+ """์ตœ์ดˆ์ž…์ฐฐ์‹œ๊ธฐ์—์„œ ์—ฐ๋„/์›”/์ผ/์š”์ผ์„ ์ถ”์ถœํ•˜๊ณ  ์›๋ณธ ์ปฌ๋Ÿผ ์ œ๊ฑฐ."""
26
+ X = df.copy()
27
+ dt = pd.to_datetime(X["์ตœ์ดˆ์ž…์ฐฐ์‹œ๊ธฐ"])
28
+ X["์ตœ์ดˆ์ž…์ฐฐ_์—ฐ๋„"] = dt.dt.year
29
+ X["์ตœ์ดˆ์ž…์ฐฐ_์›”"] = dt.dt.month
30
+ X["์ตœ์ดˆ์ž…์ฐฐ_์ผ"] = dt.dt.day
31
+ X["์ตœ์ดˆ์ž…์ฐฐ_์š”์ผ"] = dt.dt.weekday
32
+ return X.drop(columns=["์ตœ์ดˆ์ž…์ฐฐ์‹œ๊ธฐ"])
33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  def rm_readonly(func, path, exc_info):
35
+ """์ฝ๊ธฐ ์ „์šฉ ํŒŒ์ผ ์‚ญ์ œ ์‹œ ๊ถŒํ•œ ๋ณ€๊ฒฝ ํ›„ ์žฌ์‹œ๋„."""
36
  os.chmod(path, stat.S_IWRITE)
37
  func(path)
38
 
 
39
  def main():
 
40
  # ๋ฐ์ดํ„ฐ ๋กœ๋“œ
41
  df = pd.read_pickle(r'C:\Users\hwang\Desktop\OSSP\data.pkl')
42
 
 
55
  cat_cols = ["๋Œ€๋ถ„๋ฅ˜", "์ค‘๋ถ„๋ฅ˜", "๊ธฐ๊ด€"]
56
  preprocessor = ColumnTransformer(
57
  transformers=[
58
+ ("datefeat", FunctionTransformer(extract_date_features, validate=False), ["์ตœ์ดˆ์ž…์ฐฐ์‹œ๊ธฐ"]),
59
  ("ohe", OneHotEncoder(handle_unknown="ignore"), cat_cols)
60
  ],
61
  remainder="passthrough"
 
85
  try:
86
  api.create_repo(repo_id=HF_REPO_NAME, token=HF_TOKEN)
87
  except:
88
+ pass
89
 
90
  # ๋กœ์ปฌ์— ๋ ˆํฌ ํด๋ก  (๊ธฐ์กด ์‚ญ์ œ ์‹œ read-only ์˜ค๋ฅ˜ ์ฒ˜๋ฆฌ)
91
  local_dir = "hf_repo"