bteodoru commited on
Commit
58908df
·
verified ·
1 Parent(s): 34b730a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -25
app.py CHANGED
@@ -3,8 +3,22 @@ from pydantic import BaseModel, validator
3
  import pandas as pd
4
  import joblib
5
  import numpy as np
6
- from sklearn.ensemble import RandomForestRegressor # Important să importăm pentru deserializare
7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  class SoilInput(BaseModel):
9
  cement_perecent: float
10
  curing_period: float
@@ -28,40 +42,26 @@ class SoilInput(BaseModel):
28
  raise ValueError("Viteza de compactare trebuie să fie între 0.5 și 1.5 mm/min")
29
  return v
30
 
31
- app = FastAPI()
32
-
33
- # Încărcăm modelul Random Forest direct
34
- try:
35
- model = joblib.load('rf_model.joblib')
36
- FEATURE_ORDER = model.feature_names_in_
37
- print("Model încărcat cu succes!")
38
- except Exception as e:
39
- print(f"Eroare la încărcarea modelului: {str(e)}")
40
- raise
41
-
42
- # FEATURE_ORDER = ['cement_perecent', 'curing_period', 'compaction_rate']
43
 
44
  @app.post("/predict")
45
  async def predict(soil_data: SoilInput):
46
  """
47
  Realizează predicții pentru UCS
48
  """
 
 
 
49
  try:
50
  # Construim DataFrame-ul pentru predicție
51
  input_data = soil_data.dict()
52
  input_df = pd.DataFrame([input_data])
53
-
54
- # Ne asigurăm că ordinea coloanelor este corectă
55
- feature_order = FEATURE_ORDER
56
- input_df = input_df[feature_order]
57
- # input_df = pd.DataFrame([soil_data.dict()])[FEATURE_ORDER]
58
-
59
- # expected_features = model.feature_names_in_ # Extrage ordinea corectă din model
60
- # input_df = input_df[expected_features] # Reordonează caracteristicile
61
-
62
  # Facem predicția
63
  prediction = model.predict(input_df)
64
-
65
  return {
66
  "success": True,
67
  "prediction": float(prediction[0]),
@@ -71,6 +71,7 @@ async def predict(soil_data: SoilInput):
71
  except Exception as e:
72
  raise HTTPException(status_code=400, detail=str(e))
73
 
 
74
  @app.get("/status")
75
  async def root():
76
  """
@@ -78,14 +79,18 @@ async def root():
78
  """
79
  return {"status": "API is running", "model_loaded": model is not None}
80
 
 
81
  @app.get("/model-info")
82
  async def model_info():
83
  """
84
  Endpoint pentru informații despre model
85
  """
 
 
 
86
  return {
87
  "model_type": "Random Forest Regressor",
88
- "features": FEATURE_ORDER,
89
  "target": "UCS (kPa)",
90
  "valid_ranges": {
91
  "cement_perecent": {"min": 0, "max": 10, "units": "%"},
@@ -98,4 +103,4 @@ async def model_info():
98
  "min_samples_split": 6,
99
  "min_samples_leaf": 2
100
  }
101
- }
 
3
  import pandas as pd
4
  import joblib
5
  import numpy as np
6
+ from sklearn.ensemble import RandomForestRegressor # Important pentru deserializare
7
 
8
+ app = FastAPI()
9
+
10
+ # Încărcăm modelul
11
+ try:
12
+ model = joblib.load('rf_model.joblib')
13
+ FEATURE_ORDER = model.feature_names_in_ # Obținem ordinea corectă a caracteristicilor
14
+ print("Model încărcat cu succes! Feature Order:", FEATURE_ORDER)
15
+ except Exception as e:
16
+ print(f"Eroare la încărcarea modelului: {str(e)}")
17
+ model = None # Setăm modelul ca None în caz de eroare
18
+ FEATURE_ORDER = [] # Inițializăm o listă goală pentru a evita erorile ulterioare
19
+
20
+
21
+ # Definirea clasei pentru inputuri
22
  class SoilInput(BaseModel):
23
  cement_perecent: float
24
  curing_period: float
 
42
  raise ValueError("Viteza de compactare trebuie să fie între 0.5 și 1.5 mm/min")
43
  return v
44
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
  @app.post("/predict")
47
  async def predict(soil_data: SoilInput):
48
  """
49
  Realizează predicții pentru UCS
50
  """
51
+ if model is None:
52
+ raise HTTPException(status_code=500, detail="Modelul nu a fost încărcat corect")
53
+
54
  try:
55
  # Construim DataFrame-ul pentru predicție
56
  input_data = soil_data.dict()
57
  input_df = pd.DataFrame([input_data])
58
+
59
+ # Ne asigurăm că ordinea caracteristicilor este corectă
60
+ input_df = input_df[FEATURE_ORDER]
61
+
 
 
 
 
 
62
  # Facem predicția
63
  prediction = model.predict(input_df)
64
+
65
  return {
66
  "success": True,
67
  "prediction": float(prediction[0]),
 
71
  except Exception as e:
72
  raise HTTPException(status_code=400, detail=str(e))
73
 
74
+
75
  @app.get("/status")
76
  async def root():
77
  """
 
79
  """
80
  return {"status": "API is running", "model_loaded": model is not None}
81
 
82
+
83
  @app.get("/model-info")
84
  async def model_info():
85
  """
86
  Endpoint pentru informații despre model
87
  """
88
+ if model is None:
89
+ raise HTTPException(status_code=500, detail="Modelul nu a fost încărcat corect")
90
+
91
  return {
92
  "model_type": "Random Forest Regressor",
93
+ "features": FEATURE_ORDER, # Acum este garantat că FEATURE_ORDER este definit
94
  "target": "UCS (kPa)",
95
  "valid_ranges": {
96
  "cement_perecent": {"min": 0, "max": 10, "units": "%"},
 
103
  "min_samples_split": 6,
104
  "min_samples_leaf": 2
105
  }
106
+ }