Aurélie GABU commited on
Commit
15524ff
·
1 Parent(s): d8517ab

sqlalchemy restrictions on HF

Browse files
Files changed (3) hide show
  1. App/database.py +8 -9
  2. App/model.py +50 -45
  3. App/predict.py +5 -2
App/database.py CHANGED
@@ -17,18 +17,17 @@ IS_PYTEST = "pytest" in os.getenv("PYTHONPATH", "") or os.getenv("PYTEST_CURRENT
17
  IS_HF = os.getenv("SPACE_ID") is not None # Hugging Face
18
 
19
  SKIP_DB = IS_CI or IS_PYTEST or IS_HF or not SQLALCHEMY_AVAILABLE
20
-
21
- DB_USER = os.getenv("DB_USER", "postgres")
22
- DB_PASSWORD = os.getenv("DB_PASSWORD", "password")
23
- DB_HOST = os.getenv("DB_HOST", "localhost")
24
- DB_PORT = os.getenv("DB_PORT", "5432")
25
- DB_NAME = os.getenv("DB_NAME", "test_db")
26
-
27
- DATABASE_URL = (f"postgresql+psycopg2://{DB_USER}:{DB_PASSWORD}"f"@{DB_HOST}:{DB_PORT}/{DB_NAME}")
28
-
29
  Base = declarative_base() if SQLALCHEMY_AVAILABLE else None
30
 
31
  if not SKIP_DB:
 
 
 
 
 
 
 
 
32
  engine = create_engine(DATABASE_URL)
33
  SessionLocal = sessionmaker(autocommit = False, autoflush = False, bind = engine)
34
 
 
17
  IS_HF = os.getenv("SPACE_ID") is not None # Hugging Face
18
 
19
  SKIP_DB = IS_CI or IS_PYTEST or IS_HF or not SQLALCHEMY_AVAILABLE
 
 
 
 
 
 
 
 
 
20
  Base = declarative_base() if SQLALCHEMY_AVAILABLE else None
21
 
22
  if not SKIP_DB:
23
+ DB_USER = os.getenv("DB_USER", "postgres")
24
+ DB_PASSWORD = os.getenv("DB_PASSWORD", "password")
25
+ DB_HOST = os.getenv("DB_HOST", "localhost")
26
+ DB_PORT = os.getenv("DB_PORT", "5432")
27
+ DB_NAME = os.getenv("DB_NAME", "test_db")
28
+
29
+ DATABASE_URL = (f"postgresql+psycopg2://{DB_USER}:{DB_PASSWORD}"f"@{DB_HOST}:{DB_PORT}/{DB_NAME}")
30
+
31
  engine = create_engine(DATABASE_URL)
32
  SessionLocal = sessionmaker(autocommit = False, autoflush = False, bind = engine)
33
 
App/model.py CHANGED
@@ -1,51 +1,56 @@
1
- from sqlalchemy import Column, Integer, String, Float, Boolean, DateTime, ForeignKey
2
- from sqlalchemy.sql import func
3
- from App.database import Base
4
 
5
- class Input(Base):
6
- __tablename__ = "inputs"
 
7
 
8
- id = Column(Integer, primary_key=True, index=True)
9
- genre = Column(String)
10
- statut_marital = Column(String)
11
- departement = Column(String)
12
- poste = Column(String)
13
- domaine_etude = Column(String)
14
- frequence_deplacement = Column(String)
15
- heure_supplementaires = Column(Boolean)
16
- evolution_cat_evol = Column(String)
17
- categorie_employe = Column(String)
18
- satisfaction_employee_nature_travail = Column(Integer)
19
- nombre_participation_pee = Column(Integer)
20
- ecart_note_evaluation = Column(Integer)
21
- revenu_mensuel = Column(Integer)
22
- distance_domicile_travail = Column(Integer)
23
- satisfaction_globale = Column(Float)
24
- niveau_education = Column(Integer)
25
- note_evaluation_actuelle = Column(Integer)
26
- satisfaction_employee_equipe = Column(Integer)
27
- age = Column(Integer)
28
- revenu_par_annee_experience_interne = Column(Integer)
29
- satisfaction_employee_equilibre_pro_perso = Column(Integer)
30
- nombre_experiences_precedentes = Column(Integer)
31
- annees_dans_l_entreprise = Column(Integer)
32
- nb_formations_suivies = Column(Integer)
33
- revenu_par_annee_experience_totale = Column(Integer)
34
- ratio_sans_promotion = Column(Integer)
35
- satisfaction_employee_environnement = Column(Integer)
36
- exp_hors_entreprise = Column(Integer)
37
- mobilite_promotion = Column(Integer)
38
- annees_depuis_la_derniere_promotion = Column(Integer)
39
 
40
- created_at = Column(DateTime(timezone=True), server_default=func.now())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
- class Predictions(Base):
43
- __tablename__ = "predictions"
44
- id = Column(Integer, primary_key=True, index=True)
45
- input_id = Column(Integer, ForeignKey("inputs.id"))
46
 
47
- prediction_label = Column(String)
48
- prediction_proba = Column(Float)
49
- model_version = Column(String)
 
50
 
51
- created_at = Column(DateTime(timezone=True), server_default=func.now())
 
 
 
 
 
 
 
 
1
+ from App.database import Base, SQLALCHEMY_AVAILABLE
 
 
2
 
3
+ if SQLALCHEMY_AVAILABLE:
4
+ from sqlalchemy import Column, Integer, String, Float, Boolean, DateTime, ForeignKey
5
+ from sqlalchemy.sql import func
6
 
7
+ class Input(Base):
8
+ __tablename__ = "inputs"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
+ id = Column(Integer, primary_key=True, index=True)
11
+ genre = Column(String)
12
+ statut_marital = Column(String)
13
+ departement = Column(String)
14
+ poste = Column(String)
15
+ domaine_etude = Column(String)
16
+ frequence_deplacement = Column(String)
17
+ heure_supplementaires = Column(Boolean)
18
+ evolution_cat_evol = Column(String)
19
+ categorie_employe = Column(String)
20
+ satisfaction_employee_nature_travail = Column(Integer)
21
+ nombre_participation_pee = Column(Integer)
22
+ ecart_note_evaluation = Column(Integer)
23
+ revenu_mensuel = Column(Integer)
24
+ distance_domicile_travail = Column(Integer)
25
+ satisfaction_globale = Column(Float)
26
+ niveau_education = Column(Integer)
27
+ note_evaluation_actuelle = Column(Integer)
28
+ satisfaction_employee_equipe = Column(Integer)
29
+ age = Column(Integer)
30
+ revenu_par_annee_experience_interne = Column(Integer)
31
+ satisfaction_employee_equilibre_pro_perso = Column(Integer)
32
+ nombre_experiences_precedentes = Column(Integer)
33
+ annees_dans_l_entreprise = Column(Integer)
34
+ nb_formations_suivies = Column(Integer)
35
+ revenu_par_annee_experience_totale = Column(Integer)
36
+ ratio_sans_promotion = Column(Integer)
37
+ satisfaction_employee_environnement = Column(Integer)
38
+ exp_hors_entreprise = Column(Integer)
39
+ mobilite_promotion = Column(Integer)
40
+ annees_depuis_la_derniere_promotion = Column(Integer)
41
 
42
+ created_at = Column(DateTime(timezone=True), server_default=func.now())
 
 
 
43
 
44
+ class Predictions(Base):
45
+ __tablename__ = "predictions"
46
+ id = Column(Integer, primary_key=True, index=True)
47
+ input_id = Column(Integer, ForeignKey("inputs.id"))
48
 
49
+ prediction_label = Column(String)
50
+ prediction_proba = Column(Float)
51
+ model_version = Column(String)
52
+
53
+ created_at = Column(DateTime(timezone=True), server_default=func.now())
54
+ else:
55
+ Input = None
56
+ Predictions = None
App/predict.py CHANGED
@@ -9,12 +9,15 @@ from huggingface_hub import hf_hub_download
9
  # Import SQLAlchemy uniquement si disponible
10
  try:
11
  from sqlalchemy.orm import Session
 
 
12
  SQLALCHEMY_AVAILABLE = True
13
  except ModuleNotFoundError:
14
  SQLALCHEMY_AVAILABLE = False
 
 
 
15
 
16
- from App.database import SessionLocal
17
- from App.model import Input, Predictions
18
 
19
  MODEL_REPO = "Diaure/xgb_model"
20
 
 
9
  # Import SQLAlchemy uniquement si disponible
10
  try:
11
  from sqlalchemy.orm import Session
12
+ from App.database import SessionLocal
13
+ from App.model import Input, Predictions
14
  SQLALCHEMY_AVAILABLE = True
15
  except ModuleNotFoundError:
16
  SQLALCHEMY_AVAILABLE = False
17
+ essionLocal = None
18
+ Input = None
19
+ Predictions = None
20
 
 
 
21
 
22
  MODEL_REPO = "Diaure/xgb_model"
23