File size: 1,814 Bytes
d160281
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import os
import json
import joblib
import pandas as pd

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score

BASE_DIR = os.path.dirname(os.path.abspath(__file__))

DATASET_PATH = os.path.join(BASE_DIR, "dataset", "cleaned_dataset.csv")
MODEL_DIR = os.path.join(BASE_DIR, "model")
MODEL_PATH = os.path.join(MODEL_DIR, "doctor_model.pkl")
ENCODER_PATH = os.path.join(MODEL_DIR, "label_encoder.pkl")
ACCURACY_PATH = os.path.join(MODEL_DIR, "accuracy.json")

os.makedirs(MODEL_DIR, exist_ok=True)

if not os.path.exists(DATASET_PATH):
    print("❌ Dataset not found:", DATASET_PATH)
    exit()

print("πŸ“‚ Loading dataset...")
df = pd.read_csv(DATASET_PATH)

if "prognosis" not in df.columns:
    print("❌ 'prognosis' column not found in dataset")
    exit()

# Features and target
X = df.drop("prognosis", axis=1)
y = df["prognosis"]

# Encode labels
label_encoder = LabelEncoder()
y_encoded = label_encoder.fit_transform(y)

# Split
X_train, X_test, y_train, y_test = train_test_split(
    X, y_encoded, test_size=0.2, random_state=42
)

# Train model
print("πŸ€– Training model...")
model = RandomForestClassifier(n_estimators=100, random_state=42)
model.fit(X_train, y_train)

# Evaluate
y_pred = model.predict(X_test)
accuracy = round(accuracy_score(y_test, y_pred) * 100, 2)

# Save model
joblib.dump(model, MODEL_PATH)
joblib.dump(label_encoder, ENCODER_PATH)

with open(ACCURACY_PATH, "w") as f:
    json.dump({"accuracy": accuracy}, f)

print("βœ… Model trained successfully")
print("πŸ“¦ Saved model:", MODEL_PATH)
print("πŸ“¦ Saved encoder:", ENCODER_PATH)
print("πŸ“Š Accuracy:", accuracy, "%")