File size: 7,936 Bytes
8eab558 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 | import os
import joblib
import numpy as np
import pandas as pd
import cv2
import tensorflow as tf
from patchify import patchify
# 1. Define Custom Layers
@tf.keras.utils.register_keras_serializable()
class ClassToken(tf.keras.layers.Layer):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def build(self, input_shape):
self.hidden_dim = input_shape[-1]
self.w = self.add_weight(
name="cls_token",
shape=(1, 1, self.hidden_dim),
initializer="random_normal",
trainable=True,
)
def call(self, inputs):
batch_size = tf.shape(inputs)[0]
cls = tf.broadcast_to(self.w, [batch_size, 1, self.hidden_dim])
return cls
@tf.keras.utils.register_keras_serializable()
class ExtractCLSToken(tf.keras.layers.Layer):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def call(self, inputs):
return inputs[:, 0, :]
class DiamondInference:
def __init__(self, model_path, encoder_dir, model_id=None):
# Use provided model_id to load specific artifacts, fallback to generic if not provided
self.model_id = model_id
if model_id:
hp_path = os.path.join(encoder_dir, f"hyperparameters_{model_id}.pkl")
cat_path = os.path.join(encoder_dir, f"cat_encoders_{model_id}.pkl")
num_path = os.path.join(encoder_dir, f"num_scaler_{model_id}.pkl")
target_path = os.path.join(encoder_dir, f"target_encoder_{model_id}.pkl")
norm_stats_path = os.path.join(encoder_dir, f"norm_stats_{model_id}.pkl")
else:
# Fallback to older generic names if no ID is passed
hp_path = os.path.join(encoder_dir, "hyperparameters_imagenet_100ep.pkl")
cat_path = os.path.join(encoder_dir, "cat_encoders_imagenet_100ep.pkl")
num_path = os.path.join(encoder_dir, "num_scaler_imagenet_100ep.pkl")
target_path = os.path.join(encoder_dir, "target_encoder_imagenet_100ep.pkl")
norm_stats_path = os.path.join(encoder_dir, "norm_stats_imagenet_100ep.pkl")
print(f"[INFO] Loading artifacts for model ID: {model_id or 'default'}")
self.hp = joblib.load(hp_path)
self.cat_encoders = joblib.load(cat_path)
self.num_scaler = joblib.load(num_path)
self.target_encoder = joblib.load(target_path)
if os.path.exists(norm_stats_path):
self.norm_stats = joblib.load(norm_stats_path)
else:
# Default fallback to ImageNet stats
self.norm_stats = {"mean": np.array([0.485, 0.456, 0.406]), "std": np.array([0.229, 0.224, 0.225])}
self.model = tf.keras.models.load_model(
model_path,
custom_objects={"ClassToken": ClassToken, "ExtractCLSToken": ExtractCLSToken},
compile=False
)
print(f"[INFO] Model and artifacts loaded successfully from {model_path}.")
def apply_tta_transform(self, img, transform_type):
"""Apply specific Test-Time Augmentation transformation"""
if transform_type == "original":
return img
elif transform_type == "horizontal_flip":
return cv2.flip(img, 1)
elif transform_type == "rotation_5":
h, w = img.shape[:2]
M = cv2.getRotationMatrix2D((w//2, h//2), 5, 1.0)
return cv2.warpAffine(img, M, (w, h), borderMode=cv2.BORDER_REFLECT)
elif transform_type == "rotation_minus_5":
h, w = img.shape[:2]
M = cv2.getRotationMatrix2D((w//2, h//2), -5, 1.0)
return cv2.warpAffine(img, M, (w, h), borderMode=cv2.BORDER_REFLECT)
elif transform_type == "brightness_up":
return np.clip(img * 1.1, 0, 255).astype(np.uint8)
return img
def process_image(self, image_path, tta_transform=None):
try:
image = cv2.imread(image_path, cv2.IMREAD_COLOR)
if image is None:
return np.zeros(self.hp["flat_patches_shape"], dtype=np.float32)
image = cv2.resize(image, (self.hp["image_size"], self.hp["image_size"]))
if tta_transform:
image = self.apply_tta_transform(image, tta_transform)
image = image / 255.0
image = (image - self.norm_stats["mean"]) / (self.norm_stats["std"] + 1e-7)
patch_shape = (self.hp["patch_size"], self.hp["patch_size"], self.hp["num_channels"])
patches = patchify(image, patch_shape, self.hp["patch_size"])
patches = np.reshape(patches, self.hp["flat_patches_shape"]).astype(np.float32)
return patches
except Exception as e:
print(f"[ERROR] Image processing failed: {e}")
return np.zeros(self.hp["flat_patches_shape"], dtype=np.float32)
def predict(self, df_row, image_path, use_tta=True):
# 1. Preprocess Tabular Data
# Match training categorical features: StoneType, Color, Brown, BlueUv, GrdType, Result
categorical_cols = ["StoneType", "Color", "Brown", "BlueUv", "GrdType", "Result"]
numerical_cols = ["Carat"]
tab_data_list = []
for col in categorical_cols:
val = str(df_row.get(col, "__missing__"))
# Safe transform for categorical values
try:
# First check if the column exists in encoders
if col in self.cat_encoders:
# Check if val is in encoder classes, otherwise fallback to __missing__
if val not in self.cat_encoders[col].classes_:
val = "__missing__" if "__missing__" in self.cat_encoders[col].classes_ else self.cat_encoders[col].classes_[0]
encoded_val = self.cat_encoders[col].transform([val])[0]
else:
print(f"[WARN] Encoder for column {col} not found. Using 0.")
encoded_val = 0
except Exception as e:
print(f"[ERROR] Encoding failed for {col} with value {val}: {e}. Using 0.")
encoded_val = 0
tab_data_list.append(encoded_val)
for col in numerical_cols:
try:
val = float(df_row.get(col, 0))
# Reshape for scaler (expected 2D array)
scaled_val = self.num_scaler.transform([[val]])[0][0]
except Exception as e:
print(f"[ERROR] Scaling failed for {col}: {e}. Using 0.")
scaled_val = 0
tab_data_list.append(scaled_val)
tab_input = np.expand_dims(np.array(tab_data_list, dtype=np.float32), axis=0)
# 2. Inference with TTA
if use_tta:
tta_transforms = ["original", "horizontal_flip", "rotation_5", "rotation_minus_5", "brightness_up"]
all_preds = []
for transform in tta_transforms:
img_patches = self.process_image(image_path, tta_transform=transform)
img_input = np.expand_dims(img_patches, axis=0)
preds = self.model.predict([img_input, tab_input], verbose=0)[0]
all_preds.append(preds)
final_pred_probs = np.mean(all_preds, axis=0)
else:
img_patches = self.process_image(image_path)
img_input = np.expand_dims(img_patches, axis=0)
final_pred_probs = self.model.predict([img_input, tab_input], verbose=0)[0]
pred_idx = np.argmax(final_pred_probs)
decoded_pred = self.target_encoder.inverse_transform([pred_idx])[0]
return decoded_pred
|