File size: 7,936 Bytes
8eab558
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
import os
import joblib
import numpy as np
import pandas as pd
import cv2
import tensorflow as tf
from patchify import patchify

# 1. Define Custom Layers
@tf.keras.utils.register_keras_serializable()
class ClassToken(tf.keras.layers.Layer):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
    def build(self, input_shape):
        self.hidden_dim = input_shape[-1]
        self.w = self.add_weight(
            name="cls_token",
            shape=(1, 1, self.hidden_dim),
            initializer="random_normal",
            trainable=True,
        )
    def call(self, inputs):
        batch_size = tf.shape(inputs)[0]
        cls = tf.broadcast_to(self.w, [batch_size, 1, self.hidden_dim])
        return cls

@tf.keras.utils.register_keras_serializable()
class ExtractCLSToken(tf.keras.layers.Layer):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
    def call(self, inputs):
        return inputs[:, 0, :]

class DiamondInference:
    def __init__(self, model_path, encoder_dir, model_id=None):
        # Use provided model_id to load specific artifacts, fallback to generic if not provided
        self.model_id = model_id
        
        if model_id:
            hp_path = os.path.join(encoder_dir, f"hyperparameters_{model_id}.pkl")
            cat_path = os.path.join(encoder_dir, f"cat_encoders_{model_id}.pkl")
            num_path = os.path.join(encoder_dir, f"num_scaler_{model_id}.pkl")
            target_path = os.path.join(encoder_dir, f"target_encoder_{model_id}.pkl")
            norm_stats_path = os.path.join(encoder_dir, f"norm_stats_{model_id}.pkl")
        else:
            # Fallback to older generic names if no ID is passed
            hp_path = os.path.join(encoder_dir, "hyperparameters_imagenet_100ep.pkl")
            cat_path = os.path.join(encoder_dir, "cat_encoders_imagenet_100ep.pkl")
            num_path = os.path.join(encoder_dir, "num_scaler_imagenet_100ep.pkl")
            target_path = os.path.join(encoder_dir, "target_encoder_imagenet_100ep.pkl")
            norm_stats_path = os.path.join(encoder_dir, "norm_stats_imagenet_100ep.pkl")

        print(f"[INFO] Loading artifacts for model ID: {model_id or 'default'}")
        self.hp = joblib.load(hp_path)
        self.cat_encoders = joblib.load(cat_path)
        self.num_scaler = joblib.load(num_path)
        self.target_encoder = joblib.load(target_path)
        
        if os.path.exists(norm_stats_path):
            self.norm_stats = joblib.load(norm_stats_path)
        else:
            # Default fallback to ImageNet stats
            self.norm_stats = {"mean": np.array([0.485, 0.456, 0.406]), "std": np.array([0.229, 0.224, 0.225])}
        
        self.model = tf.keras.models.load_model(
            model_path, 
            custom_objects={"ClassToken": ClassToken, "ExtractCLSToken": ExtractCLSToken},
            compile=False
        )
        print(f"[INFO] Model and artifacts loaded successfully from {model_path}.")

    def apply_tta_transform(self, img, transform_type):
        """Apply specific Test-Time Augmentation transformation"""
        if transform_type == "original":
            return img
        elif transform_type == "horizontal_flip":
            return cv2.flip(img, 1)
        elif transform_type == "rotation_5":
            h, w = img.shape[:2]
            M = cv2.getRotationMatrix2D((w//2, h//2), 5, 1.0)
            return cv2.warpAffine(img, M, (w, h), borderMode=cv2.BORDER_REFLECT)
        elif transform_type == "rotation_minus_5":
            h, w = img.shape[:2]
            M = cv2.getRotationMatrix2D((w//2, h//2), -5, 1.0)
            return cv2.warpAffine(img, M, (w, h), borderMode=cv2.BORDER_REFLECT)
        elif transform_type == "brightness_up":
            return np.clip(img * 1.1, 0, 255).astype(np.uint8)
        return img

    def process_image(self, image_path, tta_transform=None):
        try:
            image = cv2.imread(image_path, cv2.IMREAD_COLOR)
            if image is None:
                return np.zeros(self.hp["flat_patches_shape"], dtype=np.float32)
            
            image = cv2.resize(image, (self.hp["image_size"], self.hp["image_size"]))
            
            if tta_transform:
                image = self.apply_tta_transform(image, tta_transform)
            
            image = image / 255.0
            image = (image - self.norm_stats["mean"]) / (self.norm_stats["std"] + 1e-7)
            
            patch_shape = (self.hp["patch_size"], self.hp["patch_size"], self.hp["num_channels"])
            patches = patchify(image, patch_shape, self.hp["patch_size"])
            patches = np.reshape(patches, self.hp["flat_patches_shape"]).astype(np.float32)
            return patches
        except Exception as e:
            print(f"[ERROR] Image processing failed: {e}")
            return np.zeros(self.hp["flat_patches_shape"], dtype=np.float32)

    def predict(self, df_row, image_path, use_tta=True):
        # 1. Preprocess Tabular Data
        # Match training categorical features: StoneType, Color, Brown, BlueUv, GrdType, Result
        categorical_cols = ["StoneType", "Color", "Brown", "BlueUv", "GrdType", "Result"]
        numerical_cols = ["Carat"]
        
        tab_data_list = []
        for col in categorical_cols:
            val = str(df_row.get(col, "__missing__"))
            # Safe transform for categorical values
            try:
                # First check if the column exists in encoders
                if col in self.cat_encoders:
                    # Check if val is in encoder classes, otherwise fallback to __missing__
                    if val not in self.cat_encoders[col].classes_:
                        val = "__missing__" if "__missing__" in self.cat_encoders[col].classes_ else self.cat_encoders[col].classes_[0]
                    
                    encoded_val = self.cat_encoders[col].transform([val])[0]
                else:
                    print(f"[WARN] Encoder for column {col} not found. Using 0.")
                    encoded_val = 0
            except Exception as e:
                print(f"[ERROR] Encoding failed for {col} with value {val}: {e}. Using 0.")
                encoded_val = 0
            tab_data_list.append(encoded_val)
        
        for col in numerical_cols:
            try:
                val = float(df_row.get(col, 0))
                # Reshape for scaler (expected 2D array)
                scaled_val = self.num_scaler.transform([[val]])[0][0]
            except Exception as e:
                print(f"[ERROR] Scaling failed for {col}: {e}. Using 0.")
                scaled_val = 0
            tab_data_list.append(scaled_val)
            
        tab_input = np.expand_dims(np.array(tab_data_list, dtype=np.float32), axis=0)
        
        # 2. Inference with TTA
        if use_tta:
            tta_transforms = ["original", "horizontal_flip", "rotation_5", "rotation_minus_5", "brightness_up"]
            all_preds = []
            
            for transform in tta_transforms:
                img_patches = self.process_image(image_path, tta_transform=transform)
                img_input = np.expand_dims(img_patches, axis=0)
                preds = self.model.predict([img_input, tab_input], verbose=0)[0]
                all_preds.append(preds)
            
            final_pred_probs = np.mean(all_preds, axis=0)
        else:
            img_patches = self.process_image(image_path)
            img_input = np.expand_dims(img_patches, axis=0)
            final_pred_probs = self.model.predict([img_input, tab_input], verbose=0)[0]
        
        pred_idx = np.argmax(final_pred_probs)
        decoded_pred = self.target_encoder.inverse_transform([pred_idx])[0]
        
        return decoded_pred