File size: 6,620 Bytes
80dbe44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
from fastapi import APIRouter, HTTPException
from src.app.schema.transportation import TransportationRequest, TransportationResponse
import pickle
import joblib
import numpy as np
from pathlib import Path
import logging
from datetime import datetime
import warnings
from sklearn.exceptions import InconsistentVersionWarning
from src.config.logging_config import get_logger

# Suppress sklearn version warning for model compatibility
warnings.filterwarnings("ignore", category=InconsistentVersionWarning)

router = APIRouter()
logger = get_logger("transportation")

class TransportationPredictor:
    def __init__(self):
        self.model = None
        self.label_encoders = None
        self.shipment_encoders = None
        self.scaler = None
        self.load_models()
    
    def _download_file(self, url, token=None):
        import requests
        headers = {"Authorization": f"Bearer {token}"} if token else {}
        response = requests.get(url, headers=headers)
        response.raise_for_status()
        from io import BytesIO
        return BytesIO(response.content)

    def load_models(self):
        """Load tất cả models từ Hugging Face."""
        try:
            from src.config.setting import settings
            base_url = settings.HF_MODEL_BASE_URL
            files = settings.HF_MODEL_FILES
            token = settings.HF_TOKEN
            with warnings.catch_warnings():
                warnings.filterwarnings("ignore", category=InconsistentVersionWarning)
                self.model = pickle.load(self._download_file(base_url + files["xgboost_model"], token))
                self.label_encoders = joblib.load(self._download_file(base_url + files["label_encoders"], token))
                self.shipment_encoders = joblib.load(self._download_file(base_url + files["shipment_encoders"], token))
                self.scaler = joblib.load(self._download_file(base_url + files["scaler"], token))
            logger.info("All models loaded successfully from Hugging Face")
        except Exception as e:
            raise RuntimeError(f"Không thể load models từ Hugging Face: {e}")
    
    def _encode_safe(self, encoder, value: str) -> int:
        """Encode an toàn với fallback - convert to Python int."""
        try:
            if encoder and value in encoder.classes_:
                result = encoder.transform([value])[0]
                return int(result)  # Convert numpy int to Python int
            return 0
        except:
            return 0
    
    def predict_shipment_mode(self, request: TransportationRequest) -> TransportationResponse:
        """Dự đoán phương thức vận chuyển."""
        try:
            # Ước tính weight và freight nếu thiếu
            weight = request.weight_kg or (request.line_item_quantity or 100) * 0.1
            freight = request.freight_cost_usd or max(weight * 3, request.pack_price * 0.04, 50)
            
            # Encode features - convert all to Python int
            project_encoded = self._encode_safe(self.shipment_encoders.get('Project Code'), request.project_code)
            country_encoded = self._encode_safe(self.shipment_encoders.get('Country'), request.country)
            vendor_encoded = self._encode_safe(self.shipment_encoders.get('Vendor'), request.vendor)
            
            # Parse date
            try:
                date_ordinal = datetime.strptime(request.delivery_date, "%Y-%m-%d").toordinal() if request.delivery_date else datetime.now().toordinal()
            except:
                date_ordinal = datetime.now().toordinal()
            
            # Tạo feature vector
            features = [0, project_encoded, country_encoded, request.pack_price, vendor_encoded, freight, weight, date_ordinal]
            
            # Dự đoán
            prediction = self.model.predict(np.array(features).reshape(1, -1))
            probabilities = self.model.predict_proba(np.array(features).reshape(1, -1))[0]
            
            classes = ['Air', 'Air Charter', 'Ocean', 'Truck']
            predicted_idx = int(prediction[0])
            predicted_mode = classes[predicted_idx]
            confidence = float(probabilities[predicted_idx])
            
            # Alternatives - convert all numpy types to Python types
            alternatives = []
            for i in range(len(classes)):
                if i != predicted_idx:
                    alternatives.append({
                        'mode': classes[i], 
                        'probability': float(probabilities[i])
                    })
            alternatives.sort(key=lambda x: x['probability'], reverse=True)
            
            return TransportationResponse(
                predicted_shipment_mode=predicted_mode,
                confidence_score=confidence,
                alternative_modes=alternatives,
                estimated_weight_kg=float(weight) if not request.weight_kg else None,
                estimated_freight_cost_usd=float(freight) if not request.freight_cost_usd else None,
                encoded_features={
                    'Project_Code': int(project_encoded), 
                    'Country': int(country_encoded), 
                    'Vendor': int(vendor_encoded)
                },
                processing_notes=[]
            )
            
        except Exception as e:
            raise HTTPException(status_code=500, detail=f"Prediction failed: {e}")

# Global instance
predictor = None

def get_predictor():
    global predictor
    if not predictor:
        predictor = TransportationPredictor()
    return predictor

@router.post('/predict-transportation', response_model=TransportationResponse)
def predict_transportation(request: TransportationRequest):
    """Dự đoán phương thức vận chuyển tối ưu."""
    return get_predictor().predict_shipment_mode(request)

@router.get('/transportation-options')
def get_transportation_options():
    """Lấy danh sách các tùy chọn có thể cho các trường input."""
    try:
        predictor = get_predictor()
        return {
            "shipment_modes": ["Air", "Air Charter", "Ocean", "Truck"],
            "sample_vendors": list(predictor.shipment_encoders['Vendor'].classes_[:10]) if predictor.shipment_encoders else [],
            "sample_countries": list(predictor.shipment_encoders['Country'].classes_[:10]) if predictor.shipment_encoders else [],
            "sample_projects": list(predictor.shipment_encoders['Project Code'].classes_[:10]) if predictor.shipment_encoders else []
        }
    except Exception as e:
        return {"error": str(e)}