Spaces:
No application file
No application file
File size: 6,620 Bytes
80dbe44 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 |
from fastapi import APIRouter, HTTPException
from src.app.schema.transportation import TransportationRequest, TransportationResponse
import pickle
import joblib
import numpy as np
from pathlib import Path
import logging
from datetime import datetime
import warnings
from sklearn.exceptions import InconsistentVersionWarning
from src.config.logging_config import get_logger
# Suppress sklearn version warning for model compatibility
warnings.filterwarnings("ignore", category=InconsistentVersionWarning)
router = APIRouter()
logger = get_logger("transportation")
class TransportationPredictor:
def __init__(self):
self.model = None
self.label_encoders = None
self.shipment_encoders = None
self.scaler = None
self.load_models()
def _download_file(self, url, token=None):
import requests
headers = {"Authorization": f"Bearer {token}"} if token else {}
response = requests.get(url, headers=headers)
response.raise_for_status()
from io import BytesIO
return BytesIO(response.content)
def load_models(self):
"""Load tất cả models từ Hugging Face."""
try:
from src.config.setting import settings
base_url = settings.HF_MODEL_BASE_URL
files = settings.HF_MODEL_FILES
token = settings.HF_TOKEN
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=InconsistentVersionWarning)
self.model = pickle.load(self._download_file(base_url + files["xgboost_model"], token))
self.label_encoders = joblib.load(self._download_file(base_url + files["label_encoders"], token))
self.shipment_encoders = joblib.load(self._download_file(base_url + files["shipment_encoders"], token))
self.scaler = joblib.load(self._download_file(base_url + files["scaler"], token))
logger.info("All models loaded successfully from Hugging Face")
except Exception as e:
raise RuntimeError(f"Không thể load models từ Hugging Face: {e}")
def _encode_safe(self, encoder, value: str) -> int:
"""Encode an toàn với fallback - convert to Python int."""
try:
if encoder and value in encoder.classes_:
result = encoder.transform([value])[0]
return int(result) # Convert numpy int to Python int
return 0
except:
return 0
def predict_shipment_mode(self, request: TransportationRequest) -> TransportationResponse:
"""Dự đoán phương thức vận chuyển."""
try:
# Ước tính weight và freight nếu thiếu
weight = request.weight_kg or (request.line_item_quantity or 100) * 0.1
freight = request.freight_cost_usd or max(weight * 3, request.pack_price * 0.04, 50)
# Encode features - convert all to Python int
project_encoded = self._encode_safe(self.shipment_encoders.get('Project Code'), request.project_code)
country_encoded = self._encode_safe(self.shipment_encoders.get('Country'), request.country)
vendor_encoded = self._encode_safe(self.shipment_encoders.get('Vendor'), request.vendor)
# Parse date
try:
date_ordinal = datetime.strptime(request.delivery_date, "%Y-%m-%d").toordinal() if request.delivery_date else datetime.now().toordinal()
except:
date_ordinal = datetime.now().toordinal()
# Tạo feature vector
features = [0, project_encoded, country_encoded, request.pack_price, vendor_encoded, freight, weight, date_ordinal]
# Dự đoán
prediction = self.model.predict(np.array(features).reshape(1, -1))
probabilities = self.model.predict_proba(np.array(features).reshape(1, -1))[0]
classes = ['Air', 'Air Charter', 'Ocean', 'Truck']
predicted_idx = int(prediction[0])
predicted_mode = classes[predicted_idx]
confidence = float(probabilities[predicted_idx])
# Alternatives - convert all numpy types to Python types
alternatives = []
for i in range(len(classes)):
if i != predicted_idx:
alternatives.append({
'mode': classes[i],
'probability': float(probabilities[i])
})
alternatives.sort(key=lambda x: x['probability'], reverse=True)
return TransportationResponse(
predicted_shipment_mode=predicted_mode,
confidence_score=confidence,
alternative_modes=alternatives,
estimated_weight_kg=float(weight) if not request.weight_kg else None,
estimated_freight_cost_usd=float(freight) if not request.freight_cost_usd else None,
encoded_features={
'Project_Code': int(project_encoded),
'Country': int(country_encoded),
'Vendor': int(vendor_encoded)
},
processing_notes=[]
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Prediction failed: {e}")
# Global instance
predictor = None
def get_predictor():
global predictor
if not predictor:
predictor = TransportationPredictor()
return predictor
@router.post('/predict-transportation', response_model=TransportationResponse)
def predict_transportation(request: TransportationRequest):
"""Dự đoán phương thức vận chuyển tối ưu."""
return get_predictor().predict_shipment_mode(request)
@router.get('/transportation-options')
def get_transportation_options():
"""Lấy danh sách các tùy chọn có thể cho các trường input."""
try:
predictor = get_predictor()
return {
"shipment_modes": ["Air", "Air Charter", "Ocean", "Truck"],
"sample_vendors": list(predictor.shipment_encoders['Vendor'].classes_[:10]) if predictor.shipment_encoders else [],
"sample_countries": list(predictor.shipment_encoders['Country'].classes_[:10]) if predictor.shipment_encoders else [],
"sample_projects": list(predictor.shipment_encoders['Project Code'].classes_[:10]) if predictor.shipment_encoders else []
}
except Exception as e:
return {"error": str(e)}
|