TrBn17
reset repo without Dataset.xlsx
80dbe44
from fastapi import APIRouter, HTTPException
from src.app.schema.transportation import TransportationRequest, TransportationResponse
import pickle
import joblib
import numpy as np
from pathlib import Path
import logging
from datetime import datetime
import warnings
from sklearn.exceptions import InconsistentVersionWarning
from src.config.logging_config import get_logger
# Suppress sklearn version warning for model compatibility
warnings.filterwarnings("ignore", category=InconsistentVersionWarning)
router = APIRouter()
logger = get_logger("transportation")
class TransportationPredictor:
def __init__(self):
self.model = None
self.label_encoders = None
self.shipment_encoders = None
self.scaler = None
self.load_models()
def _download_file(self, url, token=None):
import requests
headers = {"Authorization": f"Bearer {token}"} if token else {}
response = requests.get(url, headers=headers)
response.raise_for_status()
from io import BytesIO
return BytesIO(response.content)
def load_models(self):
"""Load tất cả models từ Hugging Face."""
try:
from src.config.setting import settings
base_url = settings.HF_MODEL_BASE_URL
files = settings.HF_MODEL_FILES
token = settings.HF_TOKEN
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=InconsistentVersionWarning)
self.model = pickle.load(self._download_file(base_url + files["xgboost_model"], token))
self.label_encoders = joblib.load(self._download_file(base_url + files["label_encoders"], token))
self.shipment_encoders = joblib.load(self._download_file(base_url + files["shipment_encoders"], token))
self.scaler = joblib.load(self._download_file(base_url + files["scaler"], token))
logger.info("All models loaded successfully from Hugging Face")
except Exception as e:
raise RuntimeError(f"Không thể load models từ Hugging Face: {e}")
def _encode_safe(self, encoder, value: str) -> int:
"""Encode an toàn với fallback - convert to Python int."""
try:
if encoder and value in encoder.classes_:
result = encoder.transform([value])[0]
return int(result) # Convert numpy int to Python int
return 0
except:
return 0
def predict_shipment_mode(self, request: TransportationRequest) -> TransportationResponse:
"""Dự đoán phương thức vận chuyển."""
try:
# Ước tính weight và freight nếu thiếu
weight = request.weight_kg or (request.line_item_quantity or 100) * 0.1
freight = request.freight_cost_usd or max(weight * 3, request.pack_price * 0.04, 50)
# Encode features - convert all to Python int
project_encoded = self._encode_safe(self.shipment_encoders.get('Project Code'), request.project_code)
country_encoded = self._encode_safe(self.shipment_encoders.get('Country'), request.country)
vendor_encoded = self._encode_safe(self.shipment_encoders.get('Vendor'), request.vendor)
# Parse date
try:
date_ordinal = datetime.strptime(request.delivery_date, "%Y-%m-%d").toordinal() if request.delivery_date else datetime.now().toordinal()
except:
date_ordinal = datetime.now().toordinal()
# Tạo feature vector
features = [0, project_encoded, country_encoded, request.pack_price, vendor_encoded, freight, weight, date_ordinal]
# Dự đoán
prediction = self.model.predict(np.array(features).reshape(1, -1))
probabilities = self.model.predict_proba(np.array(features).reshape(1, -1))[0]
classes = ['Air', 'Air Charter', 'Ocean', 'Truck']
predicted_idx = int(prediction[0])
predicted_mode = classes[predicted_idx]
confidence = float(probabilities[predicted_idx])
# Alternatives - convert all numpy types to Python types
alternatives = []
for i in range(len(classes)):
if i != predicted_idx:
alternatives.append({
'mode': classes[i],
'probability': float(probabilities[i])
})
alternatives.sort(key=lambda x: x['probability'], reverse=True)
return TransportationResponse(
predicted_shipment_mode=predicted_mode,
confidence_score=confidence,
alternative_modes=alternatives,
estimated_weight_kg=float(weight) if not request.weight_kg else None,
estimated_freight_cost_usd=float(freight) if not request.freight_cost_usd else None,
encoded_features={
'Project_Code': int(project_encoded),
'Country': int(country_encoded),
'Vendor': int(vendor_encoded)
},
processing_notes=[]
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Prediction failed: {e}")
# Global instance
predictor = None
def get_predictor():
global predictor
if not predictor:
predictor = TransportationPredictor()
return predictor
@router.post('/predict-transportation', response_model=TransportationResponse)
def predict_transportation(request: TransportationRequest):
"""Dự đoán phương thức vận chuyển tối ưu."""
return get_predictor().predict_shipment_mode(request)
@router.get('/transportation-options')
def get_transportation_options():
"""Lấy danh sách các tùy chọn có thể cho các trường input."""
try:
predictor = get_predictor()
return {
"shipment_modes": ["Air", "Air Charter", "Ocean", "Truck"],
"sample_vendors": list(predictor.shipment_encoders['Vendor'].classes_[:10]) if predictor.shipment_encoders else [],
"sample_countries": list(predictor.shipment_encoders['Country'].classes_[:10]) if predictor.shipment_encoders else [],
"sample_projects": list(predictor.shipment_encoders['Project Code'].classes_[:10]) if predictor.shipment_encoders else []
}
except Exception as e:
return {"error": str(e)}