bteodoru commited on
Commit
5d45889
·
verified ·
1 Parent(s): 31ed164

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -0
app.py CHANGED
@@ -9,6 +9,28 @@ from sklearn.ensemble import RandomForestRegressor # Important pentru deseriali
9
  from pydantic import BaseModel, ValidationError, Field, field_validator, model_validator
10
  from typing import Any
11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  app = FastAPI()
14
 
 
9
  from pydantic import BaseModel, ValidationError, Field, field_validator, model_validator
10
  from typing import Any
11
 
12
+ class RobustModelWrapper:
13
+ """Wrapper robust pentru model, compatibil cu FastAPI."""
14
+ def __init__(self, model, feature_names):
15
+ self.model = model
16
+ self.feature_names_in_ = np.array(feature_names)
17
+
18
+ def predict(self, X):
19
+ """Realizează predicții asigurându-se că datele sunt în formatul corect."""
20
+ # Convertim la DataFrame dacă nu este deja
21
+ if not isinstance(X, pd.DataFrame):
22
+ X = pd.DataFrame(X, columns=self.feature_names_in_)
23
+
24
+ # Asigură-te că DataFrame-ul are exact coloanele necesare în ordinea corectă
25
+ prediction_df = pd.DataFrame()
26
+ for feature in self.feature_names_in_:
27
+ if feature in X.columns:
28
+ prediction_df[feature] = X[feature]
29
+ else:
30
+ raise ValueError(f"Caracteristica '{feature}' lipsește din datele de intrare")
31
+
32
+ # Acum realizăm predicția cu coloanele în ordinea corectă
33
+ return self.model.predict(prediction_df)
34
 
35
  app = FastAPI()
36