from fastapi import FastAPI, HTTPException, Request from fastapi.responses import JSONResponse # from pydantic import BaseModel, Field, validator, model_validator # from pydantic.errors import PydanticValueError import pandas as pd import joblib import numpy as np from sklearn.ensemble import RandomForestRegressor # Important pentru deserializare from pydantic import BaseModel, ValidationError, Field, field_validator, model_validator from typing import Any class RobustModelWrapper: """Wrapper robust pentru model, compatibil cu FastAPI.""" def __init__(self, model, feature_names): self.model = model self.feature_names_in_ = np.array(feature_names) def predict(self, X): """Realizează predicții asigurându-se că datele sunt în formatul corect.""" # Convertim la DataFrame dacă nu este deja if not isinstance(X, pd.DataFrame): X = pd.DataFrame(X, columns=self.feature_names_in_) # Asigură-te că DataFrame-ul are exact coloanele necesare în ordinea corectă prediction_df = pd.DataFrame() for feature in self.feature_names_in_: if feature in X.columns: prediction_df[feature] = X[feature] else: raise ValueError(f"Caracteristica '{feature}' lipsește din datele de intrare") # Acum realizăm predicția cu coloanele în ordinea corectă return self.model.predict(prediction_df) app = FastAPI() # @app.exception_handler(ValidationError) # async def validation_exception_handler(request: Request, exc: ValidationError): # errors = [] # for error in exc.errors(): # # Elimină prefixul "Value error" din mesaj # message = error['msg'] # if message.startswith("Value error, "): # message = message[12:] # Lungimea "Value error: " este 12 # errors.append({"loc": error['loc'], "msg": message}) # return JSONResponse( # status_code=422, # content={"detail": errors} # ) # @app.exception_handler(ValidationError) # async def validation_exception_handler(request: Request, exc: ValidationError): # errors = [] # for error in exc.errors(): # # Extragem mesajul și eliminăm prefixul "Value error, " # message = error['msg'] # if message.startswith("Value error, "): # message = message[12:] # Lungimea "Value error, " este 12 # # Construim eroarea păstrând toate câmpurile originale, # # dar cu mesajul modificat # error_dict = { # "type": error.get('type'), # "loc": error.get('loc'), # "msg": message, # "input": error.get('input'), # "ctx": error.get('ctx'), # "url": error.get('url') # } # errors.append(error_dict) # return JSONResponse( # status_code=422, # content={"detail": errors} # ) # Încărcăm modelul try: model = joblib.load('rf_model_optim.joblib') FEATURE_ORDER = model.feature_names_in_ # Obținem ordinea corectă a caracteristicilor print("Model încărcat cu succes! Feature Order:", FEATURE_ORDER) except Exception as e: print(f"Eroare la încărcarea modelului: {str(e)}") model = None # Setăm modelul ca None în caz de eroare FEATURE_ORDER = [] # Inițializăm o listă goală pentru a evita erorile ulterioare # # Definim clase personalizate pentru erori # class CementPercentError(PydanticValueError): # msg_template = "Cement percentage must be between 0% and 15%" # class CuringPeriodError(PydanticValueError): # msg_template = "Curing period must be between 1 and 90 days" # class CompactionRateError(PydanticValueError): # msg_template = "Compaction velocity must be between 0.5 and 1.5 mm/min" # class SoilInput__(BaseModel): # cement_perecent: float # curing_period: float # compaction_rate: float # @model_validator(mode="after") # def check_cement_and_curing(self): # if self.cement_perecent == 0: # self.curing_period = 0 # else: # if not (1 <= self.curing_period <= 90): # # raise CuringPeriodError() # raise ValueError("Curing period must be between 1 and 90 days") # return self # @validator('cement_perecent') # def validate_cement(cls, v): # if not 0 <= v <= 15: # # raise CementPercentError() # raise ValueError("Cement percentage must be between 0% and 15%") # return v # @validator('compaction_rate') # def validate_compaction(cls, v): # if not 0.5 <= v <= 1.5: # # raise CompactionRateError() # raise ValueError("Compaction velocity must be between 0.5 and 1.5 mm/min") # return v class SoilInput(BaseModel): cement_perecent: float = Field( ..., # ge=0, # le=15, description="Cement percentage in the mixture" ) curing_period: float = Field( ..., # ge=1, # le=90, description="Number of days for curing" ) compaction_rate: float = Field( ..., # ge=0.5, # le=1.5, description="Rate of compaction in mm/min" ) @model_validator(mode="after") def check_cement_and_curing(self): if self.cement_perecent == 0: self.curing_period = 0 else: if not (1 <= self.curing_period <= 90): # raise CuringPeriodError() raise ValueError("Curing period must be between 1 and 90 days") return self @field_validator('cement_perecent') @classmethod def validate_cement(cls, v: float) -> float: if not 0 <= v <= 15: raise ValueError("Cement percentage must be between 0% and 15%") return v # @field_validator('curing_period') # @classmethod # def validate_curing(cls, v: float) -> float: # if not 1 <= v <= 90: # raise ValueError("Curing period must be between 1 and 90 days") # return v @field_validator('compaction_rate') @classmethod def validate_compaction(cls, v: float) -> float: if not 0.5 <= v <= 1.5: raise ValueError("Compaction rate must be between 0.5 and 1.5 mm/min") return v # class SoilInput(BaseModel): # cement_perecent: float = Field(...) # curing_period: float = Field(...) # compaction_rate: float = Field(...) # @field_validator('cement_perecent') # @classmethod # def validate_cement(cls, v: float) -> float: # if not 0 <= v <= 15: # raise ValueError("Cement percentage must be between 0% and 15%") # return v # @field_validator('curing_period') # @classmethod # def validate_curing(cls, v: float, info) -> float: # # Obținem valorile celorlalte câmpuri # values = info.data # cement_percent = values.get('cement_perecent', 0) # # Aplicăm logica de validare # if cement_percent == 0: # return 0 # if not 1 <= v <= 90: # raise ValueError("Curing period must be between 1 and 90 days") # return v # @field_validator('compaction_rate') # @classmethod # def validate_compaction(cls, v: float) -> float: # if not 0.5 <= v <= 1.5: # raise ValueError("Compaction rate must be between 0.5 and 1.5 mm/min") # return v @app.post("/predict") async def predict(soil_data: SoilInput): """ Realizează predicții pentru UCS """ if model is None: raise HTTPException(status_code=500, detail="Modelul nu a fost încărcat corect") try: # input_data = soil_data.dict() # # Aplicăm regula fizică: dacă nu avem ciment, perioada de maturare este 0 # if input_data['cement_perecent'] == 0: # # Păstrăm datele originale pentru răspuns # original_data = input_data.copy() # # Modificăm perioada de maturare pentru predicție # input_data['curing_period'] = 0 # # Adăugăm o notă explicativă în răspuns # explanation = "Pentru amestecuri fără ciment, perioada de maturare nu influențează rezistența." # else: # original_data = input_data # explanation = None # Construim DataFrame-ul pentru predicție input_data = soil_data.dict() input_df = pd.DataFrame([input_data]) # Ne asigurăm că ordinea caracteristicilor este corectă input_df = input_df[FEATURE_ORDER] # Facem predicția prediction = model.predict(input_df) return { "success": True, "prediction": float(prediction[0]), "units": "kPa", "input_parameters": input_data } except Exception as e: raise HTTPException(status_code=400, detail=str(e)) @app.get("/status") async def root(): """ Endpoint pentru verificarea stării API-ului """ return {"status": "API is running", "model_loaded": model is not None} @app.get("/model-info") async def model_info(): """ Endpoint pentru informații despre model """ if model is None: raise HTTPException(status_code=500, detail="Modelul nu a fost încărcat corect") return { "model_type": "Random Forest Regressor", "features": FEATURE_ORDER.tolist(), # 🔥 Conversia la listă pentru compatibilitate cu JSON "target": "UCS (kPa)", "valid_ranges": { "cement_perecent": {"min": 0, "max": 10, "units": "%"}, "curing_period": {"min": 1, "max": 90, "units": "days"}, "compaction_rate": {"min": 0.5, "max": 1.5, "units": "mm/min"} }, "model_parameters": { "n_estimators": 205, "max_depth": 11, "min_samples_split": 6, "min_samples_leaf": 2 } }