KLEB38 commited on
Commit
d2a2a12
·
1 Parent(s): 240a59c

feat/setup SQL database and GET predict

Browse files
app/main.py CHANGED
@@ -5,6 +5,17 @@ from app.schemas import EmployeeInput
5
  import shap
6
  import os
7
  import logging
 
 
 
 
 
 
 
 
 
 
 
8
 
9
  BASE_DIR = os.path.dirname(os.path.abspath(__file__))
10
 
@@ -96,19 +107,11 @@ def interpret_shap(rank: int, value: float) -> str:
96
  direction = "increases resignation risk" if value > 0 else "decreases resignation risk"
97
  return f"{intensity[rank]} — {direction}"
98
 
99
- @app.get("/") # La page d'accueil de ton API
100
- def read_root():
101
- return {"message": "Welcome to the FUTURISYS HR predictor API"}
102
-
103
- @app.post("/predict")
104
- def predict(data: EmployeeInput):
105
- # 1. On transforme le dictionnaire reçu en DataFrame pandas
106
- df = pd.DataFrame([data.model_dump()])
107
-
108
  for col, known in known_values.items():
109
- val = df[col].values[0]
110
- if val not in known:
111
- logger.warning(f"Unknown value '{val}' for column '{col}' — prediction may be unreliable")
112
 
113
  # Encodage binaire non inclus dans le pipeline:
114
  df['genre']= df["genre"].map({"M": 1, "F": 0})
@@ -154,3 +157,55 @@ def predict(data: EmployeeInput):
154
  for rank, factor in enumerate(top_factors.index)
155
  }
156
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  import shap
6
  import os
7
  import logging
8
+ from dotenv import load_dotenv
9
+ from sqlalchemy import create_engine, text
10
+ from sqlalchemy.orm import sessionmaker
11
+ from database.create_db import PredictionLog
12
+ from sqlalchemy.orm import Session
13
+
14
+ load_dotenv(dotenv_path=os.path.join(os.path.dirname(__file__), '..', '.env'))
15
+ DATABASE_URL = os.getenv("DATABASE_URL")
16
+ engine = create_engine(DATABASE_URL)
17
+ SessionLocal = sessionmaker(bind=engine)
18
+
19
 
20
  BASE_DIR = os.path.dirname(os.path.abspath(__file__))
21
 
 
107
  direction = "increases resignation risk" if value > 0 else "decreases resignation risk"
108
  return f"{intensity[rank]} — {direction}"
109
 
110
+ def run_prediction(df):
 
 
 
 
 
 
 
 
111
  for col, known in known_values.items():
112
+ val = df[col].values[0]
113
+ if val not in known:
114
+ logger.warning(f"Unknown value '{val}' for column '{col}' — prediction may be unreliable")
115
 
116
  # Encodage binaire non inclus dans le pipeline:
117
  df['genre']= df["genre"].map({"M": 1, "F": 0})
 
157
  for rank, factor in enumerate(top_factors.index)
158
  }
159
  }
160
+
161
+ def log_prediction(df:pd.DataFrame, result: dict,id_employee:int = None ):
162
+ with Session(engine) as session:
163
+ factors = list(result["top_5_factors"].keys())
164
+ log = PredictionLog(
165
+ id_employee=id_employee,
166
+ **{col: df[col].values[0].item() if hasattr(df[col].values[0], 'item') else df[col].values[0]
167
+ for col in df.columns
168
+ if col in PredictionLog.__table__.columns.keys()
169
+ and col != 'id_employee'},
170
+ prediction=result["statut_employe"],
171
+ probability_score=result["probability_score"].item() if hasattr(result["probability_score"], 'item') else result["probability_score"],
172
+ primary_driver=factors[0] if len(factors) > 0 else None,
173
+ strong_factor=factors[1] if len(factors) > 1 else None,
174
+ moderate_factor=factors[2] if len(factors) > 2 else None,
175
+ contributing_factor=factors[3] if len(factors) > 3 else None,
176
+ notable_factor=factors[4] if len(factors) > 4 else None,
177
+ )
178
+ session.add(log)
179
+ session.commit()
180
+
181
+
182
+
183
+
184
+
185
+ @app.get("/") # La page d'accueil de ton API
186
+ def read_root():
187
+ return {"message": "Welcome to the FUTURISYS HR predictor API"}
188
+
189
+ @app.get("/predict/{id_employee}")
190
+ def predict_by_id(id_employee: int):
191
+ with engine.connect() as conn:
192
+ result = conn.execute(
193
+ text("SELECT * FROM employees_full WHERE id_employee = :id"),
194
+ {"id": id_employee}
195
+ )
196
+ row = result.fetchone()
197
+ if row is None:
198
+ raise HTTPException(status_code=404, detail="Employee ID not found in database")
199
+
200
+ df = pd.DataFrame([row._mapping])
201
+ result = run_prediction(df)
202
+ log_prediction(df, result, id_employee)
203
+ return result
204
+
205
+ @app.post("/predict")
206
+ def predict(data: EmployeeInput):
207
+ df = pd.DataFrame([data.model_dump()])
208
+ result= run_prediction(df)
209
+ log_prediction(df, result)
210
+ return result
211
+
database/__init__.py ADDED
File without changes
database/create_db.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sqlalchemy import create_engine, Column, Integer, String, Float, DateTime
2
+ from sqlalchemy.sql import text
3
+ from sqlalchemy.orm import declarative_base
4
+ from datetime import datetime
5
+ import os
6
+ from dotenv import load_dotenv
7
+
8
+
9
+ Base = declarative_base()
10
+
11
+ # Connection string
12
+ load_dotenv()
13
+ DATABASE_URL = os.getenv("DATABASE_URL")
14
+ engine = create_engine(DATABASE_URL)
15
+
16
+ class EmployeeSirh(Base):
17
+ __tablename__ = 'employees_sirh'
18
+
19
+ id_employee = Column(Integer, primary_key=True)
20
+ age = Column(Integer)
21
+ genre = Column(String)
22
+ revenu_mensuel = Column(Integer)
23
+ statut_marital = Column(String)
24
+ departement = Column(String)
25
+ poste = Column(String)
26
+ nombre_experiences_precedentes = Column(Integer)
27
+ nombre_heures_travaillees = Column(Integer)
28
+ annee_experience_totale = Column(Integer)
29
+ annees_dans_l_entreprise = Column(Integer)
30
+ annees_dans_le_poste_actuel = Column(Integer)
31
+
32
+ class EmployeeEval(Base):
33
+ __tablename__ = 'employees_eval'
34
+
35
+ eval_number = Column(Integer, primary_key=True)
36
+ satisfaction_employee_environnement = Column(Integer)
37
+ note_evaluation_precedente = Column(Integer)
38
+ niveau_hierarchique_poste = Column(Integer)
39
+ satisfaction_employee_nature_travail = Column(Integer)
40
+ satisfaction_employee_equipe = Column(Integer)
41
+ satisfaction_employee_equilibre_pro_perso = Column(Integer)
42
+ note_evaluation_actuelle = Column(Integer)
43
+ heure_supplementaires = Column(String)
44
+ augementation_salaire_precedente = Column(String)
45
+
46
+
47
+ class EmployeeSondage(Base):
48
+ __tablename__ = 'employees_sondage'
49
+
50
+ code_sondage = Column(Integer, primary_key=True)
51
+ a_quitte_l_entreprise = Column(String)
52
+ nombre_participation_pee = Column(Integer)
53
+ nb_formations_suivies = Column(Integer)
54
+ nombre_employee_sous_responsabilite = Column(Integer)
55
+ distance_domicile_travail = Column(Integer)
56
+ niveau_education = Column(Integer)
57
+ domaine_etude = Column(String)
58
+ ayant_enfants = Column(String)
59
+ frequence_deplacement = Column(String)
60
+ annees_depuis_la_derniere_promotion = Column(Integer)
61
+ annes_sous_responsable_actuel = Column(Integer)
62
+
63
+
64
+ class PredictionLog(Base):
65
+ __tablename__ = 'predictions_log'
66
+
67
+ id = Column(Integer, primary_key=True, autoincrement=True)
68
+ id_employee = Column(Integer)
69
+ timestamp = Column(DateTime, default=datetime.utcnow)
70
+
71
+ # Inputs
72
+ genre = Column(String)
73
+ statut_marital = Column(String)
74
+ departement = Column(String)
75
+ poste = Column(String)
76
+ domaine_etude = Column(String)
77
+ frequence_deplacement = Column(String)
78
+ heure_supplementaires = Column(String)
79
+ age = Column(Integer)
80
+ revenu_mensuel = Column(Integer)
81
+ nombre_experiences_precedentes = Column(Integer)
82
+ annee_experience_totale = Column(Integer)
83
+ annees_dans_l_entreprise = Column(Integer)
84
+ annees_dans_le_poste_actuel = Column(Integer)
85
+ nb_formations_suivies = Column(Integer)
86
+ distance_domicile_travail = Column(Integer)
87
+ niveau_education = Column(Integer)
88
+ annees_depuis_la_derniere_promotion = Column(Integer)
89
+ annes_sous_responsable_actuel = Column(Integer)
90
+ satisfaction_employee_environnement = Column(Integer)
91
+ note_evaluation_precedente = Column(Float)
92
+ satisfaction_employee_nature_travail = Column(Integer)
93
+ satisfaction_employee_equipe = Column(Integer)
94
+ satisfaction_employee_equilibre_pro_perso = Column(Integer)
95
+ note_evaluation_actuelle = Column(Float)
96
+ augementation_salaire_precedente = Column(String)
97
+
98
+ # Outputs
99
+ prediction = Column(String)
100
+ probability_score = Column(Float)
101
+ primary_driver = Column(String)
102
+ strong_factor = Column(String)
103
+ moderate_factor = Column(String)
104
+ contributing_factor = Column(String)
105
+ notable_factor = Column(String)
106
+ unknown_category_warning = Column(String, nullable=True)
107
+ ground_truth = Column(Integer, nullable=True)
108
+
109
+
110
+ if __name__ == "__main__":
111
+ Base.metadata.create_all(engine)
112
+ print("Tables created successfully!")
113
+
114
+
database/insert_data.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ from sqlalchemy import create_engine, Column, Integer, String, Float, DateTime
3
+ from sqlalchemy.sql import text
4
+ import os
5
+ from dotenv import load_dotenv
6
+
7
+ # Connection string
8
+ load_dotenv()
9
+ DATABASE_URL = os.getenv("DATABASE_URL")
10
+ engine = create_engine(DATABASE_URL)
11
+
12
+ with engine.connect() as conn:
13
+ conn.execute(text("DROP VIEW IF EXISTS employees_full CASCADE"))
14
+ conn.commit()
15
+
16
+ # Load CSVs
17
+ df_sirh = pd.read_csv(r'C:\Users\Kevin\projects\OC P4\Projet 4\extrait_sirh.csv')
18
+ df_eval = pd.read_csv(r'C:\Users\Kevin\projects\OC P4\Projet 4\extrait_eval.csv')
19
+ df_eval['eval_number'] = pd.to_numeric(df_eval['eval_number'].str[2:], errors='raise')
20
+ df_sondage = pd.read_csv(r'C:\Users\Kevin\projects\OC P4\Projet 4\extrait_sondage.csv')
21
+
22
+
23
+ # Insert into DB
24
+ df_sirh.to_sql('employees_sirh', engine, if_exists='replace', index=False)
25
+ df_eval.to_sql('employees_eval', engine, if_exists='replace', index=False)
26
+ df_sondage.to_sql('employees_sondage', engine, if_exists='replace', index=False)
27
+
28
+ print("Data inserted successfully!")
29
+
30
+ with engine.connect() as conn:
31
+ conn.execute(text("""
32
+ CREATE OR REPLACE VIEW employees_full AS
33
+ SELECT * FROM employees_sirh s
34
+ INNER JOIN employees_sondage so ON s.id_employee = CAST(so.code_sondage AS INTEGER)
35
+ INNER JOIN employees_eval e ON s.id_employee = CAST(e.eval_number AS INTEGER)
36
+ """))
37
+ conn.commit()
38
+
39
+ print("View created successfully!")
requirements.txt CHANGED
@@ -7,4 +7,6 @@ shap==0.50.0
7
  psycopg2-binary==2.9.11
8
  pytest
9
  pytest-cov
10
- httpx
 
 
 
7
  psycopg2-binary==2.9.11
8
  pytest
9
  pytest-cov
10
+ httpx
11
+ python-dotenv
12
+ sqlalchemy