Spaces:
No application file
No application file
| from fastapi import APIRouter, HTTPException | |
| from src.app.schema.transportation import TransportationRequest, TransportationResponse | |
| import pickle | |
| import joblib | |
| import numpy as np | |
| from pathlib import Path | |
| import logging | |
| from datetime import datetime | |
| import warnings | |
| from sklearn.exceptions import InconsistentVersionWarning | |
| from src.config.logging_config import get_logger | |
| # Suppress sklearn version warning for model compatibility | |
| warnings.filterwarnings("ignore", category=InconsistentVersionWarning) | |
| router = APIRouter() | |
| logger = get_logger("transportation") | |
| class TransportationPredictor: | |
| def __init__(self): | |
| self.model = None | |
| self.label_encoders = None | |
| self.shipment_encoders = None | |
| self.scaler = None | |
| self.load_models() | |
| def _download_file(self, url, token=None): | |
| import requests | |
| headers = {"Authorization": f"Bearer {token}"} if token else {} | |
| response = requests.get(url, headers=headers) | |
| response.raise_for_status() | |
| from io import BytesIO | |
| return BytesIO(response.content) | |
| def load_models(self): | |
| """Load tất cả models từ Hugging Face.""" | |
| try: | |
| from src.config.setting import settings | |
| base_url = settings.HF_MODEL_BASE_URL | |
| files = settings.HF_MODEL_FILES | |
| token = settings.HF_TOKEN | |
| with warnings.catch_warnings(): | |
| warnings.filterwarnings("ignore", category=InconsistentVersionWarning) | |
| self.model = pickle.load(self._download_file(base_url + files["xgboost_model"], token)) | |
| self.label_encoders = joblib.load(self._download_file(base_url + files["label_encoders"], token)) | |
| self.shipment_encoders = joblib.load(self._download_file(base_url + files["shipment_encoders"], token)) | |
| self.scaler = joblib.load(self._download_file(base_url + files["scaler"], token)) | |
| logger.info("All models loaded successfully from Hugging Face") | |
| except Exception as e: | |
| raise RuntimeError(f"Không thể load models từ Hugging Face: {e}") | |
| def _encode_safe(self, encoder, value: str) -> int: | |
| """Encode an toàn với fallback - convert to Python int.""" | |
| try: | |
| if encoder and value in encoder.classes_: | |
| result = encoder.transform([value])[0] | |
| return int(result) # Convert numpy int to Python int | |
| return 0 | |
| except: | |
| return 0 | |
| def predict_shipment_mode(self, request: TransportationRequest) -> TransportationResponse: | |
| """Dự đoán phương thức vận chuyển.""" | |
| try: | |
| # Ước tính weight và freight nếu thiếu | |
| weight = request.weight_kg or (request.line_item_quantity or 100) * 0.1 | |
| freight = request.freight_cost_usd or max(weight * 3, request.pack_price * 0.04, 50) | |
| # Encode features - convert all to Python int | |
| project_encoded = self._encode_safe(self.shipment_encoders.get('Project Code'), request.project_code) | |
| country_encoded = self._encode_safe(self.shipment_encoders.get('Country'), request.country) | |
| vendor_encoded = self._encode_safe(self.shipment_encoders.get('Vendor'), request.vendor) | |
| # Parse date | |
| try: | |
| date_ordinal = datetime.strptime(request.delivery_date, "%Y-%m-%d").toordinal() if request.delivery_date else datetime.now().toordinal() | |
| except: | |
| date_ordinal = datetime.now().toordinal() | |
| # Tạo feature vector | |
| features = [0, project_encoded, country_encoded, request.pack_price, vendor_encoded, freight, weight, date_ordinal] | |
| # Dự đoán | |
| prediction = self.model.predict(np.array(features).reshape(1, -1)) | |
| probabilities = self.model.predict_proba(np.array(features).reshape(1, -1))[0] | |
| classes = ['Air', 'Air Charter', 'Ocean', 'Truck'] | |
| predicted_idx = int(prediction[0]) | |
| predicted_mode = classes[predicted_idx] | |
| confidence = float(probabilities[predicted_idx]) | |
| # Alternatives - convert all numpy types to Python types | |
| alternatives = [] | |
| for i in range(len(classes)): | |
| if i != predicted_idx: | |
| alternatives.append({ | |
| 'mode': classes[i], | |
| 'probability': float(probabilities[i]) | |
| }) | |
| alternatives.sort(key=lambda x: x['probability'], reverse=True) | |
| return TransportationResponse( | |
| predicted_shipment_mode=predicted_mode, | |
| confidence_score=confidence, | |
| alternative_modes=alternatives, | |
| estimated_weight_kg=float(weight) if not request.weight_kg else None, | |
| estimated_freight_cost_usd=float(freight) if not request.freight_cost_usd else None, | |
| encoded_features={ | |
| 'Project_Code': int(project_encoded), | |
| 'Country': int(country_encoded), | |
| 'Vendor': int(vendor_encoded) | |
| }, | |
| processing_notes=[] | |
| ) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Prediction failed: {e}") | |
| # Global instance | |
| predictor = None | |
| def get_predictor(): | |
| global predictor | |
| if not predictor: | |
| predictor = TransportationPredictor() | |
| return predictor | |
| def predict_transportation(request: TransportationRequest): | |
| """Dự đoán phương thức vận chuyển tối ưu.""" | |
| return get_predictor().predict_shipment_mode(request) | |
| def get_transportation_options(): | |
| """Lấy danh sách các tùy chọn có thể cho các trường input.""" | |
| try: | |
| predictor = get_predictor() | |
| return { | |
| "shipment_modes": ["Air", "Air Charter", "Ocean", "Truck"], | |
| "sample_vendors": list(predictor.shipment_encoders['Vendor'].classes_[:10]) if predictor.shipment_encoders else [], | |
| "sample_countries": list(predictor.shipment_encoders['Country'].classes_[:10]) if predictor.shipment_encoders else [], | |
| "sample_projects": list(predictor.shipment_encoders['Project Code'].classes_[:10]) if predictor.shipment_encoders else [] | |
| } | |
| except Exception as e: | |
| return {"error": str(e)} | |