File size: 10,087 Bytes
f68b2e8 714e511 d44b8b1 4fc617b 9df4ce7 518cec6 9df4ce7 58908df 82e9d98 d44b8b1 9df4ce7 5d45889 6660aa5 58908df 7d29db3 8cedd98 7d29db3 ea26e84 f394060 ea26e84 f394060 ea26e84 f394060 8cedd98 58908df 31ed164 58908df 31ed164 147fe82 58908df 4fc617b 58908df 4fc617b 2ab8ae3 4fc617b b0af23d 2ab8ae3 cded13d 9df4ce7 b0af23d cded13d 8d081b4 cded13d 9df4ce7 cded13d 4fc617b 6d369a4 f394060 31e30c5 f394060 31e30c5 f394060 2ab8ae3 9df4ce7 ad165e0 8f8838f ad165e0 58908df 9df4ce7 31e30c5 4fc617b 31e30c5 4fc617b 31e30c5 4fc617b 31e30c5 ad165e0 58908df ad165e0 9df4ce7 58908df 9df4ce7 ad165e0 9df4ce7 ad165e0 58908df 8f8838f ad165e0 8f8838f 58908df 8f8838f 58908df 8f8838f e4a553c 8f8838f 7d29db3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 |
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import JSONResponse
# from pydantic import BaseModel, Field, validator, model_validator
# from pydantic.errors import PydanticValueError
import pandas as pd
import joblib
import numpy as np
from sklearn.ensemble import RandomForestRegressor # Important pentru deserializare
from pydantic import BaseModel, ValidationError, Field, field_validator, model_validator
from typing import Any
class RobustModelWrapper:
"""Wrapper robust pentru model, compatibil cu FastAPI."""
def __init__(self, model, feature_names):
self.model = model
self.feature_names_in_ = np.array(feature_names)
def predict(self, X):
"""Realizează predicții asigurându-se că datele sunt în formatul corect."""
# Convertim la DataFrame dacă nu este deja
if not isinstance(X, pd.DataFrame):
X = pd.DataFrame(X, columns=self.feature_names_in_)
# Asigură-te că DataFrame-ul are exact coloanele necesare în ordinea corectă
prediction_df = pd.DataFrame()
for feature in self.feature_names_in_:
if feature in X.columns:
prediction_df[feature] = X[feature]
else:
raise ValueError(f"Caracteristica '{feature}' lipsește din datele de intrare")
# Acum realizăm predicția cu coloanele în ordinea corectă
return self.model.predict(prediction_df)
app = FastAPI()
# @app.exception_handler(ValidationError)
# async def validation_exception_handler(request: Request, exc: ValidationError):
# errors = []
# for error in exc.errors():
# # Elimină prefixul "Value error" din mesaj
# message = error['msg']
# if message.startswith("Value error, "):
# message = message[12:] # Lungimea "Value error: " este 12
# errors.append({"loc": error['loc'], "msg": message})
# return JSONResponse(
# status_code=422,
# content={"detail": errors}
# )
# @app.exception_handler(ValidationError)
# async def validation_exception_handler(request: Request, exc: ValidationError):
# errors = []
# for error in exc.errors():
# # Extragem mesajul și eliminăm prefixul "Value error, "
# message = error['msg']
# if message.startswith("Value error, "):
# message = message[12:] # Lungimea "Value error, " este 12
# # Construim eroarea păstrând toate câmpurile originale,
# # dar cu mesajul modificat
# error_dict = {
# "type": error.get('type'),
# "loc": error.get('loc'),
# "msg": message,
# "input": error.get('input'),
# "ctx": error.get('ctx'),
# "url": error.get('url')
# }
# errors.append(error_dict)
# return JSONResponse(
# status_code=422,
# content={"detail": errors}
# )
# Încărcăm modelul
try:
model = joblib.load('rf_model_optim.joblib')
FEATURE_ORDER = model.feature_names_in_ # Obținem ordinea corectă a caracteristicilor
print("Model încărcat cu succes! Feature Order:", FEATURE_ORDER)
except Exception as e:
print(f"Eroare la încărcarea modelului: {str(e)}")
model = None # Setăm modelul ca None în caz de eroare
FEATURE_ORDER = [] # Inițializăm o listă goală pentru a evita erorile ulterioare
# # Definim clase personalizate pentru erori
# class CementPercentError(PydanticValueError):
# msg_template = "Cement percentage must be between 0% and 15%"
# class CuringPeriodError(PydanticValueError):
# msg_template = "Curing period must be between 1 and 90 days"
# class CompactionRateError(PydanticValueError):
# msg_template = "Compaction velocity must be between 0.5 and 1.5 mm/min"
# class SoilInput__(BaseModel):
# cement_perecent: float
# curing_period: float
# compaction_rate: float
# @model_validator(mode="after")
# def check_cement_and_curing(self):
# if self.cement_perecent == 0:
# self.curing_period = 0
# else:
# if not (1 <= self.curing_period <= 90):
# # raise CuringPeriodError()
# raise ValueError("Curing period must be between 1 and 90 days")
# return self
# @validator('cement_perecent')
# def validate_cement(cls, v):
# if not 0 <= v <= 15:
# # raise CementPercentError()
# raise ValueError("Cement percentage must be between 0% and 15%")
# return v
# @validator('compaction_rate')
# def validate_compaction(cls, v):
# if not 0.5 <= v <= 1.5:
# # raise CompactionRateError()
# raise ValueError("Compaction velocity must be between 0.5 and 1.5 mm/min")
# return v
class SoilInput(BaseModel):
cement_perecent: float = Field(
...,
# ge=0,
# le=15,
description="Cement percentage in the mixture"
)
curing_period: float = Field(
...,
# ge=1,
# le=90,
description="Number of days for curing"
)
compaction_rate: float = Field(
...,
# ge=0.5,
# le=1.5,
description="Rate of compaction in mm/min"
)
@model_validator(mode="after")
def check_cement_and_curing(self):
if self.cement_perecent == 0:
self.curing_period = 0
else:
if not (1 <= self.curing_period <= 90):
# raise CuringPeriodError()
raise ValueError("Curing period must be between 1 and 90 days")
return self
@field_validator('cement_perecent')
@classmethod
def validate_cement(cls, v: float) -> float:
if not 0 <= v <= 15:
raise ValueError("Cement percentage must be between 0% and 15%")
return v
# @field_validator('curing_period')
# @classmethod
# def validate_curing(cls, v: float) -> float:
# if not 1 <= v <= 90:
# raise ValueError("Curing period must be between 1 and 90 days")
# return v
@field_validator('compaction_rate')
@classmethod
def validate_compaction(cls, v: float) -> float:
if not 0.5 <= v <= 1.5:
raise ValueError("Compaction rate must be between 0.5 and 1.5 mm/min")
return v
# class SoilInput(BaseModel):
# cement_perecent: float = Field(...)
# curing_period: float = Field(...)
# compaction_rate: float = Field(...)
# @field_validator('cement_perecent')
# @classmethod
# def validate_cement(cls, v: float) -> float:
# if not 0 <= v <= 15:
# raise ValueError("Cement percentage must be between 0% and 15%")
# return v
# @field_validator('curing_period')
# @classmethod
# def validate_curing(cls, v: float, info) -> float:
# # Obținem valorile celorlalte câmpuri
# values = info.data
# cement_percent = values.get('cement_perecent', 0)
# # Aplicăm logica de validare
# if cement_percent == 0:
# return 0
# if not 1 <= v <= 90:
# raise ValueError("Curing period must be between 1 and 90 days")
# return v
# @field_validator('compaction_rate')
# @classmethod
# def validate_compaction(cls, v: float) -> float:
# if not 0.5 <= v <= 1.5:
# raise ValueError("Compaction rate must be between 0.5 and 1.5 mm/min")
# return v
@app.post("/predict")
async def predict(soil_data: SoilInput):
"""
Realizează predicții pentru UCS
"""
if model is None:
raise HTTPException(status_code=500, detail="Modelul nu a fost încărcat corect")
try:
# input_data = soil_data.dict()
# # Aplicăm regula fizică: dacă nu avem ciment, perioada de maturare este 0
# if input_data['cement_perecent'] == 0:
# # Păstrăm datele originale pentru răspuns
# original_data = input_data.copy()
# # Modificăm perioada de maturare pentru predicție
# input_data['curing_period'] = 0
# # Adăugăm o notă explicativă în răspuns
# explanation = "Pentru amestecuri fără ciment, perioada de maturare nu influențează rezistența."
# else:
# original_data = input_data
# explanation = None
# Construim DataFrame-ul pentru predicție
input_data = soil_data.dict()
input_df = pd.DataFrame([input_data])
# Ne asigurăm că ordinea caracteristicilor este corectă
input_df = input_df[FEATURE_ORDER]
# Facem predicția
prediction = model.predict(input_df)
return {
"success": True,
"prediction": float(prediction[0]),
"units": "kPa",
"input_parameters": input_data
}
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
@app.get("/status")
async def root():
"""
Endpoint pentru verificarea stării API-ului
"""
return {"status": "API is running", "model_loaded": model is not None}
@app.get("/model-info")
async def model_info():
"""
Endpoint pentru informații despre model
"""
if model is None:
raise HTTPException(status_code=500, detail="Modelul nu a fost încărcat corect")
return {
"model_type": "Random Forest Regressor",
"features": FEATURE_ORDER.tolist(), # 🔥 Conversia la listă pentru compatibilitate cu JSON
"target": "UCS (kPa)",
"valid_ranges": {
"cement_perecent": {"min": 0, "max": 10, "units": "%"},
"curing_period": {"min": 1, "max": 90, "units": "days"},
"compaction_rate": {"min": 0.5, "max": 1.5, "units": "mm/min"}
},
"model_parameters": {
"n_estimators": 205,
"max_depth": 11,
"min_samples_split": 6,
"min_samples_leaf": 2
}
} |