File size: 10,087 Bytes
f68b2e8
714e511
d44b8b1
4fc617b
9df4ce7
518cec6
9df4ce7
58908df
82e9d98
d44b8b1
9df4ce7
5d45889
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6660aa5
58908df
 
7d29db3
 
 
 
 
 
 
 
 
8cedd98
7d29db3
 
 
 
ea26e84
f394060
 
 
 
 
 
 
 
ea26e84
f394060
 
 
 
 
 
 
 
 
 
 
ea26e84
f394060
 
 
 
8cedd98
58908df
 
31ed164
 
 
58908df
31ed164
 
 
 
147fe82
58908df
4fc617b
 
 
58908df
4fc617b
 
2ab8ae3
4fc617b
 
b0af23d
2ab8ae3
cded13d
 
 
 
9df4ce7
b0af23d
cded13d
 
 
 
 
 
 
 
 
8d081b4
cded13d
 
 
 
 
 
9df4ce7
cded13d
 
 
 
 
 
4fc617b
6d369a4
f394060
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31e30c5
 
 
 
 
 
 
 
f394060
 
 
 
 
 
31e30c5
 
 
 
 
 
 
f394060
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2ab8ae3
9df4ce7
 
ad165e0
8f8838f
ad165e0
58908df
 
 
9df4ce7
31e30c5
4fc617b
31e30c5
 
 
 
 
 
4fc617b
31e30c5
 
 
 
 
4fc617b
 
31e30c5
ad165e0
58908df
 
 
 
ad165e0
9df4ce7
58908df
9df4ce7
 
 
 
ad165e0
9df4ce7
 
ad165e0
 
58908df
8f8838f
ad165e0
 
 
 
8f8838f
 
58908df
8f8838f
 
 
 
 
58908df
 
 
8f8838f
 
e4a553c
8f8838f
 
 
 
 
 
 
 
 
 
 
 
7d29db3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import JSONResponse
# from pydantic import BaseModel, Field, validator, model_validator
# from pydantic.errors import PydanticValueError
import pandas as pd
import joblib
import numpy as np
from sklearn.ensemble import RandomForestRegressor  # Important pentru deserializare
from pydantic import BaseModel, ValidationError, Field, field_validator, model_validator
from typing import Any

class RobustModelWrapper:
    """Wrapper robust pentru model, compatibil cu FastAPI."""
    def __init__(self, model, feature_names):
        self.model = model
        self.feature_names_in_ = np.array(feature_names)
    
    def predict(self, X):
        """Realizează predicții asigurându-se că datele sunt în formatul corect."""
        # Convertim la DataFrame dacă nu este deja
        if not isinstance(X, pd.DataFrame):
            X = pd.DataFrame(X, columns=self.feature_names_in_)
        
        # Asigură-te că DataFrame-ul are exact coloanele necesare în ordinea corectă
        prediction_df = pd.DataFrame()
        for feature in self.feature_names_in_:
            if feature in X.columns:
                prediction_df[feature] = X[feature]
            else:
                raise ValueError(f"Caracteristica '{feature}' lipsește din datele de intrare")
        
        # Acum realizăm predicția cu coloanele în ordinea corectă
        return self.model.predict(prediction_df)

app = FastAPI()

# @app.exception_handler(ValidationError)
# async def validation_exception_handler(request: Request, exc: ValidationError):
#     errors = []
#     for error in exc.errors():
#         # Elimină prefixul "Value error" din mesaj
#         message = error['msg']
#         if message.startswith("Value error, "):
#             message = message[12:]  # Lungimea "Value error: " este 12
#         errors.append({"loc": error['loc'], "msg": message})
    
#     return JSONResponse(
#         status_code=422,
#         content={"detail": errors}
#     )

# @app.exception_handler(ValidationError)
# async def validation_exception_handler(request: Request, exc: ValidationError):
#     errors = []
#     for error in exc.errors():
#         # Extragem mesajul și eliminăm prefixul "Value error, "
#         message = error['msg']
#         if message.startswith("Value error, "):
#             message = message[12:]  # Lungimea "Value error, " este 12
            
#         # Construim eroarea păstrând toate câmpurile originale,
#         # dar cu mesajul modificat
#         error_dict = {
#             "type": error.get('type'),
#             "loc": error.get('loc'),
#             "msg": message,
#             "input": error.get('input'),
#             "ctx": error.get('ctx'),
#             "url": error.get('url')
#         }
#         errors.append(error_dict)
    
#     return JSONResponse(
#         status_code=422,
#         content={"detail": errors}
#     )
    
# Încărcăm modelul
try:
    model = joblib.load('rf_model_optim.joblib')
    FEATURE_ORDER = model.feature_names_in_  # Obținem ordinea corectă a caracteristicilor
    print("Model încărcat cu succes! Feature Order:", FEATURE_ORDER)
except Exception as e:
    print(f"Eroare la încărcarea modelului: {str(e)}")
    model = None  # Setăm modelul ca None în caz de eroare
    FEATURE_ORDER = []  # Inițializăm o listă goală pentru a evita erorile ulterioare



# # Definim clase personalizate pentru erori
# class CementPercentError(PydanticValueError):
#     msg_template = "Cement percentage must be between 0% and 15%"

# class CuringPeriodError(PydanticValueError):
#     msg_template = "Curing period must be between 1 and 90 days"

# class CompactionRateError(PydanticValueError):
#     msg_template = "Compaction velocity must be between 0.5 and 1.5 mm/min"
    

# class SoilInput__(BaseModel):
#     cement_perecent: float
#     curing_period: float
#     compaction_rate: float

    
#     @model_validator(mode="after")
#     def check_cement_and_curing(self):
#         if self.cement_perecent == 0:
#             self.curing_period = 0
#         else:
#             if not (1 <= self.curing_period <= 90):
#                 # raise CuringPeriodError()
#                 raise ValueError("Curing period must be between 1 and 90 days")
#         return self
    
#     @validator('cement_perecent')
#     def validate_cement(cls, v):
#         if not 0 <= v <= 15:
#             # raise CementPercentError()
#             raise ValueError("Cement percentage must be between 0% and 15%")
#         return v

#     @validator('compaction_rate')
#     def validate_compaction(cls, v):
#         if not 0.5 <= v <= 1.5:
#             # raise CompactionRateError()
#             raise ValueError("Compaction velocity must be between 0.5 and 1.5 mm/min")
#         return v

class SoilInput(BaseModel):
    cement_perecent: float = Field(
        ..., 
        # ge=0, 
        # le=15, 
        description="Cement percentage in the mixture"
    )
    curing_period: float = Field(
        ..., 
        # ge=1, 
        # le=90, 
        description="Number of days for curing"
    )
    compaction_rate: float = Field(
        ..., 
        # ge=0.5, 
        # le=1.5, 
        description="Rate of compaction in mm/min"
    )

    @model_validator(mode="after")
    def check_cement_and_curing(self):
        if self.cement_perecent == 0:
            self.curing_period = 0
        else:
            if not (1 <= self.curing_period <= 90):
                # raise CuringPeriodError()
                raise ValueError("Curing period must be between 1 and 90 days")
        return self

    @field_validator('cement_perecent')
    @classmethod
    def validate_cement(cls, v: float) -> float:
        if not 0 <= v <= 15:
            raise ValueError("Cement percentage must be between 0% and 15%")
        return v

    # @field_validator('curing_period')
    # @classmethod
    # def validate_curing(cls, v: float) -> float:
    #     if not 1 <= v <= 90:
    #         raise ValueError("Curing period must be between 1 and 90 days")
    #     return v

    @field_validator('compaction_rate')
    @classmethod
    def validate_compaction(cls, v: float) -> float:
        if not 0.5 <= v <= 1.5:
            raise ValueError("Compaction rate must be between 0.5 and 1.5 mm/min")
        return v

# class SoilInput(BaseModel):
#     cement_perecent: float = Field(...)
#     curing_period: float = Field(...)
#     compaction_rate: float = Field(...)

#     @field_validator('cement_perecent')
#     @classmethod
#     def validate_cement(cls, v: float) -> float:
#         if not 0 <= v <= 15:
#             raise ValueError("Cement percentage must be between 0% and 15%")
#         return v

#     @field_validator('curing_period')
#     @classmethod
#     def validate_curing(cls, v: float, info) -> float:
#         # Obținem valorile celorlalte câmpuri
#         values = info.data
#         cement_percent = values.get('cement_perecent', 0)
        
#         # Aplicăm logica de validare
#         if cement_percent == 0:
#             return 0
#         if not 1 <= v <= 90:
#             raise ValueError("Curing period must be between 1 and 90 days")
#         return v

#     @field_validator('compaction_rate')
#     @classmethod
#     def validate_compaction(cls, v: float) -> float:
#         if not 0.5 <= v <= 1.5:
#             raise ValueError("Compaction rate must be between 0.5 and 1.5 mm/min")
#         return v
        
@app.post("/predict")
async def predict(soil_data: SoilInput):
    """
    Realizează predicții pentru UCS
    """
    if model is None:
        raise HTTPException(status_code=500, detail="Modelul nu a fost încărcat corect")

    try:
        # input_data = soil_data.dict()
        
        # # Aplicăm regula fizică: dacă nu avem ciment, perioada de maturare este 0
        # if input_data['cement_perecent'] == 0:
        #     # Păstrăm datele originale pentru răspuns
        #     original_data = input_data.copy()
        #     # Modificăm perioada de maturare pentru predicție
        #     input_data['curing_period'] = 0
            
        #     # Adăugăm o notă explicativă în răspuns
        #     explanation = "Pentru amestecuri fără ciment, perioada de maturare nu influențează rezistența."
        # else:
        #     original_data = input_data
        #     explanation = None
        
        # Construim DataFrame-ul pentru predicție
        input_data = soil_data.dict()
        input_df = pd.DataFrame([input_data])

        # Ne asigurăm că ordinea caracteristicilor este corectă
        input_df = input_df[FEATURE_ORDER]

        # Facem predicția
        prediction = model.predict(input_df)

        return {
            "success": True,
            "prediction": float(prediction[0]),
            "units": "kPa",
            "input_parameters": input_data
        }
    except Exception as e:
        raise HTTPException(status_code=400, detail=str(e))


@app.get("/status")
async def root():
    """
    Endpoint pentru verificarea stării API-ului
    """
    return {"status": "API is running", "model_loaded": model is not None}


@app.get("/model-info")
async def model_info():
    """
    Endpoint pentru informații despre model
    """
    if model is None:
        raise HTTPException(status_code=500, detail="Modelul nu a fost încărcat corect")

    return {
        "model_type": "Random Forest Regressor",
        "features": FEATURE_ORDER.tolist(),  # 🔥 Conversia la listă pentru compatibilitate cu JSON
        "target": "UCS (kPa)",
        "valid_ranges": {
            "cement_perecent": {"min": 0, "max": 10, "units": "%"},
            "curing_period": {"min": 1, "max": 90, "units": "days"},
            "compaction_rate": {"min": 0.5, "max": 1.5, "units": "mm/min"}
        },
        "model_parameters": {
            "n_estimators": 205,
            "max_depth": 11,
            "min_samples_split": 6,
            "min_samples_leaf": 2
        }
    }