sreekar8811 commited on
Commit
d2c5284
·
verified ·
1 Parent(s): 33daa2e

Upload train_models.py

Browse files
Files changed (1) hide show
  1. train_models.py +71 -0
train_models.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # train_models.py
3
+ import pandas as pd
4
+ import numpy as np
5
+ from sklearn.model_selection import train_test_split
6
+ from sklearn.preprocessing import StandardScaler
7
+ from sklearn.linear_model import LogisticRegression
8
+ from sklearn.metrics import accuracy_score
9
+ from xgboost import XGBClassifier
10
+ from tensorflow.keras.models import Sequential
11
+ from tensorflow.keras.layers import Dense, LSTM
12
+ import joblib
13
+ import os
14
+
15
+ # Create models directory if it doesn't exist
16
+ os.makedirs("models", exist_ok=True)
17
+
18
+ # Load data
19
+ data = pd.read_csv("heart.csv")
20
+ X = data.drop('target', axis=1)
21
+ y = data['target']
22
+
23
+ # Split data
24
+ X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
25
+
26
+ # Scale features
27
+ scaler = StandardScaler()
28
+ X_train_scaled = scaler.fit_transform(X_train)
29
+ X_test_scaled = scaler.transform(X_test)
30
+
31
+ # Save scaler
32
+ joblib.dump(scaler, "models/scaler.sav")
33
+
34
+ # 1. Logistic Regression
35
+ print("Training Logistic Regression...")
36
+ lr_model = LogisticRegression(max_iter=1000)
37
+ lr_model.fit(X_train_scaled, y_train)
38
+ joblib.dump(lr_model, "models/logistic_model.sav")
39
+
40
+ # 2. XGBoost
41
+ print("Training XGBoost...")
42
+ xgb_model = XGBClassifier()
43
+ xgb_model.fit(X_train_scaled, y_train)
44
+ joblib.dump(xgb_model, "models/xgb_model.sav")
45
+
46
+ # 3. ANN (Artificial Neural Network)
47
+ print("Training ANN...")
48
+ ann_model = Sequential([
49
+ Dense(64, activation='relu', input_shape=(X_train_scaled.shape[1],)),
50
+ Dense(32, activation='relu'),
51
+ Dense(1, activation='sigmoid')
52
+ ])
53
+ ann_model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
54
+ ann_model.fit(X_train_scaled, y_train, epochs=50, batch_size=32, verbose=0)
55
+ ann_model.save("models/ann_model.h5")
56
+
57
+ # 4. LSTM (requires reshaping for time-series-like data)
58
+ print("Training LSTM...")
59
+ # Reshape data for LSTM (samples, timesteps, features)
60
+ X_train_lstm = X_train_scaled.reshape(X_train_scaled.shape[0], 1, X_train_scaled.shape[1])
61
+ X_test_lstm = X_test_scaled.reshape(X_test_scaled.shape[0], 1, X_test_scaled.shape[1])
62
+
63
+ lstm_model = Sequential([
64
+ LSTM(64, input_shape=(1, X_train_scaled.shape[1])),
65
+ Dense(1, activation='sigmoid')
66
+ ])
67
+ lstm_model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
68
+ lstm_model.fit(X_train_lstm, y_train, epochs=50, batch_size=32, verbose=0)
69
+ lstm_model.save("models/lstm_model.h5")
70
+
71
+ print("All models trained and saved successfully!")