from fastapi import FastAPI, UploadFile, File from fastapi.middleware.cors import CORSMiddleware from fastapi.staticfiles import StaticFiles import tensorflow as tf import numpy as np import cv2 import base64 import math import os gpus = tf.config.list_physical_devices('GPU') if gpus: try: for gpu in gpus: tf.config.experimental.set_memory_growth(gpu, True) except RuntimeError as e: print(e) app = FastAPI() app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=False, allow_methods=["*"], allow_headers=["*"], ) os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' class SafeDense(tf.keras.layers.Dense): def __init__(self, **kwargs): kwargs.pop('quantization_config', None) super().__init__(**kwargs) class SafeConv2D(tf.keras.layers.Conv2D): def __init__(self, **kwargs): kwargs.pop('quantization_config', None) super().__init__(**kwargs) BASE_DIR = os.path.dirname(os.path.abspath(__file__)) MODEL_PATH = os.path.join(BASE_DIR, 'alexnet_cifar10_keras.h5') model = tf.keras.models.load_model( MODEL_PATH, custom_objects={ 'Dense': SafeDense, 'Conv2D': SafeConv2D } ) conv_layers = [layer for layer in model.layers if isinstance(layer, tf.keras.layers.Conv2D)] feature_extractor = tf.keras.Model(inputs=model.inputs, outputs=[layer.output for layer in conv_layers]) CIFAR10_CLASSES = ['Airplane', 'Automobile', 'Bird', 'Cat', 'Deer', 'Dog', 'Frog', 'Horse', 'Ship', 'Truck'] def generate_feature_grid(feature_map, max_features=64): if len(feature_map.shape) == 4: feature_map = feature_map[0] height, width, channels = feature_map.shape num_features = min(channels, max_features) grid_size = math.ceil(math.sqrt(num_features)) grid_image = np.zeros((grid_size * height, grid_size * width), dtype=np.float32) for i in range(num_features): row = i // grid_size col = i % grid_size channel_img = feature_map[:, :, i] channel_img -= channel_img.min() if channel_img.max() > 0: channel_img /= channel_img.max() channel_img *= 255.0 y_start, y_end = row * height, (row + 1) * height x_start, x_end = col * width, (col + 1) * width grid_image[y_start:y_end, x_start:x_end] = channel_img grid_image = np.uint8(grid_image) colored_grid = cv2.applyColorMap(grid_image, cv2.COLORMAP_VIRIDIS) b_channel, g_channel, r_channel = cv2.split(colored_grid) alpha_channel = grid_image transparent_grid = cv2.merge((b_channel, g_channel, r_channel, alpha_channel)) _, buffer = cv2.imencode('.png', transparent_grid) return base64.b64encode(buffer).decode('utf-8') @app.post("/predict") async def predict_image(file: UploadFile = File(...)): contents = await file.read() nparr = np.frombuffer(contents, np.uint8) img = cv2.imdecode(nparr, cv2.IMREAD_COLOR) img_resized = cv2.resize(img, (227, 227)) img_normalized = img_resized.astype(np.float32) / 255.0 img_batch = np.expand_dims(img_normalized, axis=0) activations = feature_extractor(img_batch, training=False) predictions = model(img_batch, training=False) class_idx = np.argmax(predictions[0].numpy()) layer_data = [] for i, activation in enumerate(activations): b64_image = generate_feature_grid(activation.numpy()) clean_shape = [int(dim) for dim in activation.shape[1:]] layer_data.append({ "layer_index": i + 1, "shape": clean_shape, "texture_b64": f"data:image/png;base64,{b64_image}" }) return { "prediction": CIFAR10_CLASSES[class_idx], "layers": layer_data } app.mount("/", StaticFiles(directory="static", html=True), name="static")