| from fastapi import FastAPI, HTTPException, Request | |
| from fastapi.responses import JSONResponse | |
| # from pydantic import BaseModel, Field, validator, model_validator | |
| # from pydantic.errors import PydanticValueError | |
| import pandas as pd | |
| import joblib | |
| import numpy as np | |
| from sklearn.ensemble import RandomForestRegressor # Important pentru deserializare | |
| from pydantic import BaseModel, ValidationError, Field, field_validator, model_validator | |
| from typing import Any | |
| class RobustModelWrapper: | |
| """Wrapper robust pentru model, compatibil cu FastAPI.""" | |
| def __init__(self, model, feature_names): | |
| self.model = model | |
| self.feature_names_in_ = np.array(feature_names) | |
| def predict(self, X): | |
| """Realizează predicții asigurându-se că datele sunt în formatul corect.""" | |
| # Convertim la DataFrame dacă nu este deja | |
| if not isinstance(X, pd.DataFrame): | |
| X = pd.DataFrame(X, columns=self.feature_names_in_) | |
| # Asigură-te că DataFrame-ul are exact coloanele necesare în ordinea corectă | |
| prediction_df = pd.DataFrame() | |
| for feature in self.feature_names_in_: | |
| if feature in X.columns: | |
| prediction_df[feature] = X[feature] | |
| else: | |
| raise ValueError(f"Caracteristica '{feature}' lipsește din datele de intrare") | |
| # Acum realizăm predicția cu coloanele în ordinea corectă | |
| return self.model.predict(prediction_df) | |
| app = FastAPI() | |
| # @app.exception_handler(ValidationError) | |
| # async def validation_exception_handler(request: Request, exc: ValidationError): | |
| # errors = [] | |
| # for error in exc.errors(): | |
| # # Elimină prefixul "Value error" din mesaj | |
| # message = error['msg'] | |
| # if message.startswith("Value error, "): | |
| # message = message[12:] # Lungimea "Value error: " este 12 | |
| # errors.append({"loc": error['loc'], "msg": message}) | |
| # return JSONResponse( | |
| # status_code=422, | |
| # content={"detail": errors} | |
| # ) | |
| # @app.exception_handler(ValidationError) | |
| # async def validation_exception_handler(request: Request, exc: ValidationError): | |
| # errors = [] | |
| # for error in exc.errors(): | |
| # # Extragem mesajul și eliminăm prefixul "Value error, " | |
| # message = error['msg'] | |
| # if message.startswith("Value error, "): | |
| # message = message[12:] # Lungimea "Value error, " este 12 | |
| # # Construim eroarea păstrând toate câmpurile originale, | |
| # # dar cu mesajul modificat | |
| # error_dict = { | |
| # "type": error.get('type'), | |
| # "loc": error.get('loc'), | |
| # "msg": message, | |
| # "input": error.get('input'), | |
| # "ctx": error.get('ctx'), | |
| # "url": error.get('url') | |
| # } | |
| # errors.append(error_dict) | |
| # return JSONResponse( | |
| # status_code=422, | |
| # content={"detail": errors} | |
| # ) | |
| # Încărcăm modelul | |
| try: | |
| model = joblib.load('rf_model_optim.joblib') | |
| FEATURE_ORDER = model.feature_names_in_ # Obținem ordinea corectă a caracteristicilor | |
| print("Model încărcat cu succes! Feature Order:", FEATURE_ORDER) | |
| except Exception as e: | |
| print(f"Eroare la încărcarea modelului: {str(e)}") | |
| model = None # Setăm modelul ca None în caz de eroare | |
| FEATURE_ORDER = [] # Inițializăm o listă goală pentru a evita erorile ulterioare | |
| # # Definim clase personalizate pentru erori | |
| # class CementPercentError(PydanticValueError): | |
| # msg_template = "Cement percentage must be between 0% and 15%" | |
| # class CuringPeriodError(PydanticValueError): | |
| # msg_template = "Curing period must be between 1 and 90 days" | |
| # class CompactionRateError(PydanticValueError): | |
| # msg_template = "Compaction velocity must be between 0.5 and 1.5 mm/min" | |
| # class SoilInput__(BaseModel): | |
| # cement_perecent: float | |
| # curing_period: float | |
| # compaction_rate: float | |
| # @model_validator(mode="after") | |
| # def check_cement_and_curing(self): | |
| # if self.cement_perecent == 0: | |
| # self.curing_period = 0 | |
| # else: | |
| # if not (1 <= self.curing_period <= 90): | |
| # # raise CuringPeriodError() | |
| # raise ValueError("Curing period must be between 1 and 90 days") | |
| # return self | |
| # @validator('cement_perecent') | |
| # def validate_cement(cls, v): | |
| # if not 0 <= v <= 15: | |
| # # raise CementPercentError() | |
| # raise ValueError("Cement percentage must be between 0% and 15%") | |
| # return v | |
| # @validator('compaction_rate') | |
| # def validate_compaction(cls, v): | |
| # if not 0.5 <= v <= 1.5: | |
| # # raise CompactionRateError() | |
| # raise ValueError("Compaction velocity must be between 0.5 and 1.5 mm/min") | |
| # return v | |
| class SoilInput(BaseModel): | |
| cement_perecent: float = Field( | |
| ..., | |
| # ge=0, | |
| # le=15, | |
| description="Cement percentage in the mixture" | |
| ) | |
| curing_period: float = Field( | |
| ..., | |
| # ge=1, | |
| # le=90, | |
| description="Number of days for curing" | |
| ) | |
| compaction_rate: float = Field( | |
| ..., | |
| # ge=0.5, | |
| # le=1.5, | |
| description="Rate of compaction in mm/min" | |
| ) | |
| def check_cement_and_curing(self): | |
| if self.cement_perecent == 0: | |
| self.curing_period = 0 | |
| else: | |
| if not (1 <= self.curing_period <= 90): | |
| # raise CuringPeriodError() | |
| raise ValueError("Curing period must be between 1 and 90 days") | |
| return self | |
| def validate_cement(cls, v: float) -> float: | |
| if not 0 <= v <= 15: | |
| raise ValueError("Cement percentage must be between 0% and 15%") | |
| return v | |
| # @field_validator('curing_period') | |
| # @classmethod | |
| # def validate_curing(cls, v: float) -> float: | |
| # if not 1 <= v <= 90: | |
| # raise ValueError("Curing period must be between 1 and 90 days") | |
| # return v | |
| def validate_compaction(cls, v: float) -> float: | |
| if not 0.5 <= v <= 1.5: | |
| raise ValueError("Compaction rate must be between 0.5 and 1.5 mm/min") | |
| return v | |
| # class SoilInput(BaseModel): | |
| # cement_perecent: float = Field(...) | |
| # curing_period: float = Field(...) | |
| # compaction_rate: float = Field(...) | |
| # @field_validator('cement_perecent') | |
| # @classmethod | |
| # def validate_cement(cls, v: float) -> float: | |
| # if not 0 <= v <= 15: | |
| # raise ValueError("Cement percentage must be between 0% and 15%") | |
| # return v | |
| # @field_validator('curing_period') | |
| # @classmethod | |
| # def validate_curing(cls, v: float, info) -> float: | |
| # # Obținem valorile celorlalte câmpuri | |
| # values = info.data | |
| # cement_percent = values.get('cement_perecent', 0) | |
| # # Aplicăm logica de validare | |
| # if cement_percent == 0: | |
| # return 0 | |
| # if not 1 <= v <= 90: | |
| # raise ValueError("Curing period must be between 1 and 90 days") | |
| # return v | |
| # @field_validator('compaction_rate') | |
| # @classmethod | |
| # def validate_compaction(cls, v: float) -> float: | |
| # if not 0.5 <= v <= 1.5: | |
| # raise ValueError("Compaction rate must be between 0.5 and 1.5 mm/min") | |
| # return v | |
| async def predict(soil_data: SoilInput): | |
| """ | |
| Realizează predicții pentru UCS | |
| """ | |
| if model is None: | |
| raise HTTPException(status_code=500, detail="Modelul nu a fost încărcat corect") | |
| try: | |
| # input_data = soil_data.dict() | |
| # # Aplicăm regula fizică: dacă nu avem ciment, perioada de maturare este 0 | |
| # if input_data['cement_perecent'] == 0: | |
| # # Păstrăm datele originale pentru răspuns | |
| # original_data = input_data.copy() | |
| # # Modificăm perioada de maturare pentru predicție | |
| # input_data['curing_period'] = 0 | |
| # # Adăugăm o notă explicativă în răspuns | |
| # explanation = "Pentru amestecuri fără ciment, perioada de maturare nu influențează rezistența." | |
| # else: | |
| # original_data = input_data | |
| # explanation = None | |
| # Construim DataFrame-ul pentru predicție | |
| input_data = soil_data.dict() | |
| input_df = pd.DataFrame([input_data]) | |
| # Ne asigurăm că ordinea caracteristicilor este corectă | |
| input_df = input_df[FEATURE_ORDER] | |
| # Facem predicția | |
| prediction = model.predict(input_df) | |
| return { | |
| "success": True, | |
| "prediction": float(prediction[0]), | |
| "units": "kPa", | |
| "input_parameters": input_data | |
| } | |
| except Exception as e: | |
| raise HTTPException(status_code=400, detail=str(e)) | |
| async def root(): | |
| """ | |
| Endpoint pentru verificarea stării API-ului | |
| """ | |
| return {"status": "API is running", "model_loaded": model is not None} | |
| async def model_info(): | |
| """ | |
| Endpoint pentru informații despre model | |
| """ | |
| if model is None: | |
| raise HTTPException(status_code=500, detail="Modelul nu a fost încărcat corect") | |
| return { | |
| "model_type": "Random Forest Regressor", | |
| "features": FEATURE_ORDER.tolist(), # 🔥 Conversia la listă pentru compatibilitate cu JSON | |
| "target": "UCS (kPa)", | |
| "valid_ranges": { | |
| "cement_perecent": {"min": 0, "max": 10, "units": "%"}, | |
| "curing_period": {"min": 1, "max": 90, "units": "days"}, | |
| "compaction_rate": {"min": 0.5, "max": 1.5, "units": "mm/min"} | |
| }, | |
| "model_parameters": { | |
| "n_estimators": 205, | |
| "max_depth": 11, | |
| "min_samples_split": 6, | |
| "min_samples_leaf": 2 | |
| } | |
| } |