Update app.py
Browse files
app.py
CHANGED
|
@@ -9,6 +9,28 @@ from sklearn.ensemble import RandomForestRegressor # Important pentru deseriali
|
|
| 9 |
from pydantic import BaseModel, ValidationError, Field, field_validator, model_validator
|
| 10 |
from typing import Any
|
| 11 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
app = FastAPI()
|
| 14 |
|
|
|
|
| 9 |
from pydantic import BaseModel, ValidationError, Field, field_validator, model_validator
|
| 10 |
from typing import Any
|
| 11 |
|
| 12 |
+
class RobustModelWrapper:
|
| 13 |
+
"""Wrapper robust pentru model, compatibil cu FastAPI."""
|
| 14 |
+
def __init__(self, model, feature_names):
|
| 15 |
+
self.model = model
|
| 16 |
+
self.feature_names_in_ = np.array(feature_names)
|
| 17 |
+
|
| 18 |
+
def predict(self, X):
|
| 19 |
+
"""Realizează predicții asigurându-se că datele sunt în formatul corect."""
|
| 20 |
+
# Convertim la DataFrame dacă nu este deja
|
| 21 |
+
if not isinstance(X, pd.DataFrame):
|
| 22 |
+
X = pd.DataFrame(X, columns=self.feature_names_in_)
|
| 23 |
+
|
| 24 |
+
# Asigură-te că DataFrame-ul are exact coloanele necesare în ordinea corectă
|
| 25 |
+
prediction_df = pd.DataFrame()
|
| 26 |
+
for feature in self.feature_names_in_:
|
| 27 |
+
if feature in X.columns:
|
| 28 |
+
prediction_df[feature] = X[feature]
|
| 29 |
+
else:
|
| 30 |
+
raise ValueError(f"Caracteristica '{feature}' lipsește din datele de intrare")
|
| 31 |
+
|
| 32 |
+
# Acum realizăm predicția cu coloanele în ordinea corectă
|
| 33 |
+
return self.model.predict(prediction_df)
|
| 34 |
|
| 35 |
app = FastAPI()
|
| 36 |
|