File size: 3,920 Bytes
656ae3a
ab81f90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8450fd0
31222e8
ab81f90
 
 
 
 
 
12cd8e3
 
 
 
 
 
 
 
 
 
cc19bfe
 
 
12cd8e3
 
 
 
 
 
 
ab81f90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0bc36ed
ab81f90
0bc36ed
31222e8
0bc36ed
 
 
ab81f90
 
 
 
 
 
 
 
 
 
 
 
31222e8
 
 
ab81f90
 
 
31222e8
0bc36ed
656ae3a
 
ab81f90
 
656ae3a
31222e8
ab81f90
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
from fastapi import FastAPI, UploadFile, File
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
import tensorflow as tf
import numpy as np
import cv2
import base64
import math
import os

gpus = tf.config.list_physical_devices('GPU')
if gpus:
    try:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
    except RuntimeError as e:
        print(e)

app = FastAPI()

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=False,
    allow_methods=["*"],
    allow_headers=["*"],
)

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

class SafeDense(tf.keras.layers.Dense):
    def __init__(self, **kwargs):
        kwargs.pop('quantization_config', None)
        super().__init__(**kwargs)

class SafeConv2D(tf.keras.layers.Conv2D):
    def __init__(self, **kwargs):
        kwargs.pop('quantization_config', None)
        super().__init__(**kwargs)

BASE_DIR = os.path.dirname(os.path.abspath(__file__))
MODEL_PATH = os.path.join(BASE_DIR, 'alexnet_cifar10_keras.h5')

model = tf.keras.models.load_model(
    MODEL_PATH, 
    custom_objects={
        'Dense': SafeDense, 
        'Conv2D': SafeConv2D
    }
)

conv_layers = [layer for layer in model.layers if isinstance(layer, tf.keras.layers.Conv2D)]
feature_extractor = tf.keras.Model(inputs=model.inputs, outputs=[layer.output for layer in conv_layers])

CIFAR10_CLASSES = ['Airplane', 'Automobile', 'Bird', 'Cat', 'Deer', 
                   'Dog', 'Frog', 'Horse', 'Ship', 'Truck']

def generate_feature_grid(feature_map, max_features=64):
    if len(feature_map.shape) == 4:
        feature_map = feature_map[0]
        
    height, width, channels = feature_map.shape
    num_features = min(channels, max_features)
    grid_size = math.ceil(math.sqrt(num_features))
    
    grid_image = np.zeros((grid_size * height, grid_size * width), dtype=np.float32)
    
    for i in range(num_features):
        row = i // grid_size
        col = i % grid_size
        
        channel_img = feature_map[:, :, i]
        channel_img -= channel_img.min()
        if channel_img.max() > 0:
            channel_img /= channel_img.max()
        channel_img *= 255.0
        
        y_start, y_end = row * height, (row + 1) * height
        x_start, x_end = col * width, (col + 1) * width
        grid_image[y_start:y_end, x_start:x_end] = channel_img

    grid_image = np.uint8(grid_image)
    
    colored_grid = cv2.applyColorMap(grid_image, cv2.COLORMAP_VIRIDIS)
    b_channel, g_channel, r_channel = cv2.split(colored_grid)
    alpha_channel = grid_image 
    transparent_grid = cv2.merge((b_channel, g_channel, r_channel, alpha_channel))
    
    _, buffer = cv2.imencode('.png', transparent_grid)
    return base64.b64encode(buffer).decode('utf-8')

@app.post("/predict")
async def predict_image(file: UploadFile = File(...)):
    contents = await file.read()
    nparr = np.frombuffer(contents, np.uint8)
    img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
    
    img_resized = cv2.resize(img, (227, 227))
    img_normalized = img_resized.astype(np.float32) / 255.0
    img_batch = np.expand_dims(img_normalized, axis=0)
    
    activations = feature_extractor(img_batch, training=False)
    predictions = model(img_batch, training=False)
    class_idx = np.argmax(predictions[0].numpy())
    
    layer_data = []
    for i, activation in enumerate(activations):
        b64_image = generate_feature_grid(activation.numpy())
        
        clean_shape = [int(dim) for dim in activation.shape[1:]]
        
        layer_data.append({
            "layer_index": i + 1,
            "shape": clean_shape, 
            "texture_b64": f"data:image/png;base64,{b64_image}"
        })
        
    return {
        "prediction": CIFAR10_CLASSES[class_idx],
        "layers": layer_data
    }

app.mount("/", StaticFiles(directory="static", html=True), name="static")