Spaces:
Sleeping
Sleeping
File size: 6,593 Bytes
12bc208 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 | """
classifier.py
-------------
Loads the trained EfficientNetV2-L aircraft classifier and runs inference
with Test Time Augmentation (TTA).
Responsibilities:
- Set TF_USE_LEGACY_KERAS before any TensorFlow import
- Load the saved .keras model from disk once at module import time
- Accept an image file path and return a predicted aircraft class name
- Apply TTA (N=15) using horizontal flip, vertical flip,
brightness and contrast jitter
Used by: main.py (FastAPI endpoint)
"""
# Must be set before any TensorFlow or TF Hub import.
# The model was trained with tf_keras (Keras 2) and saved in that format.
# Without this flag, TF Hub's KerasLayer will fail to deserialize correctly
# on TF 2.13+ which defaults to Keras 3.
import os
os.environ["TF_USE_LEGACY_KERAS"] = "1"
import tensorflow as tf
import tensorflow_hub as hub
from src.schemas import MODEL_PATHS
import numpy as np
# Prevent TF from allocating all VRAM at startup — allocate as needed instead.
# Skipped automatically on CPU-only environments (e.g. Hugging Face Spaces).
gpus = tf.config.list_physical_devices('GPU')
if gpus:
tf.config.experimental.set_memory_growth(gpus[0], True)
# Load the Keras model
try:
classifier_model = tf.keras.models.load_model(
MODEL_PATHS['classifier'],
custom_objects={"KerasLayer": hub.KerasLayer}
)
print("Classifier model loaded successfully.")
except FileNotFoundError:
print(f"Error: model file not found at {MODEL_PATHS['classifier']}")
except Exception as e:
print(f"Error loading classifier model: {e}")
CLASS_NAMES = [
'A10', 'A400M', 'AG600', 'AH64', 'AKINCI', 'AV8B', 'An124', 'An22',
'An225', 'An72', 'B1', 'B2', 'B21', 'B52', 'Be200', 'C1', 'C130',
'C17', 'C2', 'C390', 'C5', 'CH47', 'CH53', 'CL415', 'E2', 'E7',
'EF2000', 'EMB314', 'F117', 'F14', 'F15', 'F16', 'F18', 'F2',
'F22', 'F35', 'F4', 'FCK1', 'H6', 'Il76', 'J10', 'J20', 'J35',
'J36', 'J50', 'JAS39', 'JF17', 'JH7', 'KAAN', 'KC135', 'KF21',
'KIZILELMA', 'KJ600', 'Ka27', 'Ka52', 'MQ20', 'MQ25', 'MQ28',
'MQ9', 'Mi24', 'Mi26', 'Mi28', 'Mi8', 'Mig29', 'Mig31',
'Mirage2000', 'NH90', 'P3', 'RQ4', 'Rafale', 'SR71', 'Su24',
'Su25', 'Su34', 'Su47', 'Su57', 'T50', 'TB001', 'TB2', 'Tejas',
'Tornado', 'Tu160', 'Tu22M', 'Tu95', 'U2', 'UH60', 'US2', 'V22',
'V280', 'Vulcan', 'WZ10', 'WZ7', 'WZ9', 'X29', 'X32', 'XB70',
'XQ58', 'Y20', 'YF23', 'Z10', 'Z19'
]
# Create a function to preprocess the image
def _process_image(image_path:str, image_size=(480, 480)):
"""
Load and preprocess a single image from disk.
- Reads raw bytes from the filepath
- Decodes into an RGB tensor
- Resizes to the target image size
- Normalizes pixel values to [0, 1]
Args:
image_path: Path to the image file (string)
image_size: Target image size
Returns:
Preprocessed image tensor
"""
# Read the image as raw bytes from the filepath
image = tf.io.read_file(image_path)
# Decode into an RGB tensor
image = tf.image.decode_jpeg(image, channels=3)
# Convert pixel values from 0-255 to 0-1
image = tf.image.convert_image_dtype(image, tf.float32)
# Resize the image to the desired shape
image = tf.image.resize(image, image_size)
return image
# Function to apply augmentation to Validation or Test images
def _tta_augment(image):
"""
Applies random augmentations to a single image at inference time.
Matches training augmentation exactly: flips, brightness, contrast.
Used during TTA to generate N augmented versions of the same image.
Args:
image: preprocessed image tensor, shape (Hight, Width, 3 channels)
Returns:
Augmented image tensor, same shape as input.
"""
image = tf.image.random_flip_left_right(image)
image = tf.image.random_flip_up_down(image)
image = tf.image.random_brightness(image, max_delta=0.1)
image = tf.image.random_contrast(image, lower=0.9, upper=1.1)
return image
# Function to do prediction and apply augmentaion during predictions
def _tta_predict(model, image_path, n_augments=15):
"""
Performs Test-Time Augmentation (TTA) prediction on a single image.
This function loads an image, generates multiple augmented versions of it
(creating an on-the-fly batch), and passes the entire batch through the
model in a single forward pass. The final output is the average of the
softmax probabilities across all versions, which improves prediction robustness.
Args:
model (tf.keras.Model): The loaded aircraft classifier model.
image_path (str or Path): The file path to the input image.
n_augments (int, optional): The total number of image variations to
evaluate (1 original + N-1 augmentations). Defaults to 7.
Returns:
numpy.ndarray: A 1D array containing the averaged softmax probability
vector for the target image.
"""
# 1. Load and preprocess the image
image = _process_image(image_path=image_path)
# 2. Build batch of N versions
versions = [image]
for _ in range(n_augments - 1):
versions.append(_tta_augment(image))
# 3. Stack into one batch (N, H, W, 3) - default batch size 7
batch = tf.stack(versions, axis=0)
# 4. One predict call
# FAST INFERENCE: Call the model directly instead of .predict()
# training=False ensures layers like Dropout and BatchNorm behave correctly for inference (Using a trained model to make predictions on new data)
predictions = model(batch, training=False).numpy()
# predictions = model.predict(batch, verbose=2)
# 5. Average the N Softmax vectors
avg_pred = np.mean(predictions, axis=0)
return avg_pred
# Function to predict the class name of the aircraft
def predict_aircraft(image_path:str):
"""
Run TTA inference on a single image and return the predicted aircraft class name.
Args:
image_path (str): Path to the input image file.
Returns:
str: Predicted aircraft class name (e.g. 'F22', 'Rafale').
"""
# Make the predictions
average_tta_predictions = _tta_predict(model=classifier_model,
image_path=image_path,
n_augments=15)
# Get the maximum probaility class index
class_label_idx = np.argmax(average_tta_predictions)
class_label = CLASS_NAMES[class_label_idx]
return class_label |