ucs / app_.py
bteodoru's picture
Rename app.py to app_.py
14baa6d verified
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import JSONResponse
# from pydantic import BaseModel, Field, validator, model_validator
# from pydantic.errors import PydanticValueError
import pandas as pd
import joblib
import numpy as np
from sklearn.ensemble import RandomForestRegressor # Important pentru deserializare
from pydantic import BaseModel, ValidationError, Field, field_validator, model_validator
from typing import Any
class RobustModelWrapper:
"""Wrapper robust pentru model, compatibil cu FastAPI."""
def __init__(self, model, feature_names):
self.model = model
self.feature_names_in_ = np.array(feature_names)
def predict(self, X):
"""Realizează predicții asigurându-se că datele sunt în formatul corect."""
# Convertim la DataFrame dacă nu este deja
if not isinstance(X, pd.DataFrame):
X = pd.DataFrame(X, columns=self.feature_names_in_)
# Asigură-te că DataFrame-ul are exact coloanele necesare în ordinea corectă
prediction_df = pd.DataFrame()
for feature in self.feature_names_in_:
if feature in X.columns:
prediction_df[feature] = X[feature]
else:
raise ValueError(f"Caracteristica '{feature}' lipsește din datele de intrare")
# Acum realizăm predicția cu coloanele în ordinea corectă
return self.model.predict(prediction_df)
app = FastAPI()
# @app.exception_handler(ValidationError)
# async def validation_exception_handler(request: Request, exc: ValidationError):
# errors = []
# for error in exc.errors():
# # Elimină prefixul "Value error" din mesaj
# message = error['msg']
# if message.startswith("Value error, "):
# message = message[12:] # Lungimea "Value error: " este 12
# errors.append({"loc": error['loc'], "msg": message})
# return JSONResponse(
# status_code=422,
# content={"detail": errors}
# )
# @app.exception_handler(ValidationError)
# async def validation_exception_handler(request: Request, exc: ValidationError):
# errors = []
# for error in exc.errors():
# # Extragem mesajul și eliminăm prefixul "Value error, "
# message = error['msg']
# if message.startswith("Value error, "):
# message = message[12:] # Lungimea "Value error, " este 12
# # Construim eroarea păstrând toate câmpurile originale,
# # dar cu mesajul modificat
# error_dict = {
# "type": error.get('type'),
# "loc": error.get('loc'),
# "msg": message,
# "input": error.get('input'),
# "ctx": error.get('ctx'),
# "url": error.get('url')
# }
# errors.append(error_dict)
# return JSONResponse(
# status_code=422,
# content={"detail": errors}
# )
# Încărcăm modelul
try:
model = joblib.load('rf_model_optim.joblib')
FEATURE_ORDER = model.feature_names_in_ # Obținem ordinea corectă a caracteristicilor
print("Model încărcat cu succes! Feature Order:", FEATURE_ORDER)
except Exception as e:
print(f"Eroare la încărcarea modelului: {str(e)}")
model = None # Setăm modelul ca None în caz de eroare
FEATURE_ORDER = [] # Inițializăm o listă goală pentru a evita erorile ulterioare
# # Definim clase personalizate pentru erori
# class CementPercentError(PydanticValueError):
# msg_template = "Cement percentage must be between 0% and 15%"
# class CuringPeriodError(PydanticValueError):
# msg_template = "Curing period must be between 1 and 90 days"
# class CompactionRateError(PydanticValueError):
# msg_template = "Compaction velocity must be between 0.5 and 1.5 mm/min"
# class SoilInput__(BaseModel):
# cement_perecent: float
# curing_period: float
# compaction_rate: float
# @model_validator(mode="after")
# def check_cement_and_curing(self):
# if self.cement_perecent == 0:
# self.curing_period = 0
# else:
# if not (1 <= self.curing_period <= 90):
# # raise CuringPeriodError()
# raise ValueError("Curing period must be between 1 and 90 days")
# return self
# @validator('cement_perecent')
# def validate_cement(cls, v):
# if not 0 <= v <= 15:
# # raise CementPercentError()
# raise ValueError("Cement percentage must be between 0% and 15%")
# return v
# @validator('compaction_rate')
# def validate_compaction(cls, v):
# if not 0.5 <= v <= 1.5:
# # raise CompactionRateError()
# raise ValueError("Compaction velocity must be between 0.5 and 1.5 mm/min")
# return v
class SoilInput(BaseModel):
cement_perecent: float = Field(
...,
# ge=0,
# le=15,
description="Cement percentage in the mixture"
)
curing_period: float = Field(
...,
# ge=1,
# le=90,
description="Number of days for curing"
)
compaction_rate: float = Field(
...,
# ge=0.5,
# le=1.5,
description="Rate of compaction in mm/min"
)
@model_validator(mode="after")
def check_cement_and_curing(self):
if self.cement_perecent == 0:
self.curing_period = 0
else:
if not (1 <= self.curing_period <= 90):
# raise CuringPeriodError()
raise ValueError("Curing period must be between 1 and 90 days")
return self
@field_validator('cement_perecent')
@classmethod
def validate_cement(cls, v: float) -> float:
if not 0 <= v <= 15:
raise ValueError("Cement percentage must be between 0% and 15%")
return v
# @field_validator('curing_period')
# @classmethod
# def validate_curing(cls, v: float) -> float:
# if not 1 <= v <= 90:
# raise ValueError("Curing period must be between 1 and 90 days")
# return v
@field_validator('compaction_rate')
@classmethod
def validate_compaction(cls, v: float) -> float:
if not 0.5 <= v <= 1.5:
raise ValueError("Compaction rate must be between 0.5 and 1.5 mm/min")
return v
# class SoilInput(BaseModel):
# cement_perecent: float = Field(...)
# curing_period: float = Field(...)
# compaction_rate: float = Field(...)
# @field_validator('cement_perecent')
# @classmethod
# def validate_cement(cls, v: float) -> float:
# if not 0 <= v <= 15:
# raise ValueError("Cement percentage must be between 0% and 15%")
# return v
# @field_validator('curing_period')
# @classmethod
# def validate_curing(cls, v: float, info) -> float:
# # Obținem valorile celorlalte câmpuri
# values = info.data
# cement_percent = values.get('cement_perecent', 0)
# # Aplicăm logica de validare
# if cement_percent == 0:
# return 0
# if not 1 <= v <= 90:
# raise ValueError("Curing period must be between 1 and 90 days")
# return v
# @field_validator('compaction_rate')
# @classmethod
# def validate_compaction(cls, v: float) -> float:
# if not 0.5 <= v <= 1.5:
# raise ValueError("Compaction rate must be between 0.5 and 1.5 mm/min")
# return v
@app.post("/predict")
async def predict(soil_data: SoilInput):
"""
Realizează predicții pentru UCS
"""
if model is None:
raise HTTPException(status_code=500, detail="Modelul nu a fost încărcat corect")
try:
# input_data = soil_data.dict()
# # Aplicăm regula fizică: dacă nu avem ciment, perioada de maturare este 0
# if input_data['cement_perecent'] == 0:
# # Păstrăm datele originale pentru răspuns
# original_data = input_data.copy()
# # Modificăm perioada de maturare pentru predicție
# input_data['curing_period'] = 0
# # Adăugăm o notă explicativă în răspuns
# explanation = "Pentru amestecuri fără ciment, perioada de maturare nu influențează rezistența."
# else:
# original_data = input_data
# explanation = None
# Construim DataFrame-ul pentru predicție
input_data = soil_data.dict()
input_df = pd.DataFrame([input_data])
# Ne asigurăm că ordinea caracteristicilor este corectă
input_df = input_df[FEATURE_ORDER]
# Facem predicția
prediction = model.predict(input_df)
return {
"success": True,
"prediction": float(prediction[0]),
"units": "kPa",
"input_parameters": input_data
}
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
@app.get("/status")
async def root():
"""
Endpoint pentru verificarea stării API-ului
"""
return {"status": "API is running", "model_loaded": model is not None}
@app.get("/model-info")
async def model_info():
"""
Endpoint pentru informații despre model
"""
if model is None:
raise HTTPException(status_code=500, detail="Modelul nu a fost încărcat corect")
return {
"model_type": "Random Forest Regressor",
"features": FEATURE_ORDER.tolist(), # 🔥 Conversia la listă pentru compatibilitate cu JSON
"target": "UCS (kPa)",
"valid_ranges": {
"cement_perecent": {"min": 0, "max": 10, "units": "%"},
"curing_period": {"min": 1, "max": 90, "units": "days"},
"compaction_rate": {"min": 0.5, "max": 1.5, "units": "mm/min"}
},
"model_parameters": {
"n_estimators": 205,
"max_depth": 11,
"min_samples_split": 6,
"min_samples_leaf": 2
}
}