bteodoru commited on
Commit
ad165e0
·
verified ·
1 Parent(s): e062419

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -6
app.py CHANGED
@@ -1,8 +1,9 @@
1
- from fastapi import FastAPI, HTTPException, Request
2
  from pydantic import BaseModel, validator
3
  import pandas as pd
4
  import joblib
5
  import numpy as np
 
6
 
7
  class SoilInput(BaseModel):
8
  cement_perecent: float
@@ -29,9 +30,9 @@ class SoilInput(BaseModel):
29
 
30
  app = FastAPI()
31
 
32
- # Încărcăm direct modelul Random Forest, fără wrapper
33
  try:
34
- model = joblib.load('model.joblib') # Aici folosim modelul salvat fără wrapper
35
  print("Model încărcat cu succes!")
36
  except Exception as e:
37
  print(f"Eroare la încărcarea modelului: {str(e)}")
@@ -39,15 +40,33 @@ except Exception as e:
39
 
40
  @app.post("/predict")
41
  async def predict(soil_data: SoilInput):
 
 
 
42
  try:
43
- input_df = pd.DataFrame([soil_data.dict()])
 
 
 
 
 
 
 
 
44
  prediction = model.predict(input_df)
45
 
46
  return {
47
  "success": True,
48
  "prediction": float(prediction[0]),
49
  "units": "kPa",
50
- "input_parameters": soil_data.dict()
51
  }
52
  except Exception as e:
53
- raise HTTPException(status_code=400, detail=str(e))
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException
2
  from pydantic import BaseModel, validator
3
  import pandas as pd
4
  import joblib
5
  import numpy as np
6
+ from sklearn.ensemble import RandomForestRegressor # Important să importăm pentru deserializare
7
 
8
  class SoilInput(BaseModel):
9
  cement_perecent: float
 
30
 
31
  app = FastAPI()
32
 
33
+ # Încărcăm modelul Random Forest direct
34
  try:
35
+ model = joblib.load('rf_model.joblib')
36
  print("Model încărcat cu succes!")
37
  except Exception as e:
38
  print(f"Eroare la încărcarea modelului: {str(e)}")
 
40
 
41
  @app.post("/predict")
42
  async def predict(soil_data: SoilInput):
43
+ """
44
+ Realizează predicții pentru UCS folosind parametrii solului
45
+ """
46
  try:
47
+ # Construim DataFrame-ul pentru predicție
48
+ input_data = soil_data.dict()
49
+ input_df = pd.DataFrame([input_data])
50
+
51
+ # Ne asigurăm că ordinea coloanelor este corectă
52
+ feature_order = ['cement_perecent', 'curing_period', 'compaction_rate']
53
+ input_df = input_df[feature_order]
54
+
55
+ # Facem predicția
56
  prediction = model.predict(input_df)
57
 
58
  return {
59
  "success": True,
60
  "prediction": float(prediction[0]),
61
  "units": "kPa",
62
+ "input_parameters": input_data
63
  }
64
  except Exception as e:
65
+ raise HTTPException(status_code=400, detail=str(e))
66
+
67
+ @app.get("/")
68
+ async def root():
69
+ """
70
+ Endpoint pentru verificarea stării API-ului
71
+ """
72
+ return {"status": "API is running", "model_loaded": model is not None}