Kirtan001 commited on
Commit
9e97e8f
·
verified ·
1 Parent(s): dab3832

Upload folder using huggingface_hub

Browse files
Files changed (6) hide show
  1. .github/workflows/deploy_to_hf.yml +52 -0
  2. app.py +212 -0
  3. cart_model.pkl +3 -0
  4. id3_model.pkl +3 -0
  5. requirement.txt +4 -0
  6. train.py +53 -0
.github/workflows/deploy_to_hf.yml ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ on:
2
+ push:
3
+ branches:
4
+ - main # deploy whenever you push to main
5
+ workflow_dispatch: # allow manual run from Actions tab
6
+
7
+ jobs:
8
+ deploy:
9
+ runs-on: ubuntu-latest
10
+
11
+ steps:
12
+ - name: Checkout repository
13
+ uses: actions/checkout@v4
14
+
15
+ - name: Set up Python
16
+ uses: actions/setup-python@v5
17
+ with:
18
+ python-version: "3.11"
19
+
20
+ - name: Install Hugging Face Hub
21
+ run: pip install huggingface_hub
22
+
23
+ - name: Push to Hugging Face Space
24
+ env:
25
+ HF_TOKEN: ${{ secrets.HF_TOKEN }}
26
+ run: |
27
+ python - << "EOF"
28
+ from huggingface_hub import HfApi
29
+
30
+ # 🔴 CHANGE THIS to your real Hugging Face Space id
31
+ # format: "<hf-username>/<space-name>"
32
+ # example: "ANANDA89/dt_hf_deploy"
33
+ repo_id = "Kirtan001/dc_tree"
34
+
35
+ api = HfApi()
36
+
37
+ # Create the Space if it doesn't exist
38
+ api.create_repo(
39
+ repo_id=repo_id,
40
+ repo_type="space",
41
+ exist_ok=True,
42
+ space_sdk="gradio",
43
+ )
44
+
45
+ # Upload all files except git/CI metadata
46
+ api.upload_folder(
47
+ folder_path=".",
48
+ repo_id=repo_id,
49
+ repo_type="space",
50
+ ignore_patterns=[".git", ".github"],
51
+ )
52
+ EOF
app.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import pickle
3
+ import numpy as np
4
+ import pandas as pd
5
+
6
+ # 1. Load trained models
7
+ with open("cart_model.pkl", "rb") as f:
8
+ cart_model = pickle.load(f)
9
+
10
+ with open("id3_model.pkl", "rb") as f:
11
+ id3_model = pickle.load(f)
12
+
13
+ CLASS_NAMES = ["Setosa", "Versicolor", "Virginica"]
14
+
15
+ # Accept either canonical names or sklearn's original iris names
16
+ FEATURES_CANONICAL = ["sepal_length", "sepal_width", "petal_length", "petal_width"]
17
+ FEATURES_SKLEARN = [
18
+ "sepal length (cm)",
19
+ "sepal width (cm)",
20
+ "petal length (cm)",
21
+ "petal width (cm)",
22
+ ]
23
+
24
+
25
+ def _get_model(model_type: str):
26
+ """Helper to choose CART or ID3."""
27
+ if model_type == "CART (Gini)":
28
+ return cart_model
29
+ return id3_model
30
+
31
+
32
+ # 2A. Single-row prediction
33
+ def predict_single(sepal_length, sepal_width, petal_length, petal_width, model_type):
34
+ X = np.array([[sepal_length, sepal_width, petal_length, petal_width]])
35
+ model = _get_model(model_type)
36
+
37
+ probs = model.predict_proba(X)[0]
38
+ pred_idx = int(np.argmax(probs))
39
+
40
+ return {
41
+ "Chosen model": model_type,
42
+ "Predicted class": CLASS_NAMES[pred_idx],
43
+ "Probabilities": {
44
+ "Setosa": float(probs[0]),
45
+ "Versicolor": float(probs[1]),
46
+ "Virginica": float(probs[2]),
47
+ },
48
+ }
49
+ # 2B. Batch prediction from uploaded CSV
50
+ def predict_batch(file, model_type):
51
+ """
52
+ file: uploaded CSV file from Gradio
53
+ model_type: CART or ID3
54
+ """
55
+ if file is None:
56
+ return pd.DataFrame({"error": ["Please upload a CSV file."]})
57
+
58
+ # Try to read the CSV
59
+ try:
60
+ df = pd.read_csv(file.name)
61
+ except Exception as e:
62
+ return pd.DataFrame({"error": [f"Could not read CSV: {e}"]})
63
+
64
+ # Handle sklearn-style column names by renaming to canonical
65
+ cols = list(df.columns)
66
+
67
+ if all(col in cols for col in FEATURES_SKLEARN):
68
+ rename_map = {
69
+ "sepal length (cm)": "sepal_length",
70
+ "sepal width (cm)": "sepal_width",
71
+ "petal length (cm)": "petal_length",
72
+ "petal width (cm)": "petal_width",
73
+ }
74
+ df = df.rename(columns=rename_map)
75
+
76
+ # Now check that canonical feature names exist
77
+ if not all(col in df.columns for col in FEATURES_CANONICAL):
78
+ return pd.DataFrame({
79
+ "error": [
80
+ "Input CSV must contain either:\n"
81
+ " - 'sepal_length','sepal_width','petal_length','petal_width'\n"
82
+ " OR\n"
83
+ " - 'sepal length (cm)','sepal width (cm)',"
84
+ "'petal length (cm)','petal width (cm)'"
85
+ ]
86
+ })
87
+
88
+ # Drop completely empty rows
89
+ df = df.dropna(how="all")
90
+ if df.empty:
91
+ return pd.DataFrame({"error": ["All rows are empty after dropping NA."]})
92
+
93
+ # Ensure numeric
94
+ try:
95
+ X = df[FEATURES_CANONICAL].astype(float).to_numpy()
96
+ except Exception:
97
+ return pd.DataFrame({
98
+ "error": [
99
+ "Feature columns must be numeric: "
100
+ + ", ".join(FEATURES_CANONICAL)
101
+ ]
102
+ })
103
+
104
+ model = _get_model(model_type)
105
+ probs = model.predict_proba(X)
106
+ preds = np.argmax(probs, axis=1)
107
+ pred_labels = [CLASS_NAMES[i] for i in preds]
108
+
109
+ # Build result DataFrame: original columns + predictions
110
+ result = df.copy()
111
+ result["predicted_class"] = pred_labels
112
+ result["prob_setosa"] = probs[:, 0]
113
+ result["prob_versicolor"] = probs[:, 1]
114
+ result["prob_virginica"] = probs[:, 2]
115
+
116
+ return result
117
+ def predict_batch_and_save(file, model_type):
118
+ result = predict_batch(file, model_type)
119
+ if not isinstance(result, pd.DataFrame):
120
+ result = pd.DataFrame({"error": ["Unknown error"]})
121
+ csv_path = "batch_predictions.csv"
122
+ result.to_csv(csv_path, index=False)
123
+ return csv_path
124
+
125
+
126
+ # 3. Gradio UI with Tabs
127
+ with gr.Blocks() as demo:
128
+ gr.Markdown("# Decision Tree Classifier (CART vs ID3)")
129
+ gr.Markdown(
130
+ "Use the single prediction tab for one Iris flower, "
131
+ "or upload a CSV file with multiple rows for batch prediction.\n\n"
132
+ "*Data feeding happens entirely at the user end:* "
133
+ "they prepare their own CSV, upload it, and see model outputs."
134
+ )
135
+
136
+ # ---- Tab 1: Single prediction ----
137
+ with gr.Tab("Single prediction"):
138
+ with gr.Row():
139
+ sepal_length = gr.Number(label="Sepal length (cm)", value=5.1)
140
+ sepal_width = gr.Number(label="Sepal width (cm)", value=3.5)
141
+ with gr.Row():
142
+ petal_length = gr.Number(label="Petal length (cm)", value=1.4)
143
+ petal_width = gr.Number(label="Petal width (cm)", value=0.2)
144
+
145
+ model_single = gr.Radio(
146
+ choices=["CART (Gini)", "ID3 (Entropy)"],
147
+ value="CART (Gini)",
148
+ label="Decision tree type",
149
+ )
150
+
151
+ btn_single = gr.Button("Predict")
152
+ out_single = gr.JSON(label="Prediction details")
153
+
154
+ btn_single.click(
155
+ fn=predict_single,
156
+ inputs=[sepal_length, sepal_width, petal_length, petal_width, model_single],
157
+ outputs=out_single,
158
+ )
159
+
160
+ # ---- Tab 2: Batch prediction (CSV upload) ----
161
+ with gr.Tab("Batch prediction (CSV upload)"):
162
+ gr.Markdown(
163
+ "Upload a CSV file with column names either:\n"
164
+ "- sepal_length, sepal_width, petal_length, petal_width, or\n"
165
+ "- sepal length (cm), sepal width (cm), petal length (cm), petal width (cm).\n\n"
166
+ "You can edit your data in Excel / Python, save as CSV, upload here, "
167
+ "and see the predictions instantly."
168
+ )
169
+
170
+ file_input = gr.File(label="Upload CSV file", file_types=[".csv"])
171
+ model_batch = gr.Radio(
172
+ choices=["CART (Gini)", "ID3 (Entropy)"],
173
+ value="CART (Gini)",
174
+ label="Decision tree type",
175
+ )
176
+
177
+ btn_batch = gr.Button("Run batch prediction")
178
+ out_batch = gr.Dataframe(
179
+ label="Predictions (input + model outputs)",
180
+ interactive=False,
181
+ )
182
+
183
+
184
+ download_btn = gr.DownloadButton(
185
+ label="Download results as CSV"
186
+ # file_name="batch_predictions.csv"
187
+ )
188
+
189
+ # Show table
190
+ btn_batch.click(
191
+ fn=predict_batch,
192
+ inputs=[file_input, model_batch],
193
+ outputs=out_batch,
194
+ )
195
+
196
+ # Download CSV
197
+ download_btn.click(
198
+ fn=predict_batch_and_save,
199
+ inputs=[file_input, model_batch],
200
+ outputs=download_btn,
201
+ )
202
+
203
+
204
+
205
+
206
+
207
+ # 4. Entry point
208
+ if __name__ == "__main__":
209
+ demo.launch()
210
+
211
+
212
+
cart_model.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4a620bd87909613bb53ab80352775391ef12d16e7fff511a421047a2cde95a8f
3
+ size 2263
id3_model.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:30ffeee87e11d56c693cd740bd002bfed524681be23145f4af4eb1f90b731c75
3
+ size 2266
requirement.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ numpy
2
+ pandas
3
+ scikit-learn
4
+ gradio
train.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+ from sklearn.datasets import load_iris
3
+ from sklearn.model_selection import train_test_split
4
+ from sklearn.tree import DecisionTreeClassifier
5
+
6
+
7
+ def train_and_save_models():
8
+ print("🔹 Loading Iris dataset...")
9
+ iris = load_iris()
10
+ X = iris.data
11
+ y = iris.target
12
+ print(f" Features shape: {X.shape}")
13
+ print(f" Target shape: {y.shape}")
14
+
15
+ print("Splitting train/test...")
16
+ X_train, X_test, y_train, y_test = train_test_split(
17
+ X, y, test_size=0.2, random_state=42, stratify=y
18
+ )
19
+ print(f" Train size: {X_train.shape[0]}, Test size: {X_test.shape[0]}")
20
+
21
+ print("🔹 Training CART (Gini) model...")
22
+ cart_clf = DecisionTreeClassifier(
23
+ criterion="gini",
24
+ random_state=42,
25
+ max_depth=4
26
+ )
27
+ cart_clf.fit(X_train, y_train)
28
+
29
+ print("🔹 Training ID3-like (Entropy) model...")
30
+ id3_clf = DecisionTreeClassifier(
31
+ criterion="entropy",
32
+ random_state=42,
33
+ max_depth=4
34
+ )
35
+ id3_clf.fit(X_train, y_train)
36
+
37
+ print("🔹 Saving models to cart_model.pkl and id3_model.pkl...")
38
+ with open("cart_model.pkl", "wb") as f:
39
+ pickle.dump(cart_clf, f)
40
+
41
+ with open("id3_model.pkl", "wb") as f:
42
+ pickle.dump(id3_clf, f)
43
+
44
+ cart_acc = cart_clf.score(X_test, y_test)
45
+ id3_acc = id3_clf.score(X_test, y_test)
46
+ print(f" CART test accuracy: {cart_acc:.3f}")
47
+ print(f"ID3 test accuracy: {id3_acc:.3f}")
48
+
49
+
50
+ if __name__ == "__main__":
51
+ print(" Starting training script train_and_save_models()")
52
+ train_and_save_models()
53
+ print("Training completed.")