Update app.py
Browse files
app.py
CHANGED
|
@@ -1,8 +1,9 @@
|
|
| 1 |
-
from fastapi import FastAPI, HTTPException
|
| 2 |
from pydantic import BaseModel, validator
|
| 3 |
import pandas as pd
|
| 4 |
import joblib
|
| 5 |
import numpy as np
|
|
|
|
| 6 |
|
| 7 |
class SoilInput(BaseModel):
|
| 8 |
cement_perecent: float
|
|
@@ -29,9 +30,9 @@ class SoilInput(BaseModel):
|
|
| 29 |
|
| 30 |
app = FastAPI()
|
| 31 |
|
| 32 |
-
# Încărcăm
|
| 33 |
try:
|
| 34 |
-
model = joblib.load('
|
| 35 |
print("Model încărcat cu succes!")
|
| 36 |
except Exception as e:
|
| 37 |
print(f"Eroare la încărcarea modelului: {str(e)}")
|
|
@@ -39,15 +40,33 @@ except Exception as e:
|
|
| 39 |
|
| 40 |
@app.post("/predict")
|
| 41 |
async def predict(soil_data: SoilInput):
|
|
|
|
|
|
|
|
|
|
| 42 |
try:
|
| 43 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
prediction = model.predict(input_df)
|
| 45 |
|
| 46 |
return {
|
| 47 |
"success": True,
|
| 48 |
"prediction": float(prediction[0]),
|
| 49 |
"units": "kPa",
|
| 50 |
-
"input_parameters":
|
| 51 |
}
|
| 52 |
except Exception as e:
|
| 53 |
-
raise HTTPException(status_code=400, detail=str(e))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import FastAPI, HTTPException
|
| 2 |
from pydantic import BaseModel, validator
|
| 3 |
import pandas as pd
|
| 4 |
import joblib
|
| 5 |
import numpy as np
|
| 6 |
+
from sklearn.ensemble import RandomForestRegressor # Important să importăm pentru deserializare
|
| 7 |
|
| 8 |
class SoilInput(BaseModel):
|
| 9 |
cement_perecent: float
|
|
|
|
| 30 |
|
| 31 |
app = FastAPI()
|
| 32 |
|
| 33 |
+
# Încărcăm modelul Random Forest direct
|
| 34 |
try:
|
| 35 |
+
model = joblib.load('rf_model.joblib')
|
| 36 |
print("Model încărcat cu succes!")
|
| 37 |
except Exception as e:
|
| 38 |
print(f"Eroare la încărcarea modelului: {str(e)}")
|
|
|
|
| 40 |
|
| 41 |
@app.post("/predict")
|
| 42 |
async def predict(soil_data: SoilInput):
|
| 43 |
+
"""
|
| 44 |
+
Realizează predicții pentru UCS folosind parametrii solului
|
| 45 |
+
"""
|
| 46 |
try:
|
| 47 |
+
# Construim DataFrame-ul pentru predicție
|
| 48 |
+
input_data = soil_data.dict()
|
| 49 |
+
input_df = pd.DataFrame([input_data])
|
| 50 |
+
|
| 51 |
+
# Ne asigurăm că ordinea coloanelor este corectă
|
| 52 |
+
feature_order = ['cement_perecent', 'curing_period', 'compaction_rate']
|
| 53 |
+
input_df = input_df[feature_order]
|
| 54 |
+
|
| 55 |
+
# Facem predicția
|
| 56 |
prediction = model.predict(input_df)
|
| 57 |
|
| 58 |
return {
|
| 59 |
"success": True,
|
| 60 |
"prediction": float(prediction[0]),
|
| 61 |
"units": "kPa",
|
| 62 |
+
"input_parameters": input_data
|
| 63 |
}
|
| 64 |
except Exception as e:
|
| 65 |
+
raise HTTPException(status_code=400, detail=str(e))
|
| 66 |
+
|
| 67 |
+
@app.get("/")
|
| 68 |
+
async def root():
|
| 69 |
+
"""
|
| 70 |
+
Endpoint pentru verificarea stării API-ului
|
| 71 |
+
"""
|
| 72 |
+
return {"status": "API is running", "model_loaded": model is not None}
|