Spaces:
Sleeping
Sleeping
| """ | |
| classifier.py | |
| ------------- | |
| Loads the trained EfficientNetV2-L aircraft classifier and runs inference | |
| with Test Time Augmentation (TTA). | |
| Responsibilities: | |
| - Set TF_USE_LEGACY_KERAS before any TensorFlow import | |
| - Load the saved .keras model from disk once at module import time | |
| - Accept an image file path and return a predicted aircraft class name | |
| - Apply TTA (N=15) using horizontal flip, vertical flip, | |
| brightness and contrast jitter | |
| Used by: main.py (FastAPI endpoint) | |
| """ | |
| # Must be set before any TensorFlow or TF Hub import. | |
| # The model was trained with tf_keras (Keras 2) and saved in that format. | |
| # Without this flag, TF Hub's KerasLayer will fail to deserialize correctly | |
| # on TF 2.13+ which defaults to Keras 3. | |
| import os | |
| os.environ["TF_USE_LEGACY_KERAS"] = "1" | |
| import tensorflow as tf | |
| import tensorflow_hub as hub | |
| from src.schemas import MODEL_PATHS | |
| import numpy as np | |
| # Prevent TF from allocating all VRAM at startup — allocate as needed instead. | |
| # Skipped automatically on CPU-only environments (e.g. Hugging Face Spaces). | |
| gpus = tf.config.list_physical_devices('GPU') | |
| if gpus: | |
| tf.config.experimental.set_memory_growth(gpus[0], True) | |
| # Load the Keras model | |
| try: | |
| classifier_model = tf.keras.models.load_model( | |
| MODEL_PATHS['classifier'], | |
| custom_objects={"KerasLayer": hub.KerasLayer} | |
| ) | |
| print("Classifier model loaded successfully.") | |
| except FileNotFoundError: | |
| print(f"Error: model file not found at {MODEL_PATHS['classifier']}") | |
| except Exception as e: | |
| print(f"Error loading classifier model: {e}") | |
| CLASS_NAMES = [ | |
| 'A10', 'A400M', 'AG600', 'AH64', 'AKINCI', 'AV8B', 'An124', 'An22', | |
| 'An225', 'An72', 'B1', 'B2', 'B21', 'B52', 'Be200', 'C1', 'C130', | |
| 'C17', 'C2', 'C390', 'C5', 'CH47', 'CH53', 'CL415', 'E2', 'E7', | |
| 'EF2000', 'EMB314', 'F117', 'F14', 'F15', 'F16', 'F18', 'F2', | |
| 'F22', 'F35', 'F4', 'FCK1', 'H6', 'Il76', 'J10', 'J20', 'J35', | |
| 'J36', 'J50', 'JAS39', 'JF17', 'JH7', 'KAAN', 'KC135', 'KF21', | |
| 'KIZILELMA', 'KJ600', 'Ka27', 'Ka52', 'MQ20', 'MQ25', 'MQ28', | |
| 'MQ9', 'Mi24', 'Mi26', 'Mi28', 'Mi8', 'Mig29', 'Mig31', | |
| 'Mirage2000', 'NH90', 'P3', 'RQ4', 'Rafale', 'SR71', 'Su24', | |
| 'Su25', 'Su34', 'Su47', 'Su57', 'T50', 'TB001', 'TB2', 'Tejas', | |
| 'Tornado', 'Tu160', 'Tu22M', 'Tu95', 'U2', 'UH60', 'US2', 'V22', | |
| 'V280', 'Vulcan', 'WZ10', 'WZ7', 'WZ9', 'X29', 'X32', 'XB70', | |
| 'XQ58', 'Y20', 'YF23', 'Z10', 'Z19' | |
| ] | |
| # Create a function to preprocess the image | |
| def _process_image(image_path:str, image_size=(480, 480)): | |
| """ | |
| Load and preprocess a single image from disk. | |
| - Reads raw bytes from the filepath | |
| - Decodes into an RGB tensor | |
| - Resizes to the target image size | |
| - Normalizes pixel values to [0, 1] | |
| Args: | |
| image_path: Path to the image file (string) | |
| image_size: Target image size | |
| Returns: | |
| Preprocessed image tensor | |
| """ | |
| # Read the image as raw bytes from the filepath | |
| image = tf.io.read_file(image_path) | |
| # Decode into an RGB tensor | |
| image = tf.image.decode_jpeg(image, channels=3) | |
| # Convert pixel values from 0-255 to 0-1 | |
| image = tf.image.convert_image_dtype(image, tf.float32) | |
| # Resize the image to the desired shape | |
| image = tf.image.resize(image, image_size) | |
| return image | |
| # Function to apply augmentation to Validation or Test images | |
| def _tta_augment(image): | |
| """ | |
| Applies random augmentations to a single image at inference time. | |
| Matches training augmentation exactly: flips, brightness, contrast. | |
| Used during TTA to generate N augmented versions of the same image. | |
| Args: | |
| image: preprocessed image tensor, shape (Hight, Width, 3 channels) | |
| Returns: | |
| Augmented image tensor, same shape as input. | |
| """ | |
| image = tf.image.random_flip_left_right(image) | |
| image = tf.image.random_flip_up_down(image) | |
| image = tf.image.random_brightness(image, max_delta=0.1) | |
| image = tf.image.random_contrast(image, lower=0.9, upper=1.1) | |
| return image | |
| # Function to do prediction and apply augmentaion during predictions | |
| def _tta_predict(model, image_path, n_augments=15): | |
| """ | |
| Performs Test-Time Augmentation (TTA) prediction on a single image. | |
| This function loads an image, generates multiple augmented versions of it | |
| (creating an on-the-fly batch), and passes the entire batch through the | |
| model in a single forward pass. The final output is the average of the | |
| softmax probabilities across all versions, which improves prediction robustness. | |
| Args: | |
| model (tf.keras.Model): The loaded aircraft classifier model. | |
| image_path (str or Path): The file path to the input image. | |
| n_augments (int, optional): The total number of image variations to | |
| evaluate (1 original + N-1 augmentations). Defaults to 7. | |
| Returns: | |
| numpy.ndarray: A 1D array containing the averaged softmax probability | |
| vector for the target image. | |
| """ | |
| # 1. Load and preprocess the image | |
| image = _process_image(image_path=image_path) | |
| # 2. Build batch of N versions | |
| versions = [image] | |
| for _ in range(n_augments - 1): | |
| versions.append(_tta_augment(image)) | |
| # 3. Stack into one batch (N, H, W, 3) - default batch size 7 | |
| batch = tf.stack(versions, axis=0) | |
| # 4. One predict call | |
| # FAST INFERENCE: Call the model directly instead of .predict() | |
| # training=False ensures layers like Dropout and BatchNorm behave correctly for inference (Using a trained model to make predictions on new data) | |
| predictions = model(batch, training=False).numpy() | |
| # predictions = model.predict(batch, verbose=2) | |
| # 5. Average the N Softmax vectors | |
| avg_pred = np.mean(predictions, axis=0) | |
| return avg_pred | |
| # Function to predict the class name of the aircraft | |
| def predict_aircraft(image_path:str): | |
| """ | |
| Run TTA inference on a single image and return the predicted aircraft class name. | |
| Args: | |
| image_path (str): Path to the input image file. | |
| Returns: | |
| str: Predicted aircraft class name (e.g. 'F22', 'Rafale'). | |
| """ | |
| # Make the predictions | |
| average_tta_predictions = _tta_predict(model=classifier_model, | |
| image_path=image_path, | |
| n_augments=15) | |
| # Get the maximum probaility class index | |
| class_label_idx = np.argmax(average_tta_predictions) | |
| class_label = CLASS_NAMES[class_label_idx] | |
| return class_label |